diff --git a/resource_management/routes/interfaces.py b/resource_management/routes/interfaces.py index edf26cd5521c09fd8b37c938b83536209e1ddc61..6dfa7a297aef6753c3a50411b55ea5c5520b32a5 100644 --- a/resource_management/routes/interfaces.py +++ b/resource_management/routes/interfaces.py @@ -10,6 +10,7 @@ from resource_management import router_interfaces router = APIRouter() +FIRST_LAG_INDEX = 0 class InterfaceCounts(pydantic.BaseModel): total: int @@ -85,22 +86,29 @@ async def reserve_next_lag(fqdn: str) -> NextLAG: db.init_db_model(params['db']) with db.session_scope() as session: - records = session.query(model.LAG.name) \ + router_record = session.query(model.Router.id) \ + .filter_by(fqdn=fqdn).one() + lag_rows = session.query(model.LAG.name) \ .join(model.Router).filter_by(fqdn=fqdn).all() - names = set(r[0] for r in records) + lag_names = set(r[0] for r in lag_rows) + + def _next_lag_name(): + index = FIRST_LAG_INDEX + while True: + candidate_name = f'ae{index}' + if candidate_name not in lag_names: + new_lag_record = model.LAG( + name=candidate_name, + router_id=router_record[0], + availability=model.AvalabilityStates.RESERVED.name) + session.add(new_lag_record) + return candidate_name + index += 1 - def _next_lag_name(): - index = 1 - while True: - candidate = f'ae{index}' - if candidate not in names: - return candidate - index += 1 - - return { - 'fqdn': fqdn, - 'name': _next_lag_name() - } + return { + 'fqdn': fqdn, + 'name': _next_lag_name() + } @router.post('/next-physical/{fqdn}/{lag_name}') @@ -121,8 +129,12 @@ async def reserve_physical_bundle_member( def _find_available_physical(): for ifc in router.physical: - if not ifc.lag: + if ifc.availability \ + == model.AvalabilityStates.AVAILABLE.name: + ifc.availability = model.AvalabilityStates.RESERVED.name + session.add(ifc) return ifc.name + raise HTTPException( status_code=404, detail=f'no available physical ports for "{lag_name}"') diff --git a/test/test_interfaces_routes.py b/test/test_interfaces_routes.py index e9923a13bc152bb428926125b2a781136817e695..d1d4bfcfbc22f4cc6b9ba10fad1e2b37db79a0ea 100644 --- a/test/test_interfaces_routes.py +++ b/test/test_interfaces_routes.py @@ -3,6 +3,10 @@ import jsonschema from resource_management.routes.default import Version from resource_management.routes import interfaces from resource_management import router_interfaces +from resource_management import db +from resource_management.db import model + +import pytest def test_bad_method(client): @@ -29,13 +33,48 @@ def test_next_lag(client, resources_db, mocked_router, router_name): rv = client.post(f'/api/interfaces/next-lag/{router_name}') assert rv.status_code == 200 - jsonschema.validate(rv.json(), interfaces.NextLAG.schema()) + response = rv.json() + jsonschema.validate(response, interfaces.NextLAG.schema()) + + with db.session_scope() as session: + row = session.query(model.LAG). \ + filter_by(name=response['name']). \ + join(model.Router). \ + filter_by(fqdn=router_name).one() + assert row.availability == model.AvalabilityStates.RESERVED.name + + def test_next_physical(client, resources_db, mocked_router, router_name): + if '.lab.' in router_name: + pytest.skip('not all lab routers have available ports') + if router_name.startswith('rt'): + pytest.skip('not all rt* have available ports') + router_interfaces.load_new_router_interfaces(router_name) - rv = client.post(f'/api/interfaces/next-physical/{router_name}/ae123123') + with db.session_scope() as session: + rows = session.query(model.LAG.name). \ + join(model.Router). \ + filter_by(fqdn=router_name).all() + lag_name = rows[0][0] + + rv = client.post(f'/api/interfaces/next-physical/{router_name}/{lag_name}') assert rv.status_code == 200 - jsonschema.validate(rv.json(), interfaces.NextPhysicalInterface.schema()) + + response = rv.json() + jsonschema.validate(response, interfaces.NextPhysicalInterface.schema()) + + assert response['fqdn'] == router_name + assert response['lag'] == lag_name + + with db.session_scope() as session: + ifc_row = session.query(model.PhysicalInterface). \ + filter_by(name=response['name']). \ + join(model.Router). \ + filter_by(fqdn=router_name). \ + join(model.LAG). \ + filter_by(name=lag_name).one() + assert ifc_row.availability == model.AvalabilityStates.RESERVED.name