Skip to content
Snippets Groups Projects
Commit 078fda02 authored by Pelle Koster's avatar Pelle Koster
Browse files

Improve testing speed

parent e059a01c
No related branches found
No related tags found
1 merge request!29Improve testing speed
...@@ -12,6 +12,3 @@ dist ...@@ -12,6 +12,3 @@ dist
venv venv
.vscode .vscode
docs/build docs/build
errors.log
info.log
\ No newline at end of file
...@@ -9,13 +9,15 @@ from flask_cors import CORS ...@@ -9,13 +9,15 @@ from flask_cors import CORS
from inventory_provider import environment from inventory_provider import environment
def create_app(): def create_app(setup_logging=True):
""" """
overrides default settings with those found overrides default settings with those found
in the file read from env var SETTINGS_FILENAME in the file read from env var SETTINGS_FILENAME
:return: a new flask app instance :return: a new flask app instance
""" """
if setup_logging:
environment.setup_logging()
required_env_vars = [ required_env_vars = [
'FLASK_SETTINGS_FILENAME', 'INVENTORY_PROVIDER_CONFIG_FILENAME'] 'FLASK_SETTINGS_FILENAME', 'INVENTORY_PROVIDER_CONFIG_FILENAME']
...@@ -87,6 +89,4 @@ def create_app(): ...@@ -87,6 +89,4 @@ def create_app():
logging.info('Inventory Provider Flask app initialized') logging.info('Inventory Provider Flask app initialized')
environment.setup_logging()
return app return app
...@@ -7,7 +7,6 @@ import sentry_sdk ...@@ -7,7 +7,6 @@ import sentry_sdk
from sentry_sdk.integrations.flask import FlaskIntegration from sentry_sdk.integrations.flask import FlaskIntegration
import inventory_provider import inventory_provider
from inventory_provider import environment
sentry_dsn = os.getenv('SENTRY_DSN') sentry_dsn = os.getenv('SENTRY_DSN')
if sentry_dsn: if sentry_dsn:
...@@ -16,8 +15,6 @@ if sentry_dsn: ...@@ -16,8 +15,6 @@ if sentry_dsn:
integrations=[FlaskIntegration()], integrations=[FlaskIntegration()],
release=pkg_resources.get_distribution('inventory-provider').version) release=pkg_resources.get_distribution('inventory-provider').version)
environment.setup_logging()
app = inventory_provider.create_app() app = inventory_provider.create_app()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -61,11 +61,12 @@ def _save_proc(db_queue, params, dbid): ...@@ -61,11 +61,12 @@ def _save_proc(db_queue, params, dbid):
# TODO: do something to terminate the process ...? # TODO: do something to terminate the process ...?
def run(): def run(setup_logging=True):
""" """
save 'task-*' events to redis (all databases), never returns save 'task-*' events to redis (all databases), never returns
""" """
environment.setup_logging() if setup_logging:
environment.setup_logging()
with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f: with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
logging.info( logging.info(
......
...@@ -9,6 +9,7 @@ from typing import List ...@@ -9,6 +9,7 @@ from typing import List
import ncclient.transport.errors import ncclient.transport.errors
from celery import Task, states, chord from celery import Task, states, chord
from celery.result import AsyncResult from celery.result import AsyncResult
from celery import signals
from collections import defaultdict from collections import defaultdict
...@@ -36,12 +37,16 @@ FINALIZER_TIMEOUT_S = 300 ...@@ -36,12 +37,16 @@ FINALIZER_TIMEOUT_S = 300
# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks) # noqa: E501 # TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks) # noqa: E501
environment.setup_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
log_task_entry_and_exit = functools.partial( log_task_entry_and_exit = functools.partial(
environment.log_entry_and_exit, logger=logger) environment.log_entry_and_exit, logger=logger)
@signals.after_setup_logger.connect
def setup_logging(conf=None, **kwargs):
environment.setup_logging()
class InventoryTaskError(Exception): class InventoryTaskError(Exception):
pass pass
......
import ast import ast
import contextlib import contextlib
import copy
from functools import lru_cache
import json import json
import netifaces import netifaces
import os import os
import pathlib
import re import re
import tempfile import tempfile
import threading
from lxml import etree from lxml import etree
import pytest import pytest
...@@ -14,11 +16,32 @@ import inventory_provider ...@@ -14,11 +16,32 @@ import inventory_provider
from inventory_provider.tasks import worker from inventory_provider.tasks import worker
from inventory_provider import config from inventory_provider import config
TEST_DATA_DIRNAME = os.path.join(
os.path.dirname(__file__),
"data")
_bootstrap_semaphore = threading.Semaphore() TEST_DATA_DIR = pathlib.Path(__file__).parent / "data"
@lru_cache()
def read_json_test_data(filename: str):
return json.loads(TEST_DATA_DIR.joinpath(filename).read_text())
def without_cache(data: dict):
data = copy.copy(data)
def _is_cache(s):
if s.startswith('classifier-cache'):
return True
if s.startswith('joblog'):
return True
return False
keys_to_delete = filter(_is_cache, data.keys())
for k in list(keys_to_delete):
del data[k]
return data
DB_DATA = read_json_test_data("router-info.json")
DB_DATA_NOCACHE = without_cache(DB_DATA)
@pytest.fixture @pytest.fixture
...@@ -83,8 +106,7 @@ def data_config_filename(): ...@@ -83,8 +106,7 @@ def data_config_filename():
] ]
} }
with open(os.path.join(TEST_DATA_DIRNAME, 'gws-direct.json')) as gws: config['gws-direct'] = read_json_test_data('gws-direct.json')
config['gws-direct'] = json.loads(gws.read())
f.write(json.dumps(config).encode('utf-8')) f.write(json.dumps(config).encode('utf-8'))
f.flush() f.flush()
...@@ -97,27 +119,69 @@ def data_config(data_config_filename): ...@@ -97,27 +119,69 @@ def data_config(data_config_filename):
return config.load(f) return config.load(f)
class MockedRedis(object): class DictOverlay:
"""Since we're dealing with a very large dictionary from 'router-info.json' we only
want to read that dictionary once and prevent mutating it during tests. We instead
use this class to capture all mutations in additional layer that is reset every test
db = None Not thread safe
"""
TOMBSTONE = object() # record deletions
def __init__(self, *args, **kwargs): def __init__(self, base_dict: dict) -> None:
_bootstrap_semaphore.acquire() self.dict = base_dict
self.overlay = {}
def __getitem__(self, key):
if key not in self.overlay:
return self.dict[key]
value = self.overlay[key]
if value is self.TOMBSTONE:
raise KeyError(key)
return value
def __setitem__(self, key, value):
self.overlay[key] = value
def __delitem__(self, key):
self.overlay[key] = self.TOMBSTONE
def __contains__(self, key):
if key in self.overlay:
return self.overlay[key] is not self.TOMBSTONE
return key in self.dict
def get(self, key, default=None):
try: try:
if MockedRedis.db is None: return self[key]
MockedRedis.prep() except KeyError:
finally: return default
_bootstrap_semaphore.release()
def keys(self):
# allows us to create other mocks using a different data source file deleted_keys = {k for k, v in self.overlay.items() if v is self.TOMBSTONE}
@staticmethod return (self.dict.keys() | self.overlay.keys()) - deleted_keys
def prep(data_source_file="router-info.json"):
test_data_filename = os.path.join( def items(self):
TEST_DATA_DIRNAME, return ((key, self[key]) for key in self.keys())
data_source_file)
with open(test_data_filename) as f: _missing = object
MockedRedis.db = json.loads(f.read())
MockedRedis.db['db:latch'] = json.dumps({ def pop(self, key, default=_missing):
try:
value = self[key]
except KeyError:
if default is self._missing:
raise
return default
self.overlay[key] = self.TOMBSTONE
return value
class MockedRedis:
def __init__(self, *args, **kwargs):
self.db = DictOverlay(DB_DATA_NOCACHE)
self.db['db:latch'] = json.dumps({
'current': 0, 'current': 0,
'next': 0, 'next': 0,
'this': 0, 'this': 0,
...@@ -125,44 +189,33 @@ class MockedRedis(object): ...@@ -125,44 +189,33 @@ class MockedRedis(object):
'failure': False 'failure': False
}) })
# remove any cached data from the captured snapshot
def _is_cache(s):
if s.startswith('classifier-cache'):
return True
if s.startswith('joblog'):
return True
return False
keys_to_delete = filter(_is_cache, MockedRedis.db.keys())
for k in list(keys_to_delete):
del MockedRedis.db[k]
def set(self, name, value): def set(self, name, value):
MockedRedis.db[name] = value self.db[name] = value
def get(self, name): def get(self, name):
value = MockedRedis.db.get(name, None) value = self.db.get(name, None)
if value is None: if value is None:
return None return None
return value.encode('utf-8') return value.encode('utf-8')
def exists(self, name): def exists(self, name):
return name in MockedRedis.db return name in self.db
def delete(self, key): def delete(self, key):
if isinstance(key, bytes): if isinstance(key, bytes):
key = key.decode('utf-8') key = key.decode('utf-8')
# redis ignores delete for keys that don't exist
if key in MockedRedis.db: if key in self.db:
del MockedRedis.db[key] del self.db[key]
def scan_iter(self, glob=None, count='unused'): def scan_iter(self, glob=None, count='unused'):
if not glob: if not glob:
for k in list(MockedRedis.db.keys()): for k in list(self.db.keys()):
yield k.encode('utf-8') yield k.encode('utf-8')
m = re.match(r'^([^*]+)\*$', glob) m = re.match(r'^([^*]+)\*$', glob)
assert m # all expected globs are like this assert m # all expected globs are like this
for k in list(MockedRedis.db.keys()): for k in list(self.db.keys()):
if k.startswith(m.group(1)): if k.startswith(m.group(1)):
yield k.encode('utf-8') yield k.encode('utf-8')
...@@ -180,16 +233,13 @@ class MockedRedis(object): ...@@ -180,16 +233,13 @@ class MockedRedis(object):
return self return self
@pytest.fixture @pytest.fixture(scope="session")
def cached_test_data(): def cached_test_data():
filename = os.path.join(TEST_DATA_DIRNAME, "router-info.json") return DB_DATA
with open(filename) as f:
return json.loads(f.read())
@pytest.fixture @pytest.fixture
def flask_config_filename(): def flask_config_filename():
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
f.write('ENABLE_TESTING_ROUTES = True\n'.encode('utf-8')) f.write('ENABLE_TESTING_ROUTES = True\n'.encode('utf-8'))
f.flush() f.flush()
...@@ -198,17 +248,17 @@ def flask_config_filename(): ...@@ -198,17 +248,17 @@ def flask_config_filename():
@pytest.fixture @pytest.fixture
def mocked_redis(mocker): def mocked_redis(mocker):
MockedRedis.db = None # force data to be reloaded instance = MockedRedis()
mocker.patch( mocker.patch(
'inventory_provider.tasks.common.redis.StrictRedis', 'inventory_provider.tasks.common.redis.StrictRedis',
MockedRedis) return_value=instance)
@pytest.fixture @pytest.fixture
def client(flask_config_filename, data_config_filename, mocked_redis): def client(flask_config_filename, data_config_filename, mocked_redis):
os.environ['FLASK_SETTINGS_FILENAME'] = flask_config_filename os.environ['FLASK_SETTINGS_FILENAME'] = flask_config_filename
os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'] = data_config_filename os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'] = data_config_filename
with inventory_provider.create_app().test_client() as c: with inventory_provider.create_app(setup_logging=False).test_client() as c:
yield c yield c
......
import glob
import json import json
import os import pathlib
import re import re
from unittest.mock import patch from unittest.mock import patch
...@@ -10,41 +9,31 @@ from ncclient.manager import make_device_handler, Manager ...@@ -10,41 +9,31 @@ from ncclient.manager import make_device_handler, Manager
from ncclient.transport import SSHSession from ncclient.transport import SSHSession
from ncclient.xml_ import NCElement from ncclient.xml_ import NCElement
import inventory_provider
from inventory_provider import juniper from inventory_provider import juniper
TEST_DATA_DIRNAME = os.path.realpath(os.path.join(
inventory_provider.__path__[0], TEST_DATA_DIR = pathlib.Path(__file__).parent / "../data"
"..",
"test",
"data"))
@pytest.fixture @pytest.fixture
def classifier_cache_test_entries(): def classifier_cache_test_entries():
filename = os.path.join( file = TEST_DATA_DIR.joinpath('classifier-cache-entries.json')
TEST_DATA_DIRNAME, 'classifier-cache-entries.json') return json.loads(file.read_text())
with open(filename) as f:
return json.loads(f.read())
def pytest_generate_tests(metafunc):
# TODO: can we really not get netconf data for all routers?
def _available_netconf_hosts():
for fn in glob.glob(os.path.join(TEST_DATA_DIRNAME, '*-netconf.xml')):
m = re.match('(.*)-netconf.xml', os.path.basename(fn))
assert m # sanity
yield m.group(1)
routers = list(_available_netconf_hosts()) @pytest.fixture(params=list(TEST_DATA_DIR.glob('*-netconf.xml')))
metafunc.parametrize("router", routers) def router(request):
file: pathlib.Path = request.param
m = re.match('(.*)-netconf.xml', file.name)
assert m # sanity
return m.group(1)
class MockedJunosRpc(object): class MockedJunosRpc(object):
def __init__(self, hostname): def __init__(self, hostname):
filename = os.path.join(TEST_DATA_DIRNAME, "%s-netconf.xml" % hostname) file = TEST_DATA_DIR.joinpath(f'{hostname}-netconf.xml')
self.config = etree.parse(filename) self.config = etree.parse(file)
def get_config(self): def get_config(self):
return self.config return self.config
...@@ -74,11 +63,10 @@ def netconf_doc(mocker, router, data_config): ...@@ -74,11 +63,10 @@ def netconf_doc(mocker, router, data_config):
@pytest.fixture @pytest.fixture
def interface_info_response(router): def interface_info_response(router):
filename = os.path.join(TEST_DATA_DIRNAME, 'interface_info', f'{router}.xml') file = TEST_DATA_DIR / f"interface_info/{router}.xml"
try: try:
with open(filename, 'r') as file: return file.read_text()
data = file.read()
return data
except FileNotFoundError: except FileNotFoundError:
pytest.skip(f'no corresponding interface_info doc for {router}, skipping') pytest.skip(f'no corresponding interface_info doc for {router}, skipping')
......
...@@ -4,6 +4,7 @@ and some data ends up in the right place ... otherwise not very detailed ...@@ -4,6 +4,7 @@ and some data ends up in the right place ... otherwise not very detailed
""" """
import re import re
import ncclient.transport.errors
import pytest import pytest
from inventory_provider.tasks import worker from inventory_provider.tasks import worker
...@@ -20,7 +21,15 @@ def backend_db(): ...@@ -20,7 +21,15 @@ def backend_db():
}).db }).db
def test_netconf_refresh_config(mocked_worker_module, router): @pytest.fixture
def mocked_juniper(mocker):
return mocker.patch(
"inventory_provider.juniper.get_interface_info_for_router",
side_effect=ncclient.transport.errors.SSHError
)
def test_netconf_refresh_config(mocked_worker_module, router, mocked_juniper):
if router in ['qfx.par.fr.geant.net', 'qfx.fra.de.geant.net']: if router in ['qfx.par.fr.geant.net', 'qfx.fra.de.geant.net']:
# expected to fail # expected to fail
pytest.skip(f'test data has no community string for {router}') pytest.skip(f'test data has no community string for {router}')
......
...@@ -64,7 +64,7 @@ def backend_db(): ...@@ -64,7 +64,7 @@ def backend_db():
def test_latchdb(data_config_filename, mocked_redis): def test_latchdb(data_config_filename, mocked_redis):
os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'] = data_config_filename os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'] = data_config_filename
monitor.run() monitor.run(setup_logging=False)
db = backend_db() db = backend_db()
......
...@@ -601,26 +601,26 @@ def test_persist_ims_data(mocker, data_config, mocked_redis): ...@@ -601,26 +601,26 @@ def test_persist_ims_data(mocker, data_config, mocked_redis):
r.delete(k) r.delete(k)
persist_ims_data(data) persist_ims_data(data)
assert [k.decode("utf-8") for k in r.keys("ims:pop_nodes:*")] == \ assert {k.decode("utf-8") for k in r.keys("ims:pop_nodes:*")} == \
["ims:pop_nodes:LOC A", "ims:pop_nodes:LOC B"] {"ims:pop_nodes:LOC A", "ims:pop_nodes:LOC B"}
assert [k.decode("utf-8") for k in r.keys("ims:location:*")] == \ assert {k.decode("utf-8") for k in r.keys("ims:location:*")} == \
["ims:location:eq_a", "ims:location:eq_b"] {"ims:location:eq_a", "ims:location:eq_b"}
assert [k.decode("utf-8") for k in r.keys("ims:lg:*")] == \ assert {k.decode("utf-8") for k in r.keys("ims:lg:*")} == \
["ims:lg:lg_eq1", "ims:lg:lg_eq2"] {"ims:lg:lg_eq1", "ims:lg:lg_eq2"}
assert [k.decode("utf-8") for k in r.keys("ims:circuit_hierarchy:*")] == \ assert {k.decode("utf-8") for k in r.keys("ims:circuit_hierarchy:*")} == \
["ims:circuit_hierarchy:123", "ims:circuit_hierarchy:456"] {"ims:circuit_hierarchy:123", "ims:circuit_hierarchy:456"}
assert [k.decode("utf-8") for k in r.keys("ims:interface_services:*")] == \ assert {k.decode("utf-8") for k in r.keys("ims:interface_services:*")} == \
["ims:interface_services:if1", "ims:interface_services:if2"] {"ims:interface_services:if1", "ims:interface_services:if2"}
assert [k.decode("utf-8") for k in r.keys("ims:node_pair_services:*")] == \ assert {k.decode("utf-8") for k in r.keys("ims:node_pair_services:*")} == \
["ims:node_pair_services:np1", "ims:node_pair_services:np2"] {"ims:node_pair_services:np1", "ims:node_pair_services:np2"}
assert [k.decode("utf-8") for k in r.keys("poller_cache:*")] == \ assert {k.decode("utf-8") for k in r.keys("poller_cache:*")} == \
["poller_cache:eq1", "poller_cache:eq2"] {"poller_cache:eq1", "poller_cache:eq2"}
assert json.loads(r.get("ims:sid_services").decode("utf-8")) == \ assert json.loads(r.get("ims:sid_services").decode("utf-8")) == \
data["sid_services"] data["sid_services"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment