From f56c6df35b4e600b8a562657788ac12ae38a5d77 Mon Sep 17 00:00:00 2001
From: Robert Latta <robert.latta@geant.org>
Date: Fri, 17 Jan 2025 11:14:10 +0000
Subject: [PATCH] updated IMS class to use session. RE DBOARD3-1095

---
 inventory_provider/db/ims.py | 44 +++++++++++++++++++-----------------
 test/test_ims.py             | 31 +++++++++++--------------
 2 files changed, 36 insertions(+), 39 deletions(-)

diff --git a/inventory_provider/db/ims.py b/inventory_provider/db/ims.py
index cb7a8a2..898be67 100644
--- a/inventory_provider/db/ims.py
+++ b/inventory_provider/db/ims.py
@@ -199,10 +199,23 @@ class IMS(object):
         self.base_url = base_url
         self.username = username
         self.password = password
-        self.verify_ssl = verify_ssl
-        self.bearer_token = None
-        self.bearer_token_init_time = 0
         self.reconnect_attempts = 0
+        self.session = requests.session()
+        self.session.verify = verify_ssl
+
+    @property
+    def bearer_token(self):
+        try:
+            return self.session.headers['Authorization'].replace('Bearer ', '')
+        except KeyError:
+            return None
+
+    @bearer_token.setter
+    def bearer_token(self, value):
+        if value is None:
+            self.session.headers.pop('Authorization', None)
+        else:
+            self.session.headers.update({'Authorization': f'Bearer {value}'})
 
     def __enter__(self):
         self._init_bearer_token()
@@ -220,10 +233,7 @@ class IMS(object):
 
         for i in range(IMS.PERMITTED_RECONNECT_ATTEMPTS):
             try:
-                response = requests.get(
-                    self.base_url + self.HEARTBEAT_PATH,
-                    headers={'Authorization': f'Bearer {self.bearer_token}'},
-                    verify=self.verify_ssl)
+                response = self.session.get(self.base_url + self.HEARTBEAT_PATH)
                 return response.status_code == requests.codes.ok
             except HTTPError as e:
                 logger.error(f'Error checking heartbeat (attempt {i + 1}): {e}')
@@ -241,9 +251,9 @@ class IMS(object):
 
         logger.debug(f'Logging in - Username: {self.username}'
                      f' - URL: {self.base_url + self.LOGIN_PATH}')
-        response = requests.post(
+        response = self.session.post(
             self.base_url + self.LOGIN_PATH,
-            auth=(self.username, self.password), verify=self.verify_ssl)
+            auth=(self.username, self.password))
         response.raise_for_status()
         self.bearer_token = response.text
 
@@ -251,11 +261,7 @@ class IMS(object):
         if self.bearer_token is None:
             return
         logout_url = self.base_url + self.LOGOUT_PATH
-        response = requests.get(
-            logout_url,
-            headers={'Authorization': f'Bearer {self.bearer_token}'},
-            verify=self.verify_ssl
-        )
+        response = self.session.get(logout_url)
         try:
             response.raise_for_status()
         except HTTPError as e:
@@ -267,10 +273,9 @@ class IMS(object):
             self._init_bearer_token()
         while True:
             logger.info('Clearing Dynamic Context Cache')
-            response = requests.put(
+            response = self.session.put(
                 f'{self.base_url + IMS.IMS_PATH}/ClearDynamicContextCache',
-                headers={'Authorization': f'Bearer {self.bearer_token}'},
-                verify=self.verify_ssl)
+            )
             if response.status_code == 401:
                 self._init_bearer_token()
                 continue
@@ -332,10 +337,7 @@ class IMS(object):
             return source
 
         while True:
-            response = requests.get(
-                url,
-                headers={'Authorization': f'Bearer {self.bearer_token}'},
-                params=params, verify=self.verify_ssl)
+            response = self.session.get(url, params=params)
             if _is_invalid_login_state(response):
                 self._init_bearer_token()
             else:
