From f56c6df35b4e600b8a562657788ac12ae38a5d77 Mon Sep 17 00:00:00 2001 From: Robert Latta <robert.latta@geant.org> Date: Fri, 17 Jan 2025 11:14:10 +0000 Subject: [PATCH] updated IMS class to use session. RE DBOARD3-1095 --- inventory_provider/db/ims.py | 44 +++++++++++++++++++----------------- test/test_ims.py | 31 +++++++++++-------------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/inventory_provider/db/ims.py b/inventory_provider/db/ims.py index cb7a8a2..898be67 100644 --- a/inventory_provider/db/ims.py +++ b/inventory_provider/db/ims.py @@ -199,10 +199,23 @@ class IMS(object): self.base_url = base_url self.username = username self.password = password - self.verify_ssl = verify_ssl - self.bearer_token = None - self.bearer_token_init_time = 0 self.reconnect_attempts = 0 + self.session = requests.session() + self.session.verify = verify_ssl + + @property + def bearer_token(self): + try: + return self.session.headers['Authorization'].replace('Bearer ', '') + except KeyError: + return None + + @bearer_token.setter + def bearer_token(self, value): + if value is None: + self.session.headers.pop('Authorization', None) + else: + self.session.headers.update({'Authorization': f'Bearer {value}'}) def __enter__(self): self._init_bearer_token() @@ -220,10 +233,7 @@ class IMS(object): for i in range(IMS.PERMITTED_RECONNECT_ATTEMPTS): try: - response = requests.get( - self.base_url + self.HEARTBEAT_PATH, - headers={'Authorization': f'Bearer {self.bearer_token}'}, - verify=self.verify_ssl) + response = self.session.get(self.base_url + self.HEARTBEAT_PATH) return response.status_code == requests.codes.ok except HTTPError as e: logger.error(f'Error checking heartbeat (attempt {i + 1}): {e}') @@ -241,9 +251,9 @@ class IMS(object): logger.debug(f'Logging in - Username: {self.username}' f' - URL: {self.base_url + self.LOGIN_PATH}') - response = requests.post( + response = self.session.post( self.base_url + self.LOGIN_PATH, - auth=(self.username, self.password), verify=self.verify_ssl) + auth=(self.username, self.password)) response.raise_for_status() self.bearer_token = response.text @@ -251,11 +261,7 @@ class IMS(object): if self.bearer_token is None: return logout_url = self.base_url + self.LOGOUT_PATH - response = requests.get( - logout_url, - headers={'Authorization': f'Bearer {self.bearer_token}'}, - verify=self.verify_ssl - ) + response = self.session.get(logout_url) try: response.raise_for_status() except HTTPError as e: @@ -267,10 +273,9 @@ class IMS(object): self._init_bearer_token() while True: logger.info('Clearing Dynamic Context Cache') - response = requests.put( + response = self.session.put( f'{self.base_url + IMS.IMS_PATH}/ClearDynamicContextCache', - headers={'Authorization': f'Bearer {self.bearer_token}'}, - verify=self.verify_ssl) + ) if response.status_code == 401: self._init_bearer_token() continue @@ -332,10 +337,7 @@ class IMS(object): return source while True: - response = requests.get( - url, - headers={'Authorization': f'Bearer {self.bearer_token}'}, - params=params, verify=self.verify_ssl) + response = self.session.get(url, params=params) if _is_invalid_login_state(response): self._init_bearer_token() else: diff --git a/test/test_ims.py b/test/test_ims.py index 4ef41d1..afa88c2 100644 --- a/test/test_ims.py +++ b/test/test_ims.py @@ -22,16 +22,16 @@ PASSWORD = "dummy_password" @pytest.fixture def ims_instance(mocker): - mock_post = mocker.patch('inventory_provider.db.ims.requests.post') + ims_instance = PatchedIMS(BASE_URL, USERNAME, PASSWORD) + mock_post = mocker.patch.object(ims_instance.session, 'post') mock_post.return_value.status_code = 200 mock_post.return_value.text = 'dummy_bt' - return PatchedIMS(BASE_URL, USERNAME, PASSWORD) + return ims_instance def test_init_bearer_token(ims_instance): ims_instance._init_bearer_token() - assert ims_instance.bearer_token == 'dummy_bt' assert ims_instance.reconnect_attempts == 0 @@ -44,7 +44,7 @@ def test_heartbeat(mocker, ims_instance): # set the bearer token ims_instance.bearer_token = 'dummy_bt' - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_get.return_value.status_code = 200 assert ims_instance._is_heartbeat_ok() @@ -76,7 +76,7 @@ def test_get_entity_by_id(mocker, ims_instance): entity_id = '12345' expected_entity = {'id': entity_id, 'name': 'Test Entity'} - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_get.return_value.status_code = 200 mocked_get.return_value.json.return_value = expected_entity @@ -90,7 +90,7 @@ def test_get_entity_by_name(mocker, ims_instance): entity_name = 'dummy_entity_name' expected_entity = {'id': '1234', 'name': entity_name} - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_get.return_value.status_code = 200 mocked_get.return_value.json.return_value = expected_entity @@ -104,7 +104,7 @@ def test_get_filtered_entities(mocker, ims_instance): entity_filter = 'dummy_param=dummy value' expected_result = [{'id': '1234', 'name': 'entity_name'}] - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_get.return_value.status_code = 200 mocked_get.return_value.json.return_value = expected_result @@ -117,7 +117,7 @@ def test_get_entities(mocker, ims_instance): entity = 'dummy_entity_type' expected_result = [{'id': '1234', 'name': 'entity_name'}] - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_get.return_value.status_code = 200 mocked_get.return_value.json.return_value = expected_result @@ -140,7 +140,7 @@ def test_get_paged_entities(step_count, ims_instance, mocker): {'id': '12348', 'name': 'entity_name_8'}, {'id': '12349', 'name': 'entity_name_9'}, ] - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_responses = [] for i in range(0, len(expected_result) + 1, step_count): @@ -159,10 +159,7 @@ def test_get_paged_entities(step_count, ims_instance, mocker): def test_context_manager(ims_instance, mocker): - mock_post = mocker.patch('inventory_provider.db.ims.requests.post') - mock_post.return_value.status_code = 200 - mock_post.return_value.text = 'dummy_bt' - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocked_get.return_value.status_code = 200 mocked_get.return_value.json.return_value = [] @@ -174,17 +171,15 @@ def test_context_manager(ims_instance, mocker): def test_navigation_properties(ims_instance, mocker): - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') ims_instance.get_entity_by_id('Node', 1234, navigation_properties=[1, 2, 3]) mocked_get.assert_called_with( - f'{BASE_URL}{IMS.IMS_PATH}/Node/1234', - headers={'Authorization': 'Bearer dummy_bt'}, - params={'navigationproperty': 6}, verify=False) + f'{BASE_URL}{IMS.IMS_PATH}/Node/1234', params={'navigationproperty': 6}) def test_get_entities_gateway_error(ims_instance, mocker): - mocked_get = mocker.patch('inventory_provider.db.ims.requests.get') + mocked_get = mocker.patch.object(ims_instance.session, 'get') mocker.patch('inventory_provider.db.ims.time') error_message = 'Dummy Gateway Error' -- GitLab