diff --git a/inventory_provider/db/ims.py b/inventory_provider/db/ims.py index 34183c450d4ccae71c126717bede8233c091f10c..898be674332dd4a9daffdd76d90de066f55f99d3 100644 --- a/inventory_provider/db/ims.py +++ b/inventory_provider/db/ims.py @@ -4,13 +4,13 @@ import logging import time import requests +from requests import HTTPError from inventory_provider import environment # Navigation Properties # http://149.210.162.190:81/ImsVersions/4.19.9/html/86d07a57-fa45-835e-d4a2-a789c4acbc96.htm # noqa -from requests import HTTPError logger = logging.getLogger(__name__) log_entry_and_exit = functools.partial( @@ -19,8 +19,8 @@ log_entry_and_exit = functools.partial( # http://149.210.162.190:81/ImsVersions/21.9/html/50e6a1b1-3910-2091-63d5-e13777b2194e.htm # noqa CIRCUIT_CUSTOMER_RELATION = { - "Circuit": 2, - "Customer": 4 + 'Circuit': 2, + 'Customer': 4 } # http://149.210.162.190:81/ImsVersions/20.1/html/86d07a57-fa45-835e-d4a2-a789c4acbc96.htm # noqa CIRCUIT_PROPERTIES = { @@ -189,59 +189,95 @@ class IMSError(Exception): class IMS(object): - TIMEOUT_THRESHOLD = 1200 PERMITTED_RECONNECT_ATTEMPTS = 3 LOGIN_PATH = '/login' + LOGOUT_PATH = '/login/logout' + HEARTBEAT_PATH = '/login/heartbeat' IMS_PATH = '/ims' - cache = {} - base_url = None - bearer_token = None - bearer_token_init_time = 0 - reconnect_attempts = 0 - - def __init__(self, base_url, username, password, bearer_token=None, - verify_ssl=False): - IMS.base_url = base_url + def __init__(self, base_url, username, password, verify_ssl=False): + self.base_url = base_url self.username = username self.password = password - self.verify_ssl = verify_ssl - IMS.bearer_token = bearer_token - - @classmethod - def _init_bearer_token(cls, username, password, verify_ssl=False): - re_init_time = time.time() - if not cls.bearer_token or \ - re_init_time - cls.bearer_token_init_time \ - > cls.TIMEOUT_THRESHOLD: - cls.reconnect_attempts = 0 + self.reconnect_attempts = 0 + self.session = requests.session() + self.session.verify = verify_ssl + + @property + def bearer_token(self): + try: + return self.session.headers['Authorization'].replace('Bearer ', '') + except KeyError: + return None + + @bearer_token.setter + def bearer_token(self, value): + if value is None: + self.session.headers.pop('Authorization', None) else: - cls.reconnect_attempts += 1 - if cls.reconnect_attempts > cls.PERMITTED_RECONNECT_ATTEMPTS: + self.session.headers.update({'Authorization': f'Bearer {value}'}) + + def __enter__(self): + self._init_bearer_token() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.logout() + + def __del__(self): + self.logout() + + def _is_heartbeat_ok(self): + if not self.bearer_token: + return False + + for i in range(IMS.PERMITTED_RECONNECT_ATTEMPTS): + try: + response = self.session.get(self.base_url + self.HEARTBEAT_PATH) + return response.status_code == requests.codes.ok + except HTTPError as e: + logger.error(f'Error checking heartbeat (attempt {i + 1}): {e}') + raise IMSError("Failed to get heartbeat") + + def _init_bearer_token(self): + if not self.bearer_token: + self.reconnect_attempts = 0 + elif self._is_heartbeat_ok(): + return + else: + self.reconnect_attempts += 1 + if self.reconnect_attempts > self.PERMITTED_RECONNECT_ATTEMPTS: raise IMSError('Too many reconnection attempts made') - logger.debug(f'Logging in - Username: {username}' - f' - URL: {cls.base_url + cls.LOGIN_PATH}') - response = requests.post( - cls.base_url + cls.LOGIN_PATH, - auth=(username, password), verify=verify_ssl) + logger.debug(f'Logging in - Username: {self.username}' + f' - URL: {self.base_url + self.LOGIN_PATH}') + response = self.session.post( + self.base_url + self.LOGIN_PATH, + auth=(self.username, self.password)) response.raise_for_status() - cls.bearer_token_init_time = re_init_time - cls.bearer_token = response.text + self.bearer_token = response.text + + def logout(self): + if self.bearer_token is None: + return + logout_url = self.base_url + self.LOGOUT_PATH + response = self.session.get(logout_url) + try: + response.raise_for_status() + except HTTPError as e: + logger.error(f'Error logging out: {e}') + self.bearer_token = None def clear_dynamic_context_cache(self): - if not IMS.bearer_token: - IMS._init_bearer_token( - self.username, self.password, self.verify_ssl) + if not self.bearer_token: + self._init_bearer_token() while True: logger.info('Clearing Dynamic Context Cache') - response = requests.put( + response = self.session.put( f'{self.base_url + IMS.IMS_PATH}/ClearDynamicContextCache', - headers={'Authorization': f'Bearer {self.bearer_token}'}, - verify=self.verify_ssl) + ) if response.status_code == 401: - IMS._init_bearer_token( - self.username, self.password, self.verify_ssl) + self._init_bearer_token() continue response.raise_for_status() break @@ -257,9 +293,8 @@ class IMS(object): params['navigationproperty'] = navigation_properties if isinstance( navigation_properties, int) else sum(navigation_properties) - if not IMS.bearer_token: - IMS._init_bearer_token( - self.username, self.password, self.verify_ssl) + if not self.bearer_token: + self._init_bearer_token() def _is_invalid_login_state(response_): if response_.status_code == requests.codes.unauthorized: @@ -275,11 +310,11 @@ class IMS(object): except Exception as e: t = response_.text if len(t) > 100: - message_text = f"{t[:50]} ... {t[-50:]}" + message_text = f'{t[:50]} ... {t[-50:]}' else: message_text = t - logger.debug(f"unexpected response: {message_text}\n{e}" - "\nre-raising") + logger.debug(f'unexpected response: {message_text}\n{e}' + '\nre-raising') raise e if r and 'haserrors' in r and r['haserrors']: for e in r['Errors']: @@ -302,13 +337,9 @@ class IMS(object): return source while True: - response = requests.get( - url, - headers={'Authorization': f'Bearer {self.bearer_token}'}, - params=params, verify=self.verify_ssl) + response = self.session.get(url, params=params) if _is_invalid_login_state(response): - IMS._init_bearer_token( - self.username, self.password, self.verify_ssl) + self._init_bearer_token() else: response.raise_for_status() orig = response.json() @@ -376,12 +407,12 @@ class IMS(object): elif r.status_code == 504: gateway_error_count += 1 logger.debug( - f"GATEWAY TIME-OUT for {url}" - f" -- COUNT: {gateway_error_count}") + f'GATEWAY TIME-OUT for {url}' + f' -- COUNT: {gateway_error_count}') if gateway_error_count > 4: raise e time.sleep(5) - logger.debug("WAKING UP") + logger.debug('WAKING UP') more_to_come = True continue else: diff --git a/test/test_ims.py b/test/test_ims.py index 44b83313a7acbff0c5be7e9e3d329acbb1fe782d..afa88c2b27f4a10f59dbd9aa6ec81a916af0bee2 100644 --- a/test/test_ims.py +++ b/test/test_ims.py @@ -1,160 +1,193 @@ -from requests import HTTPError +import json -import inventory_provider -from inventory_provider.db.ims import IMSError +import requests +from requests import HTTPError +import pytest +from requests.models import Response +from inventory_provider.db.ims import IMS, IMSError -class MockResponse: - def __init__(self, json_data, status_code): - self.json_data = json_data - self.status_code = status_code - self.text = '' if json_data else 'No records found for Entity:XXXXX' - def json(self): - return self.json_data +class PatchedIMS(IMS): - def raise_for_status(self): + def __del__(self): + # patched to stop call to logout pass -def test_ims_class_login(mocker): - mock_post = mocker.patch('inventory_provider.db.ims.requests.post') - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - mock_post.return_value = MockResponse("my_bearer_token", 200) - mock_post.return_value.text = "my_bearer_token" - - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password') - ds.get_entity_by_id('Node', 1234) - mock_post.assert_called_once_with( - 'dummy_base/login', - auth=('dummy_username', 'dummy_password'), verify=False) - mock_get.assert_called_once_with( - 'dummy_base/ims/Node/1234', - headers={'Authorization': 'Bearer my_bearer_token'}, - params=None, verify=False) - - -def test_ims_failed_response(mocker): - mock_post = mocker.patch('inventory_provider.db.ims.requests.post') - mock_post.return_value = MockResponse("my_bearer_token", 200) - mock_post.return_value.text = "my_bearer_token" - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - mock_get.return_value.status_code = 200 - mock_get.return_value.json.side_effect = IMSError("IMS exception") - - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password') - - try: - ds.get_entity_by_id('Node', 1234) - except Exception as e: - assert isinstance(e, IMSError) - - -def test_ims_class_entity_by_id(mocker): - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt') - - ds.get_entity_by_id('Node', 1234) - mock_get.assert_called_once_with( - 'dummy_base/ims/Node/1234', - headers={'Authorization': 'Bearer dummy_bt'}, - params=None, verify=False) - - -def test_ims_class_entity_by_name(mocker): - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt') - - ds.get_entity_by_name('Node', 'dummy_name') - mock_get.assert_called_once_with( - 'dummy_base/ims/Node/byname/"dummy_name"', - headers={'Authorization': 'Bearer dummy_bt'}, - params=None, verify=False) - - -def test_ims_class_filtered_entities(mocker): - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt') - - list(ds.get_filtered_entities( - 'Node', 'dummy_param=dummy value', step_count=50)) - mock_get.assert_called_once_with( - 'dummy_base/ims/Node/filtered/dummy_param=dummy value', - headers={'Authorization': 'Bearer dummy_bt'}, - params={ - 'paginatorStartElement': 0, - 'paginatorNumberOfElements': 50 - }, verify=False) - - def side_effect(*args, **kargs): - if kargs['params']['paginatorStartElement'] == 0: - return MockResponse([1, 2], 200) - return MockResponse([3], 200) - - mock_multi_get = mocker.patch( - 'inventory_provider.db.ims.requests.get', side_effect=side_effect) - - res = list(ds.get_filtered_entities( - 'Node', 'dummy_param=dummy value', step_count=2)) - mock_multi_get.assert_any_call( - 'dummy_base/ims/Node/filtered/dummy_param=dummy value', - headers={'Authorization': 'Bearer dummy_bt'}, - params={ - 'paginatorStartElement': 0, - 'paginatorNumberOfElements': 2 - }, verify=False) - mock_multi_get.assert_any_call( - 'dummy_base/ims/Node/filtered/dummy_param=dummy value', - headers={'Authorization': 'Bearer dummy_bt'}, - params={ - 'paginatorStartElement': 2, - 'paginatorNumberOfElements': 2 - }, verify=False) - assert mock_multi_get.call_count == 2 - assert res == [1, 2, 3] - - def side_effect_no_recs(*args, **kargs): - if kargs['params']['paginatorStartElement'] == 0: - return MockResponse([1, 2], 200) - e = HTTPError() - e.response = MockResponse('', 404) - raise e - - mocker.patch('inventory_provider.db.ims.requests.get', - side_effect=side_effect_no_recs) - - res = list(ds.get_filtered_entities( - 'Node', 'dummy_param=dummy value', step_count=2)) - assert res == [1, 2] - - -def test_ims_class_get_all_entities(mocker): - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt') - - list(ds.get_all_entities('Node', step_count=10)) - mock_get.assert_called_once_with( - 'dummy_base/ims/Node/all', - headers={'Authorization': 'Bearer dummy_bt'}, - params={ - 'paginatorStartElement': 0, - 'paginatorNumberOfElements': 10 - }, verify=False) - - -def test_ims_class_navigation_properties(mocker): - mock_get = mocker.patch('inventory_provider.db.ims.requests.get') - ds = inventory_provider.db.ims.IMS( - 'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt') - - ds.get_entity_by_id('Node', 1234, navigation_properties=[1, 2, 3]) - mock_get.assert_called_with( - 'dummy_base/ims/Node/1234', - headers={'Authorization': 'Bearer dummy_bt'}, - params={'navigationproperty': 6}, verify=False) +BASE_URL = "http://dummy_base" +USERNAME = "dummy_username" +PASSWORD = "dummy_password" + + +@pytest.fixture +def ims_instance(mocker): + ims_instance = PatchedIMS(BASE_URL, USERNAME, PASSWORD) + mock_post = mocker.patch.object(ims_instance.session, 'post') + mock_post.return_value.status_code = 200 + mock_post.return_value.text = 'dummy_bt' + return ims_instance + + +def test_init_bearer_token(ims_instance): + + ims_instance._init_bearer_token() + assert ims_instance.bearer_token == 'dummy_bt' + assert ims_instance.reconnect_attempts == 0 + + +def test_heartbeat(mocker, ims_instance): + + # there should be no bearer token set + assert ims_instance._is_heartbeat_ok() is False + + # set the bearer token + ims_instance.bearer_token = 'dummy_bt' + + mocked_get = mocker.patch.object(ims_instance.session, 'get') + + mocked_get.return_value.status_code = 200 + assert ims_instance._is_heartbeat_ok() + + mocked_get.return_value.status_code = 401 + assert ims_instance._is_heartbeat_ok() is False + + dummy_exception = HTTPError('http://dummy_base/heartbeat', 500, 'Dummy', None, None) + responses = [dummy_exception] * IMS.PERMITTED_RECONNECT_ATTEMPTS + mocked_get.side_effect = responses + with pytest.raises(IMSError) as exception_info: + ims_instance._is_heartbeat_ok() + assert exception_info.value.message == 'Dummy' + + r = Response() + r.status_code = 200 + responses[-1] = r + mocked_get.side_effect = responses + assert ims_instance._is_heartbeat_ok() + + r.status_code = 401 + responses[-1] = r + mocked_get.side_effect = responses + assert ims_instance._is_heartbeat_ok() is False + + +def test_get_entity_by_id(mocker, ims_instance): + entity = 'dummy_entity_type' + entity_id = '12345' + expected_entity = {'id': entity_id, 'name': 'Test Entity'} + + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json.return_value = expected_entity + + result = ims_instance.get_entity_by_id(entity, entity_id) + + assert result == expected_entity + + +def test_get_entity_by_name(mocker, ims_instance): + entity = 'dummy_entity_tyoe' + entity_name = 'dummy_entity_name' + expected_entity = {'id': '1234', 'name': entity_name} + + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json.return_value = expected_entity + + result = ims_instance.get_entity_by_name(entity, entity_name) + + assert result == expected_entity + + +def test_get_filtered_entities(mocker, ims_instance): + entity = 'dummy_entity_tyoe' + entity_filter = 'dummy_param=dummy value' + expected_result = [{'id': '1234', 'name': 'entity_name'}] + + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json.return_value = expected_result + + result = ims_instance.get_filtered_entities(entity, entity_filter) + + assert list(result) == expected_result + + +def test_get_entities(mocker, ims_instance): + entity = 'dummy_entity_type' + expected_result = [{'id': '1234', 'name': 'entity_name'}] + + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json.return_value = expected_result + + result = ims_instance.get_entities(entity) + + assert list(result) == expected_result + + +@pytest.mark.parametrize('step_count', [1, 3, 4, 9, 10]) +def test_get_paged_entities(step_count, ims_instance, mocker): + entity = 'dummy_entity_type' + expected_result = [ + {'id': '12341', 'name': 'entity_name_1'}, + {'id': '12342', 'name': 'entity_name_2'}, + {'id': '12343', 'name': 'entity_name_3'}, + {'id': '12344', 'name': 'entity_name_4'}, + {'id': '12345', 'name': 'entity_name_5'}, + {'id': '12346', 'name': 'entity_name_6'}, + {'id': '12347', 'name': 'entity_name_7'}, + {'id': '12348', 'name': 'entity_name_8'}, + {'id': '12349', 'name': 'entity_name_9'}, + ] + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocked_responses = [] + + for i in range(0, len(expected_result) + 1, step_count): + resp = Response() + resp.status_code = 200 + if i >= len(expected_result): + resp_content = {} + else: + resp_content = expected_result[i:i + step_count] + resp._content = str.encode(json.dumps(resp_content)) + mocked_responses.append(resp) + mocked_get.side_effect = mocked_responses + + result = ims_instance.get_all_entities(entity, step_count=step_count) + assert list(result) == expected_result + + +def test_context_manager(ims_instance, mocker): + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json.return_value = [] + + with ims_instance as ims: + list(ims_instance.get_all_entities('dummy')) + assert ims.bearer_token == 'dummy_bt' + + assert ims_instance.bearer_token is None + + +def test_navigation_properties(ims_instance, mocker): + mocked_get = mocker.patch.object(ims_instance.session, 'get') + + ims_instance.get_entity_by_id('Node', 1234, navigation_properties=[1, 2, 3]) + mocked_get.assert_called_with( + f'{BASE_URL}{IMS.IMS_PATH}/Node/1234', params={'navigationproperty': 6}) + + +def test_get_entities_gateway_error(ims_instance, mocker): + mocked_get = mocker.patch.object(ims_instance.session, 'get') + mocker.patch('inventory_provider.db.ims.time') + + error_message = 'Dummy Gateway Error' + r = Response() + r.status_code = requests.codes.gateway_timeout + r._content = str.encode(error_message) + mocked_get.side_effect = [r] * 5 + with pytest.raises(HTTPError) as exception_info: + list(ims_instance.get_entities('dummy')) + assert exception_info.value.message == error_message + assert mocked_get.call_count == 4