diff --git a/test/test_ims.py b/test/test_ims.py
index 4ef41d1..afa88c2 100644
--- a/test/test_ims.py
+++ b/test/test_ims.py
@@ -22,16 +22,16 @@ PASSWORD = "dummy_password"
 
 @pytest.fixture
 def ims_instance(mocker):
-    mock_post = mocker.patch('inventory_provider.db.ims.requests.post')
+    ims_instance = PatchedIMS(BASE_URL, USERNAME, PASSWORD)
+    mock_post = mocker.patch.object(ims_instance.session, 'post')
     mock_post.return_value.status_code = 200
     mock_post.return_value.text = 'dummy_bt'
-    return PatchedIMS(BASE_URL, USERNAME, PASSWORD)
+    return ims_instance
 
 
 def test_init_bearer_token(ims_instance):
 
     ims_instance._init_bearer_token()
-
     assert ims_instance.bearer_token == 'dummy_bt'
     assert ims_instance.reconnect_attempts == 0
 
@@ -44,7 +44,7 @@ def test_heartbeat(mocker, ims_instance):
     # set the bearer token
     ims_instance.bearer_token = 'dummy_bt'
 
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
 
     mocked_get.return_value.status_code = 200
     assert ims_instance._is_heartbeat_ok()
@@ -76,7 +76,7 @@ def test_get_entity_by_id(mocker, ims_instance):
     entity_id = '12345'
     expected_entity = {'id': entity_id, 'name': 'Test Entity'}
 
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocked_get.return_value.status_code = 200
     mocked_get.return_value.json.return_value = expected_entity
 
@@ -90,7 +90,7 @@ def test_get_entity_by_name(mocker, ims_instance):
     entity_name = 'dummy_entity_name'
     expected_entity = {'id': '1234', 'name': entity_name}
 
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocked_get.return_value.status_code = 200
     mocked_get.return_value.json.return_value = expected_entity
 
@@ -104,7 +104,7 @@ def test_get_filtered_entities(mocker, ims_instance):
     entity_filter = 'dummy_param=dummy value'
     expected_result = [{'id': '1234', 'name': 'entity_name'}]
 
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocked_get.return_value.status_code = 200
     mocked_get.return_value.json.return_value = expected_result
 
@@ -117,7 +117,7 @@ def test_get_entities(mocker, ims_instance):
     entity = 'dummy_entity_type'
     expected_result = [{'id': '1234', 'name': 'entity_name'}]
 
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocked_get.return_value.status_code = 200
     mocked_get.return_value.json.return_value = expected_result
 
@@ -140,7 +140,7 @@ def test_get_paged_entities(step_count, ims_instance, mocker):
         {'id': '12348', 'name': 'entity_name_8'},
         {'id': '12349', 'name': 'entity_name_9'},
     ]
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocked_responses = []
 
     for i in range(0, len(expected_result) + 1, step_count):
@@ -159,10 +159,7 @@ def test_get_paged_entities(step_count, ims_instance, mocker):
 
 
 def test_context_manager(ims_instance, mocker):
-    mock_post = mocker.patch('inventory_provider.db.ims.requests.post')
-    mock_post.return_value.status_code = 200
-    mock_post.return_value.text = 'dummy_bt'
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocked_get.return_value.status_code = 200
     mocked_get.return_value.json.return_value = []
 
@@ -174,17 +171,15 @@ def test_context_manager(ims_instance, mocker):
 
 
 def test_navigation_properties(ims_instance, mocker):
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
 
     ims_instance.get_entity_by_id('Node', 1234, navigation_properties=[1, 2, 3])
     mocked_get.assert_called_with(
-        f'{BASE_URL}{IMS.IMS_PATH}/Node/1234',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params={'navigationproperty': 6}, verify=False)
+        f'{BASE_URL}{IMS.IMS_PATH}/Node/1234', params={'navigationproperty': 6})
 
 
 def test_get_entities_gateway_error(ims_instance, mocker):
-    mocked_get = mocker.patch('inventory_provider.db.ims.requests.get')
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
     mocker.patch('inventory_provider.db.ims.time')
 
     error_message = 'Dummy Gateway Error'
-- 
GitLab