diff --git a/.gitignore b/.gitignore index aa9230d7a2f0ea3062d8aac1004ad5c5c76e975e..e8bf92ead3e777406d4d7eeb484af7b7bfb557ad 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,3 @@ dist venv .vscode docs/build - -errors.log -info.log \ No newline at end of file diff --git a/inventory_provider/__init__.py b/inventory_provider/__init__.py index 73970ad2da56426c1820c43a4a87c8d620bf6471..7adb3e53f0fcd61c4ad780df3f927271a3f5f6ec 100644 --- a/inventory_provider/__init__.py +++ b/inventory_provider/__init__.py @@ -9,13 +9,15 @@ from flask_cors import CORS from inventory_provider import environment -def create_app(): +def create_app(setup_logging=True): """ overrides default settings with those found in the file read from env var SETTINGS_FILENAME :return: a new flask app instance """ + if setup_logging: + environment.setup_logging() required_env_vars = [ 'FLASK_SETTINGS_FILENAME', 'INVENTORY_PROVIDER_CONFIG_FILENAME'] @@ -87,6 +89,4 @@ def create_app(): logging.info('Inventory Provider Flask app initialized') - environment.setup_logging() - return app diff --git a/inventory_provider/app.py b/inventory_provider/app.py index afb7ee739917f31ddae54a16281cd4be989f06c1..c26d2a0a07b5d2b484b0101d82c3ac3dddc2ebd2 100644 --- a/inventory_provider/app.py +++ b/inventory_provider/app.py @@ -7,7 +7,6 @@ import sentry_sdk from sentry_sdk.integrations.flask import FlaskIntegration import inventory_provider -from inventory_provider import environment sentry_dsn = os.getenv('SENTRY_DSN') if sentry_dsn: @@ -16,8 +15,6 @@ if sentry_dsn: integrations=[FlaskIntegration()], release=pkg_resources.get_distribution('inventory-provider').version) -environment.setup_logging() - app = inventory_provider.create_app() if __name__ == "__main__": diff --git a/inventory_provider/tasks/monitor.py b/inventory_provider/tasks/monitor.py index 258c751197ccc4ad584a31aab6e4973e97c71c28..7fabb166a3f8e35ce311c7e139bce0dc1978485b 100644 --- a/inventory_provider/tasks/monitor.py +++ b/inventory_provider/tasks/monitor.py @@ -61,11 +61,12 @@ def _save_proc(db_queue, params, dbid): # TODO: do something to terminate the process ...? -def run(): +def run(setup_logging=True): """ save 'task-*' events to redis (all databases), never returns """ - environment.setup_logging() + if setup_logging: + environment.setup_logging() with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f: logging.info( diff --git a/inventory_provider/tasks/worker.py b/inventory_provider/tasks/worker.py index cbfeb3c76a57882ad8661771bda87a3e5578f87c..a638589739a53e960982c16227bfba59d1ad9f3b 100644 --- a/inventory_provider/tasks/worker.py +++ b/inventory_provider/tasks/worker.py @@ -9,6 +9,7 @@ from typing import List import ncclient.transport.errors from celery import Task, states, chord from celery.result import AsyncResult +from celery import signals from collections import defaultdict @@ -36,12 +37,16 @@ FINALIZER_TIMEOUT_S = 300 # TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks) # noqa: E501 -environment.setup_logging() logger = logging.getLogger(__name__) log_task_entry_and_exit = functools.partial( environment.log_entry_and_exit, logger=logger) +@signals.after_setup_logger.connect +def setup_logging(conf=None, **kwargs): + environment.setup_logging() + + class InventoryTaskError(Exception): pass diff --git a/test/conftest.py b/test/conftest.py index a0e90799d5cf488d0cd2fea2746d56a1d30c196d..245f1f3b97d78dba94b814833eed2a6fc13c4b6d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,11 +1,13 @@ import ast import contextlib +import copy +from functools import lru_cache import json import netifaces import os +import pathlib import re import tempfile -import threading from lxml import etree import pytest @@ -14,11 +16,32 @@ import inventory_provider from inventory_provider.tasks import worker from inventory_provider import config -TEST_DATA_DIRNAME = os.path.join( - os.path.dirname(__file__), - "data") -_bootstrap_semaphore = threading.Semaphore() +TEST_DATA_DIR = pathlib.Path(__file__).parent / "data" + + +@lru_cache() +def read_json_test_data(filename: str): + return json.loads(TEST_DATA_DIR.joinpath(filename).read_text()) + + +def without_cache(data: dict): + data = copy.copy(data) + + def _is_cache(s): + if s.startswith('classifier-cache'): + return True + if s.startswith('joblog'): + return True + return False + keys_to_delete = filter(_is_cache, data.keys()) + for k in list(keys_to_delete): + del data[k] + return data + + +DB_DATA = read_json_test_data("router-info.json") +DB_DATA_NOCACHE = without_cache(DB_DATA) @pytest.fixture @@ -83,8 +106,7 @@ def data_config_filename(): ] } - with open(os.path.join(TEST_DATA_DIRNAME, 'gws-direct.json')) as gws: - config['gws-direct'] = json.loads(gws.read()) + config['gws-direct'] = read_json_test_data('gws-direct.json') f.write(json.dumps(config).encode('utf-8')) f.flush() @@ -97,27 +119,69 @@ def data_config(data_config_filename): return config.load(f) -class MockedRedis(object): +class DictOverlay: + """Since we're dealing with a very large dictionary from 'router-info.json' we only + want to read that dictionary once and prevent mutating it during tests. We instead + use this class to capture all mutations in additional layer that is reset every test - db = None + Not thread safe + """ + TOMBSTONE = object() # record deletions - def __init__(self, *args, **kwargs): - _bootstrap_semaphore.acquire() + def __init__(self, base_dict: dict) -> None: + self.dict = base_dict + self.overlay = {} + + def __getitem__(self, key): + if key not in self.overlay: + return self.dict[key] + value = self.overlay[key] + if value is self.TOMBSTONE: + raise KeyError(key) + return value + + def __setitem__(self, key, value): + self.overlay[key] = value + + def __delitem__(self, key): + self.overlay[key] = self.TOMBSTONE + + def __contains__(self, key): + if key in self.overlay: + return self.overlay[key] is not self.TOMBSTONE + return key in self.dict + + def get(self, key, default=None): try: - if MockedRedis.db is None: - MockedRedis.prep() - finally: - _bootstrap_semaphore.release() - - # allows us to create other mocks using a different data source file - @staticmethod - def prep(data_source_file="router-info.json"): - test_data_filename = os.path.join( - TEST_DATA_DIRNAME, - data_source_file) - with open(test_data_filename) as f: - MockedRedis.db = json.loads(f.read()) - MockedRedis.db['db:latch'] = json.dumps({ + return self[key] + except KeyError: + return default + + def keys(self): + deleted_keys = {k for k, v in self.overlay.items() if v is self.TOMBSTONE} + return (self.dict.keys() | self.overlay.keys()) - deleted_keys + + def items(self): + return ((key, self[key]) for key in self.keys()) + + _missing = object + + def pop(self, key, default=_missing): + try: + value = self[key] + except KeyError: + if default is self._missing: + raise + return default + + self.overlay[key] = self.TOMBSTONE + return value + + +class MockedRedis: + def __init__(self, *args, **kwargs): + self.db = DictOverlay(DB_DATA_NOCACHE) + self.db['db:latch'] = json.dumps({ 'current': 0, 'next': 0, 'this': 0, @@ -125,44 +189,33 @@ class MockedRedis(object): 'failure': False }) - # remove any cached data from the captured snapshot - def _is_cache(s): - if s.startswith('classifier-cache'): - return True - if s.startswith('joblog'): - return True - return False - keys_to_delete = filter(_is_cache, MockedRedis.db.keys()) - for k in list(keys_to_delete): - del MockedRedis.db[k] - def set(self, name, value): - MockedRedis.db[name] = value + self.db[name] = value def get(self, name): - value = MockedRedis.db.get(name, None) + value = self.db.get(name, None) if value is None: return None return value.encode('utf-8') def exists(self, name): - return name in MockedRedis.db + return name in self.db def delete(self, key): if isinstance(key, bytes): key = key.decode('utf-8') - # redis ignores delete for keys that don't exist - if key in MockedRedis.db: - del MockedRedis.db[key] + + if key in self.db: + del self.db[key] def scan_iter(self, glob=None, count='unused'): if not glob: - for k in list(MockedRedis.db.keys()): + for k in list(self.db.keys()): yield k.encode('utf-8') m = re.match(r'^([^*]+)\*$', glob) assert m # all expected globs are like this - for k in list(MockedRedis.db.keys()): + for k in list(self.db.keys()): if k.startswith(m.group(1)): yield k.encode('utf-8') @@ -180,16 +233,13 @@ class MockedRedis(object): return self -@pytest.fixture +@pytest.fixture(scope="session") def cached_test_data(): - filename = os.path.join(TEST_DATA_DIRNAME, "router-info.json") - with open(filename) as f: - return json.loads(f.read()) + return DB_DATA @pytest.fixture def flask_config_filename(): - with tempfile.NamedTemporaryFile() as f: f.write('ENABLE_TESTING_ROUTES = True\n'.encode('utf-8')) f.flush() @@ -198,17 +248,17 @@ def flask_config_filename(): @pytest.fixture def mocked_redis(mocker): - MockedRedis.db = None # force data to be reloaded + instance = MockedRedis() mocker.patch( 'inventory_provider.tasks.common.redis.StrictRedis', - MockedRedis) + return_value=instance) @pytest.fixture def client(flask_config_filename, data_config_filename, mocked_redis): os.environ['FLASK_SETTINGS_FILENAME'] = flask_config_filename os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'] = data_config_filename - with inventory_provider.create_app().test_client() as c: + with inventory_provider.create_app(setup_logging=False).test_client() as c: yield c diff --git a/test/per_router/conftest.py b/test/per_router/conftest.py index bb581421d103815471fa6229ac4ab6073a382d83..7f4842c96ded38787ceaa90956424f0d17d57289 100644 --- a/test/per_router/conftest.py +++ b/test/per_router/conftest.py @@ -1,6 +1,5 @@ -import glob import json -import os +import pathlib import re from unittest.mock import patch @@ -10,41 +9,31 @@ from ncclient.manager import make_device_handler, Manager from ncclient.transport import SSHSession from ncclient.xml_ import NCElement -import inventory_provider from inventory_provider import juniper -TEST_DATA_DIRNAME = os.path.realpath(os.path.join( - inventory_provider.__path__[0], - "..", - "test", - "data")) + +TEST_DATA_DIR = pathlib.Path(__file__).parent / "../data" @pytest.fixture def classifier_cache_test_entries(): - filename = os.path.join( - TEST_DATA_DIRNAME, 'classifier-cache-entries.json') - with open(filename) as f: - return json.loads(f.read()) - + file = TEST_DATA_DIR.joinpath('classifier-cache-entries.json') + return json.loads(file.read_text()) -def pytest_generate_tests(metafunc): - # TODO: can we really not get netconf data for all routers? - def _available_netconf_hosts(): - for fn in glob.glob(os.path.join(TEST_DATA_DIRNAME, '*-netconf.xml')): - m = re.match('(.*)-netconf.xml', os.path.basename(fn)) - assert m # sanity - yield m.group(1) - routers = list(_available_netconf_hosts()) - metafunc.parametrize("router", routers) +@pytest.fixture(params=list(TEST_DATA_DIR.glob('*-netconf.xml'))) +def router(request): + file: pathlib.Path = request.param + m = re.match('(.*)-netconf.xml', file.name) + assert m # sanity + return m.group(1) class MockedJunosRpc(object): def __init__(self, hostname): - filename = os.path.join(TEST_DATA_DIRNAME, "%s-netconf.xml" % hostname) - self.config = etree.parse(filename) + file = TEST_DATA_DIR.joinpath(f'{hostname}-netconf.xml') + self.config = etree.parse(file) def get_config(self): return self.config @@ -74,11 +63,10 @@ def netconf_doc(mocker, router, data_config): @pytest.fixture def interface_info_response(router): - filename = os.path.join(TEST_DATA_DIRNAME, 'interface_info', f'{router}.xml') + file = TEST_DATA_DIR / f"interface_info/{router}.xml" + try: - with open(filename, 'r') as file: - data = file.read() - return data + return file.read_text() except FileNotFoundError: pytest.skip(f'no corresponding interface_info doc for {router}, skipping') diff --git a/test/per_router/test_celery_worker.py b/test/per_router/test_celery_worker.py index fb2da012f81e6d36ad88355f266e5c183a1733bb..50734ddbdd510944660600851e1a0ee4b6b05eb5 100644 --- a/test/per_router/test_celery_worker.py +++ b/test/per_router/test_celery_worker.py @@ -4,6 +4,7 @@ and some data ends up in the right place ... otherwise not very detailed """ import re +import ncclient.transport.errors import pytest from inventory_provider.tasks import worker @@ -20,7 +21,15 @@ def backend_db(): }).db -def test_netconf_refresh_config(mocked_worker_module, router): +@pytest.fixture +def mocked_juniper(mocker): + return mocker.patch( + "inventory_provider.juniper.get_interface_info_for_router", + side_effect=ncclient.transport.errors.SSHError + ) + + +def test_netconf_refresh_config(mocked_worker_module, router, mocked_juniper): if router in ['qfx.par.fr.geant.net', 'qfx.fra.de.geant.net']: # expected to fail pytest.skip(f'test data has no community string for {router}') diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 66fa82851b4a4ccbeba942185f20a9567699e0e8..9365a3a869d600e1978aa143d2fb5bbaa9d80619 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -64,7 +64,7 @@ def backend_db(): def test_latchdb(data_config_filename, mocked_redis): os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'] = data_config_filename - monitor.run() + monitor.run(setup_logging=False) db = backend_db() diff --git a/test/test_worker.py b/test/test_worker.py index 5d3c2c4f8c2955cd65a624890a91d46b68ea64bd..50be2e8f0acd6639716598ae2ac033718fa8e109 100644 --- a/test/test_worker.py +++ b/test/test_worker.py @@ -601,26 +601,26 @@ def test_persist_ims_data(mocker, data_config, mocked_redis): r.delete(k) persist_ims_data(data) - assert [k.decode("utf-8") for k in r.keys("ims:pop_nodes:*")] == \ - ["ims:pop_nodes:LOC A", "ims:pop_nodes:LOC B"] + assert {k.decode("utf-8") for k in r.keys("ims:pop_nodes:*")} == \ + {"ims:pop_nodes:LOC A", "ims:pop_nodes:LOC B"} - assert [k.decode("utf-8") for k in r.keys("ims:location:*")] == \ - ["ims:location:eq_a", "ims:location:eq_b"] + assert {k.decode("utf-8") for k in r.keys("ims:location:*")} == \ + {"ims:location:eq_a", "ims:location:eq_b"} - assert [k.decode("utf-8") for k in r.keys("ims:lg:*")] == \ - ["ims:lg:lg_eq1", "ims:lg:lg_eq2"] + assert {k.decode("utf-8") for k in r.keys("ims:lg:*")} == \ + {"ims:lg:lg_eq1", "ims:lg:lg_eq2"} - assert [k.decode("utf-8") for k in r.keys("ims:circuit_hierarchy:*")] == \ - ["ims:circuit_hierarchy:123", "ims:circuit_hierarchy:456"] + assert {k.decode("utf-8") for k in r.keys("ims:circuit_hierarchy:*")} == \ + {"ims:circuit_hierarchy:123", "ims:circuit_hierarchy:456"} - assert [k.decode("utf-8") for k in r.keys("ims:interface_services:*")] == \ - ["ims:interface_services:if1", "ims:interface_services:if2"] + assert {k.decode("utf-8") for k in r.keys("ims:interface_services:*")} == \ + {"ims:interface_services:if1", "ims:interface_services:if2"} - assert [k.decode("utf-8") for k in r.keys("ims:node_pair_services:*")] == \ - ["ims:node_pair_services:np1", "ims:node_pair_services:np2"] + assert {k.decode("utf-8") for k in r.keys("ims:node_pair_services:*")} == \ + {"ims:node_pair_services:np1", "ims:node_pair_services:np2"} - assert [k.decode("utf-8") for k in r.keys("poller_cache:*")] == \ - ["poller_cache:eq1", "poller_cache:eq2"] + assert {k.decode("utf-8") for k in r.keys("poller_cache:*")} == \ + {"poller_cache:eq1", "poller_cache:eq2"} assert json.loads(r.get("ims:sid_services").decode("utf-8")) == \ data["sid_services"]