diff --git a/gso/api/api_v1/endpoints/imports.py b/gso/api/api_v1/endpoints/imports.py index 74390a80af68c197e7757d056996c015633916d3..c83b0919d70ae5988885ee1d4efd1b874589330d 100644 --- a/gso/api/api_v1/endpoints/imports.py +++ b/gso/api/api_v1/endpoints/imports.py @@ -81,11 +81,10 @@ class RouterImportModel(BaseModel): ts_port: int router_vendor: RouterVendor router_role: RouterRole - is_ias_connected: Optional[bool] = None - router_access_via_ts: Optional[bool] = None - router_lo_ipv4_address: Optional[ipaddress.IPv4Address] = None - router_lo_ipv6_address: Optional[ipaddress.IPv6Address] = None - router_lo_iso_address: Optional[str] = None + is_ias_connected: bool + router_lo_ipv4_address: ipaddress.IPv4Address + router_lo_ipv6_address: ipaddress.IPv6Address + router_lo_iso_address: str router_si_ipv4_network: Optional[ipaddress.IPv4Network] = None router_ias_lt_ipv4_network: Optional[ipaddress.IPv4Network] = None router_ias_lt_ipv6_network: Optional[ipaddress.IPv6Network] = None diff --git a/gso/main.py b/gso/main.py index d94b25391a9238f0a9fafc75239d99dacb49a550..9c8223bb7da91ecc554bea1e38e9faf1e787feb1 100644 --- a/gso/main.py +++ b/gso/main.py @@ -9,8 +9,11 @@ import gso.workflows # noqa: F401 from gso import load_gso_cli from gso.api.api_v1.api import api_router -app = OrchestratorCore(base_settings=AppSettings()) -app.include_router(api_router, prefix="/api") + +def init_gso_app(settings: AppSettings) -> OrchestratorCore: + app = OrchestratorCore(base_settings=settings) + app.include_router(api_router, prefix="/api") + return app def init_cli_app() -> typer.Typer: @@ -18,5 +21,7 @@ def init_cli_app() -> typer.Typer: return core_cli() +app = init_gso_app(settings=AppSettings()) + if __name__ == "__main__": init_cli_app() diff --git a/gso/workflows/tasks/import_router.py b/gso/workflows/tasks/import_router.py index 5b0468b6e23f10060e49ad65f10678190af8d5a4..547f8c8ae29f79d58b80a3ec14e94b8890a8bb91 100644 --- a/gso/workflows/tasks/import_router.py +++ b/gso/workflows/tasks/import_router.py @@ -47,11 +47,10 @@ def initial_input_form_generator() -> FormGenerator: ts_port: int router_vendor: RouterVendor router_role: RouterRole - is_ias_connected: Optional[bool] = None - router_access_via_ts: Optional[bool] = None - router_lo_ipv4_address: Optional[ipaddress.IPv4Address] = None - router_lo_ipv6_address: Optional[ipaddress.IPv6Address] = None - router_lo_iso_address: Optional[str] = None + is_ias_connected: bool + router_lo_ipv4_address: ipaddress.IPv4Address + router_lo_ipv6_address: ipaddress.IPv6Address + router_lo_iso_address: str router_si_ipv4_network: Optional[ipaddress.IPv4Network] = None router_ias_lt_ipv4_network: Optional[ipaddress.IPv4Network] = None router_ias_lt_ipv6_network: Optional[ipaddress.IPv6Network] = None @@ -72,7 +71,7 @@ def get_site_by_name(site_name: str) -> Site: .first() ) if not subscription: - raise ValueError(f"Site with name {site_name} not found") + raise ValueError(f"Site with name {site_name} not found.") return Site.from_subscription(subscription.subscription_id) diff --git a/gso/workflows/tasks/import_site.py b/gso/workflows/tasks/import_site.py index 22bd55d91852ced80e48ed44d0b520d938f9c781..6402ae2a31cf05234200cc53cca17afd987fb3f1 100644 --- a/gso/workflows/tasks/import_site.py +++ b/gso/workflows/tasks/import_site.py @@ -1,4 +1,4 @@ -from uuid import UUID, uuid4 +from uuid import UUID from orchestrator.db.models import ProductTable from orchestrator.forms import FormPage diff --git a/requirements.txt b/requirements.txt index b454c47cb2b55eef9fab6467b91c60dc2d5cbe8e..91809af8718ac3417dd76cc0f01e024c5c67a3cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pydantic requests pytest +faker responses black isort diff --git a/test/conftest.py b/test/conftest.py index 75ce00bc67ee6cee82d4b1ba305b7855daf9adf1..138adc0bb496222e54b365d4c6426a257c5bd808 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,8 +3,21 @@ import json import os import socket import tempfile +from pathlib import Path +import orchestrator import pytest +from alembic import command +from alembic.config import Config +from orchestrator import app_settings +from orchestrator.db import Database, db +from orchestrator.db.database import ENGINE_ARGUMENTS, SESSION_ARGUMENTS, BaseModel +from sqlalchemy import create_engine +from sqlalchemy.engine import make_url +from sqlalchemy.orm import scoped_session, sessionmaker +from starlette.testclient import TestClient + +from gso.main import init_gso_app @pytest.fixture(scope="session") @@ -74,3 +87,128 @@ def data_config_filename(configuration_data) -> str: os.environ["OSS_PARAMS_FILENAME"] = f.name yield f.name + + +@pytest.fixture(scope="session") +def db_uri(): + """Provide the database uri configuration to run the migration on.""" + + return os.environ.get("DATABASE_URI_TEST", "postgresql://nwa:nwa@localhost/nwa-workflows-test") + + +def run_migrations(db_uri: str) -> None: + """Configure the alembic migration and run the migration on the database. + + Args: + ---- + db_uri: The database uri configuration to run the migration on. + + Returns: + ------- + None + """ + + path = Path(__file__).resolve().parent + app_settings.DATABASE_URI = db_uri + alembic_cfg = Config(file_=path / "../gso/alembic.ini") + alembic_cfg.set_main_option("sqlalchemy.url", db_uri) + + alembic_cfg.set_main_option("script_location", str(path / "../gso/migrations")) + version_locations = alembic_cfg.get_main_option("version_locations") + alembic_cfg.set_main_option( + "version_locations", f"{version_locations} {os.path.dirname(orchestrator.__file__)}/migrations/versions/schema" + ) + + command.upgrade(alembic_cfg, "heads") + + +@pytest.fixture(scope="session") +def database(db_uri): + """Create database and run migrations and cleanup after wards. + + Args: + ---- + db_uri: The database uri configuration to run the migration on. + """ + + db.update(Database(db_uri)) + url = make_url(db_uri) + db_to_create = url.database + url = url.set(database="postgres") + + engine = create_engine(url) + with engine.connect() as conn: + conn.execute("COMMIT;") + conn.execute(f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname='{db_to_create}';") # noqa + conn.execute(f'DROP DATABASE IF EXISTS "{db_to_create}";') # noqa + conn.execute("COMMIT;") + conn.execute(f'CREATE DATABASE "{db_to_create}";') # noqa + + run_migrations(db_uri) + db.wrapped_database.engine = create_engine(db_uri, **ENGINE_ARGUMENTS) + + try: + yield + finally: + db.wrapped_database.engine.dispose() + with engine.connect() as conn: + conn.execute("COMMIT;") + # Terminate all connections to the database + conn.execute( + f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname='{db_to_create}';" # noqa + ) + conn.execute(f'DROP DATABASE IF EXISTS "{db_to_create}";') # noqa + + +@pytest.fixture(autouse=True) +def db_session(database): + """Ensure that tests are executed within a transactional scope that automatically rolls back after completion. + + This fixture facilitates a pattern known as 'transactional tests'. At the start, it establishes a connection and + begins an overarching transaction. Any database operations performed within the test function—whether they commit + or not happen within the context of this master transaction. + + From the perspective of the test function, it seems as though changes are getting committed to the database, + enabling the tests to query and assert the persistence of data. Yet, once the test completes, this fixture + intervenes to roll back the master transaction. This ensures a clean slate after each test, preventing tests from + polluting the database state for subsequent tests. + + Benefits: + - Each test runs in isolation with a pristine database state. + - Avoids the overhead of recreating the database schema or re-seeding data between tests. + + Args: + ---- + database: A fixture reference that initializes the database. + """ + + with contextlib.closing(db.wrapped_database.engine.connect()) as test_connection: + # Create a new session factory for this context. + session_factory = sessionmaker(bind=test_connection, **SESSION_ARGUMENTS) + scoped_session_instance = scoped_session(session_factory, scopefunc=db.wrapped_database._scopefunc) + + # Point the database session to this new scoped session. + db.wrapped_database.session_factory = session_factory + db.wrapped_database.scoped_session = scoped_session_instance + + # Set the query for the base model. + BaseModel.set_query(scoped_session_instance.query_property()) + transaction = test_connection.begin() + try: + yield + finally: + transaction.rollback() + scoped_session_instance.remove() + + +@pytest.fixture(scope="session", autouse=True) +def fastapi_app(database, db_uri): + """Load the GSO FastAPI app for testing purposes.""" + + app_settings.DATABASE_URI = db_uri + return init_gso_app(settings=app_settings) + + +@pytest.fixture(scope="session") +def test_client(fastapi_app): + return TestClient(fastapi_app) diff --git a/test/test_imports.py b/test/test_imports.py new file mode 100644 index 0000000000000000000000000000000000000000..27ea48e6dd1336e02c8143bfc6ce5916399356bf --- /dev/null +++ b/test/test_imports.py @@ -0,0 +1,105 @@ +import pytest +from faker import Faker +from orchestrator.db import SubscriptionTable +from orchestrator.services import subscriptions + +from gso.products.product_blocks.router import RouterRole, RouterVendor +from gso.products.product_blocks.site import SiteTier + + +class TestImportEndpoints: + @pytest.fixture(autouse=True) + def setup(self, test_client): + self.faker = Faker() + self.client = test_client + self.site_import_endpoint = "/api/imports/sites" + self.router_import_endpoint = "/api/imports/routers" + self.site_data = { + "site_name": self.faker.name(), + "site_city": self.faker.city(), + "site_country": self.faker.country(), + "site_country_code": self.faker.country_code(), + "site_latitude": float(self.faker.latitude()), + "site_longitude": float(self.faker.longitude()), + "site_bgp_community_id": self.faker.pyint(), + "site_internal_id": self.faker.pyint(), + "site_tier": SiteTier.TIER1, + "site_ts_address": self.faker.ipv4(), + "customer": "Geant", + } + self.router_data = { + "hostname": "127.0.0.1", + "router_role": RouterRole.PE, + "router_vendor": RouterVendor.JUNIPER, + "router_site": self.site_data["site_name"], + "ts_port": 1234, + "customer": "Geant", + "is_ias_connected": True, + "router_lo_ipv4_address": self.faker.ipv4(), + "router_lo_ipv6_address": self.faker.ipv6(), + "router_lo_iso_address": "TestAddress", + } + + def test_import_site_endpoint(self): + assert SubscriptionTable.query.all() == [] + # Post data to the endpoint + response = self.client.post(self.site_import_endpoint, json=self.site_data) + assert response.status_code == 201 + assert "detail" in response.json() + assert "pid" in response.json() + subscription = subscriptions.retrieve_subscription_by_subscription_instance_value( + resource_type="site_name", value=self.site_data["site_name"] + ) + assert subscription is not None + + def test_import_site_endpoint_with_existing_site(self): + response = self.client.post(self.site_import_endpoint, json=self.site_data) + assert SubscriptionTable.query.count() == 1 + assert response.status_code == 201 + + response = self.client.post(self.site_import_endpoint, json=self.site_data) + assert response.status_code == 409 + assert SubscriptionTable.query.count() == 1 + + def test_import_site_endpoint_with_invalid_data(self): + # invalid data, missing site_latitude and invalid site_longitude + site_data = self.site_data.copy() + site_data.pop("site_latitude") + site_data["site_longitude"] = "invalid" + assert SubscriptionTable.query.count() == 0 + response = self.client.post(self.site_import_endpoint, json=site_data) + assert response.status_code == 422 + assert SubscriptionTable.query.count() == 0 + response = response.json() + assert response["detail"][0]["loc"] == ["body", "site_latitude"] + assert response["detail"][0]["msg"] == "field required" + assert response["detail"][1]["loc"] == ["body", "site_longitude"] + assert response["detail"][1]["msg"] == "value is not a valid float" + + def test_import_router_endpoint(self): + # Create a site first + response = self.client.post(self.site_import_endpoint, json=self.site_data) + assert response.status_code == 201 + assert SubscriptionTable.query.count() == 1 + + response = self.client.post(self.router_import_endpoint, json=self.router_data) + assert response.status_code == 201 + assert SubscriptionTable.query.count() == 2 + + def test_import_router_endpoint_with_invalid_data(self): + response = self.client.post(self.site_import_endpoint, json=self.site_data) + assert response.status_code == 201 + assert SubscriptionTable.query.count() == 1 + + # invalid data, missing hostname and invalid router_lo_ipv6_address + router_data = self.router_data.copy() + router_data.pop("hostname") + router_data["router_lo_ipv6_address"] = "invalid" + response = self.client.post(self.router_import_endpoint, json=router_data) + assert response.status_code == 422 + assert SubscriptionTable.query.count() == 1 + response = response.json() + assert response["detail"][0]["loc"] == ["body", "hostname"] + assert response["detail"][0]["msg"] == "field required" + assert response["detail"][1]["loc"] == ["body", "router_lo_ipv6_address"] + assert response["detail"][1]["msg"] == "value is not a valid IPv6 address"