diff --git a/inventory_provider/db/ims.py b/inventory_provider/db/ims.py
index 34183c450d4ccae71c126717bede8233c091f10c..898be674332dd4a9daffdd76d90de066f55f99d3 100644
--- a/inventory_provider/db/ims.py
+++ b/inventory_provider/db/ims.py
@@ -4,13 +4,13 @@ import logging
 import time
 
 import requests
+from requests import HTTPError
 
 from inventory_provider import environment
 
 
 # Navigation Properties
 # http://149.210.162.190:81/ImsVersions/4.19.9/html/86d07a57-fa45-835e-d4a2-a789c4acbc96.htm  # noqa
-from requests import HTTPError
 
 logger = logging.getLogger(__name__)
 log_entry_and_exit = functools.partial(
@@ -19,8 +19,8 @@ log_entry_and_exit = functools.partial(
 
 # http://149.210.162.190:81/ImsVersions/21.9/html/50e6a1b1-3910-2091-63d5-e13777b2194e.htm  # noqa
 CIRCUIT_CUSTOMER_RELATION = {
-    "Circuit": 2,
-    "Customer": 4
+    'Circuit': 2,
+    'Customer': 4
 }
 # http://149.210.162.190:81/ImsVersions/20.1/html/86d07a57-fa45-835e-d4a2-a789c4acbc96.htm  # noqa
 CIRCUIT_PROPERTIES = {
@@ -189,59 +189,95 @@ class IMSError(Exception):
 
 class IMS(object):
 
-    TIMEOUT_THRESHOLD = 1200
     PERMITTED_RECONNECT_ATTEMPTS = 3
     LOGIN_PATH = '/login'
+    LOGOUT_PATH = '/login/logout'
+    HEARTBEAT_PATH = '/login/heartbeat'
     IMS_PATH = '/ims'
 
-    cache = {}
-    base_url = None
-    bearer_token = None
-    bearer_token_init_time = 0
-    reconnect_attempts = 0
-
-    def __init__(self, base_url, username, password, bearer_token=None,
-                 verify_ssl=False):
-        IMS.base_url = base_url
+    def __init__(self, base_url, username, password, verify_ssl=False):
+        self.base_url = base_url
         self.username = username
         self.password = password
-        self.verify_ssl = verify_ssl
-        IMS.bearer_token = bearer_token
-
-    @classmethod
-    def _init_bearer_token(cls, username, password, verify_ssl=False):
-        re_init_time = time.time()
-        if not cls.bearer_token or \
-                re_init_time - cls.bearer_token_init_time \
-                > cls.TIMEOUT_THRESHOLD:
-            cls.reconnect_attempts = 0
+        self.reconnect_attempts = 0
+        self.session = requests.session()
+        self.session.verify = verify_ssl
+
+    @property
+    def bearer_token(self):
+        try:
+            return self.session.headers['Authorization'].replace('Bearer ', '')
+        except KeyError:
+            return None
+
+    @bearer_token.setter
+    def bearer_token(self, value):
+        if value is None:
+            self.session.headers.pop('Authorization', None)
         else:
-            cls.reconnect_attempts += 1
-            if cls.reconnect_attempts > cls.PERMITTED_RECONNECT_ATTEMPTS:
+            self.session.headers.update({'Authorization': f'Bearer {value}'})
+
+    def __enter__(self):
+        self._init_bearer_token()
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.logout()
+
+    def __del__(self):
+        self.logout()
+
+    def _is_heartbeat_ok(self):
+        if not self.bearer_token:
+            return False
+
+        for i in range(IMS.PERMITTED_RECONNECT_ATTEMPTS):
+            try:
+                response = self.session.get(self.base_url + self.HEARTBEAT_PATH)
+                return response.status_code == requests.codes.ok
+            except HTTPError as e:
+                logger.error(f'Error checking heartbeat (attempt {i + 1}): {e}')
+        raise IMSError("Failed to get heartbeat")
+
+    def _init_bearer_token(self):
+        if not self.bearer_token:
+            self.reconnect_attempts = 0
+        elif self._is_heartbeat_ok():
+            return
+        else:
+            self.reconnect_attempts += 1
+            if self.reconnect_attempts > self.PERMITTED_RECONNECT_ATTEMPTS:
                 raise IMSError('Too many reconnection attempts made')
 
-        logger.debug(f'Logging in - Username: {username}'
-                     f' - URL: {cls.base_url + cls.LOGIN_PATH}')
-        response = requests.post(
-            cls.base_url + cls.LOGIN_PATH,
-            auth=(username, password), verify=verify_ssl)
+        logger.debug(f'Logging in - Username: {self.username}'
+                     f' - URL: {self.base_url + self.LOGIN_PATH}')
+        response = self.session.post(
+            self.base_url + self.LOGIN_PATH,
+            auth=(self.username, self.password))
         response.raise_for_status()
-        cls.bearer_token_init_time = re_init_time
-        cls.bearer_token = response.text
+        self.bearer_token = response.text
+
+    def logout(self):
+        if self.bearer_token is None:
+            return
+        logout_url = self.base_url + self.LOGOUT_PATH
+        response = self.session.get(logout_url)
+        try:
+            response.raise_for_status()
+        except HTTPError as e:
+            logger.error(f'Error logging out: {e}')
+        self.bearer_token = None
 
     def clear_dynamic_context_cache(self):
-        if not IMS.bearer_token:
-            IMS._init_bearer_token(
-                self.username, self.password, self.verify_ssl)
+        if not self.bearer_token:
+            self._init_bearer_token()
         while True:
             logger.info('Clearing Dynamic Context Cache')
-            response = requests.put(
+            response = self.session.put(
                 f'{self.base_url + IMS.IMS_PATH}/ClearDynamicContextCache',
-                headers={'Authorization': f'Bearer {self.bearer_token}'},
-                verify=self.verify_ssl)
+            )
             if response.status_code == 401:
-                IMS._init_bearer_token(
-                    self.username, self.password, self.verify_ssl)
+                self._init_bearer_token()
                 continue
             response.raise_for_status()
             break
@@ -257,9 +293,8 @@ class IMS(object):
             params['navigationproperty'] = navigation_properties if isinstance(
                 navigation_properties, int) else sum(navigation_properties)
 
-        if not IMS.bearer_token:
-            IMS._init_bearer_token(
-                self.username, self.password, self.verify_ssl)
+        if not self.bearer_token:
+            self._init_bearer_token()
 
         def _is_invalid_login_state(response_):
             if response_.status_code == requests.codes.unauthorized:
@@ -275,11 +310,11 @@ class IMS(object):
                 except Exception as e:
                     t = response_.text
                     if len(t) > 100:
-                        message_text = f"{t[:50]} ... {t[-50:]}"
+                        message_text = f'{t[:50]} ... {t[-50:]}'
                     else:
                         message_text = t
-                    logger.debug(f"unexpected response: {message_text}\n{e}"
-                                 "\nre-raising")
+                    logger.debug(f'unexpected response: {message_text}\n{e}'
+                                 '\nre-raising')
                     raise e
                 if r and 'haserrors' in r and r['haserrors']:
                     for e in r['Errors']:
@@ -302,13 +337,9 @@ class IMS(object):
             return source
 
         while True:
-            response = requests.get(
-                url,
-                headers={'Authorization': f'Bearer {self.bearer_token}'},
-                params=params, verify=self.verify_ssl)
+            response = self.session.get(url, params=params)
             if _is_invalid_login_state(response):
-                IMS._init_bearer_token(
-                    self.username, self.password, self.verify_ssl)
+                self._init_bearer_token()
             else:
                 response.raise_for_status()
                 orig = response.json()
@@ -376,12 +407,12 @@ class IMS(object):
                 elif r.status_code == 504:
                     gateway_error_count += 1
                     logger.debug(
-                        f"GATEWAY TIME-OUT for {url}"
-                        f"  --  COUNT: {gateway_error_count}")
+                        f'GATEWAY TIME-OUT for {url}'
+                        f'  --  COUNT: {gateway_error_count}')
                     if gateway_error_count > 4:
                         raise e
                     time.sleep(5)
-                    logger.debug("WAKING UP")
+                    logger.debug('WAKING UP')
                     more_to_come = True
                     continue
                 else:
diff --git a/test/test_ims.py b/test/test_ims.py
index 44b83313a7acbff0c5be7e9e3d329acbb1fe782d..afa88c2b27f4a10f59dbd9aa6ec81a916af0bee2 100644
--- a/test/test_ims.py
+++ b/test/test_ims.py
@@ -1,160 +1,193 @@
-from requests import HTTPError
+import json
 
-import inventory_provider
-from inventory_provider.db.ims import IMSError
+import requests
+from requests import HTTPError
+import pytest
+from requests.models import Response
 
+from inventory_provider.db.ims import IMS, IMSError
 
-class MockResponse:
-    def __init__(self, json_data, status_code):
-        self.json_data = json_data
-        self.status_code = status_code
-        self.text = '' if json_data else 'No records found for Entity:XXXXX'
 
-    def json(self):
-        return self.json_data
+class PatchedIMS(IMS):
 
-    def raise_for_status(self):
+    def __del__(self):
+        # patched to stop call to logout
         pass
 
 
-def test_ims_class_login(mocker):
-    mock_post = mocker.patch('inventory_provider.db.ims.requests.post')
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    mock_post.return_value = MockResponse("my_bearer_token", 200)
-    mock_post.return_value.text = "my_bearer_token"
-
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password')
-    ds.get_entity_by_id('Node', 1234)
-    mock_post.assert_called_once_with(
-        'dummy_base/login',
-        auth=('dummy_username', 'dummy_password'), verify=False)
-    mock_get.assert_called_once_with(
-        'dummy_base/ims/Node/1234',
-        headers={'Authorization': 'Bearer my_bearer_token'},
-        params=None, verify=False)
-
-
-def test_ims_failed_response(mocker):
-    mock_post = mocker.patch('inventory_provider.db.ims.requests.post')
-    mock_post.return_value = MockResponse("my_bearer_token", 200)
-    mock_post.return_value.text = "my_bearer_token"
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    mock_get.return_value.status_code = 200
-    mock_get.return_value.json.side_effect = IMSError("IMS exception")
-
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password')
-
-    try:
-        ds.get_entity_by_id('Node', 1234)
-    except Exception as e:
-        assert isinstance(e, IMSError)
-
-
-def test_ims_class_entity_by_id(mocker):
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt')
-
-    ds.get_entity_by_id('Node', 1234)
-    mock_get.assert_called_once_with(
-        'dummy_base/ims/Node/1234',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params=None, verify=False)
-
-
-def test_ims_class_entity_by_name(mocker):
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt')
-
-    ds.get_entity_by_name('Node', 'dummy_name')
-    mock_get.assert_called_once_with(
-        'dummy_base/ims/Node/byname/"dummy_name"',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params=None, verify=False)
-
-
-def test_ims_class_filtered_entities(mocker):
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt')
-
-    list(ds.get_filtered_entities(
-        'Node', 'dummy_param=dummy value', step_count=50))
-    mock_get.assert_called_once_with(
-        'dummy_base/ims/Node/filtered/dummy_param=dummy value',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params={
-            'paginatorStartElement': 0,
-            'paginatorNumberOfElements': 50
-        }, verify=False)
-
-    def side_effect(*args, **kargs):
-        if kargs['params']['paginatorStartElement'] == 0:
-            return MockResponse([1, 2], 200)
-        return MockResponse([3], 200)
-
-    mock_multi_get = mocker.patch(
-        'inventory_provider.db.ims.requests.get', side_effect=side_effect)
-
-    res = list(ds.get_filtered_entities(
-        'Node', 'dummy_param=dummy value', step_count=2))
-    mock_multi_get.assert_any_call(
-        'dummy_base/ims/Node/filtered/dummy_param=dummy value',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params={
-            'paginatorStartElement': 0,
-            'paginatorNumberOfElements': 2
-        }, verify=False)
-    mock_multi_get.assert_any_call(
-        'dummy_base/ims/Node/filtered/dummy_param=dummy value',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params={
-            'paginatorStartElement': 2,
-            'paginatorNumberOfElements': 2
-        }, verify=False)
-    assert mock_multi_get.call_count == 2
-    assert res == [1, 2, 3]
-
-    def side_effect_no_recs(*args, **kargs):
-        if kargs['params']['paginatorStartElement'] == 0:
-            return MockResponse([1, 2], 200)
-        e = HTTPError()
-        e.response = MockResponse('', 404)
-        raise e
-
-    mocker.patch('inventory_provider.db.ims.requests.get',
-                 side_effect=side_effect_no_recs)
-
-    res = list(ds.get_filtered_entities(
-        'Node', 'dummy_param=dummy value', step_count=2))
-    assert res == [1, 2]
-
-
-def test_ims_class_get_all_entities(mocker):
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt')
-
-    list(ds.get_all_entities('Node', step_count=10))
-    mock_get.assert_called_once_with(
-        'dummy_base/ims/Node/all',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params={
-            'paginatorStartElement': 0,
-            'paginatorNumberOfElements': 10
-        }, verify=False)
-
-
-def test_ims_class_navigation_properties(mocker):
-    mock_get = mocker.patch('inventory_provider.db.ims.requests.get')
-    ds = inventory_provider.db.ims.IMS(
-        'dummy_base', 'dummy_username', 'dummy_password', 'dummy_bt')
-
-    ds.get_entity_by_id('Node', 1234, navigation_properties=[1, 2, 3])
-    mock_get.assert_called_with(
-        'dummy_base/ims/Node/1234',
-        headers={'Authorization': 'Bearer dummy_bt'},
-        params={'navigationproperty': 6}, verify=False)
+BASE_URL = "http://dummy_base"
+USERNAME = "dummy_username"
+PASSWORD = "dummy_password"
+
+
+@pytest.fixture
+def ims_instance(mocker):
+    ims_instance = PatchedIMS(BASE_URL, USERNAME, PASSWORD)
+    mock_post = mocker.patch.object(ims_instance.session, 'post')
+    mock_post.return_value.status_code = 200
+    mock_post.return_value.text = 'dummy_bt'
+    return ims_instance
+
+
+def test_init_bearer_token(ims_instance):
+
+    ims_instance._init_bearer_token()
+    assert ims_instance.bearer_token == 'dummy_bt'
+    assert ims_instance.reconnect_attempts == 0
+
+
+def test_heartbeat(mocker, ims_instance):
+
+    # there should be no bearer token set
+    assert ims_instance._is_heartbeat_ok() is False
+
+    # set the bearer token
+    ims_instance.bearer_token = 'dummy_bt'
+
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+
+    mocked_get.return_value.status_code = 200
+    assert ims_instance._is_heartbeat_ok()
+
+    mocked_get.return_value.status_code = 401
+    assert ims_instance._is_heartbeat_ok() is False
+
+    dummy_exception = HTTPError('http://dummy_base/heartbeat', 500, 'Dummy', None, None)
+    responses = [dummy_exception] * IMS.PERMITTED_RECONNECT_ATTEMPTS
+    mocked_get.side_effect = responses
+    with pytest.raises(IMSError) as exception_info:
+        ims_instance._is_heartbeat_ok()
+        assert exception_info.value.message == 'Dummy'
+
+    r = Response()
+    r.status_code = 200
+    responses[-1] = r
+    mocked_get.side_effect = responses
+    assert ims_instance._is_heartbeat_ok()
+
+    r.status_code = 401
+    responses[-1] = r
+    mocked_get.side_effect = responses
+    assert ims_instance._is_heartbeat_ok() is False
+
+
+def test_get_entity_by_id(mocker, ims_instance):
+    entity = 'dummy_entity_type'
+    entity_id = '12345'
+    expected_entity = {'id': entity_id, 'name': 'Test Entity'}
+
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocked_get.return_value.status_code = 200
+    mocked_get.return_value.json.return_value = expected_entity
+
+    result = ims_instance.get_entity_by_id(entity, entity_id)
+
+    assert result == expected_entity
+
+
+def test_get_entity_by_name(mocker, ims_instance):
+    entity = 'dummy_entity_tyoe'
+    entity_name = 'dummy_entity_name'
+    expected_entity = {'id': '1234', 'name': entity_name}
+
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocked_get.return_value.status_code = 200
+    mocked_get.return_value.json.return_value = expected_entity
+
+    result = ims_instance.get_entity_by_name(entity, entity_name)
+
+    assert result == expected_entity
+
+
+def test_get_filtered_entities(mocker, ims_instance):
+    entity = 'dummy_entity_tyoe'
+    entity_filter = 'dummy_param=dummy value'
+    expected_result = [{'id': '1234', 'name': 'entity_name'}]
+
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocked_get.return_value.status_code = 200
+    mocked_get.return_value.json.return_value = expected_result
+
+    result = ims_instance.get_filtered_entities(entity, entity_filter)
+
+    assert list(result) == expected_result
+
+
+def test_get_entities(mocker, ims_instance):
+    entity = 'dummy_entity_type'
+    expected_result = [{'id': '1234', 'name': 'entity_name'}]
+
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocked_get.return_value.status_code = 200
+    mocked_get.return_value.json.return_value = expected_result
+
+    result = ims_instance.get_entities(entity)
+
+    assert list(result) == expected_result
+
+
+@pytest.mark.parametrize('step_count', [1, 3, 4, 9, 10])
+def test_get_paged_entities(step_count, ims_instance, mocker):
+    entity = 'dummy_entity_type'
+    expected_result = [
+        {'id': '12341', 'name': 'entity_name_1'},
+        {'id': '12342', 'name': 'entity_name_2'},
+        {'id': '12343', 'name': 'entity_name_3'},
+        {'id': '12344', 'name': 'entity_name_4'},
+        {'id': '12345', 'name': 'entity_name_5'},
+        {'id': '12346', 'name': 'entity_name_6'},
+        {'id': '12347', 'name': 'entity_name_7'},
+        {'id': '12348', 'name': 'entity_name_8'},
+        {'id': '12349', 'name': 'entity_name_9'},
+    ]
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocked_responses = []
+
+    for i in range(0, len(expected_result) + 1, step_count):
+        resp = Response()
+        resp.status_code = 200
+        if i >= len(expected_result):
+            resp_content = {}
+        else:
+            resp_content = expected_result[i:i + step_count]
+        resp._content = str.encode(json.dumps(resp_content))
+        mocked_responses.append(resp)
+    mocked_get.side_effect = mocked_responses
+
+    result = ims_instance.get_all_entities(entity, step_count=step_count)
+    assert list(result) == expected_result
+
+
+def test_context_manager(ims_instance, mocker):
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocked_get.return_value.status_code = 200
+    mocked_get.return_value.json.return_value = []
+
+    with ims_instance as ims:
+        list(ims_instance.get_all_entities('dummy'))
+        assert ims.bearer_token == 'dummy_bt'
+
+    assert ims_instance.bearer_token is None
+
+
+def test_navigation_properties(ims_instance, mocker):
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+
+    ims_instance.get_entity_by_id('Node', 1234, navigation_properties=[1, 2, 3])
+    mocked_get.assert_called_with(
+        f'{BASE_URL}{IMS.IMS_PATH}/Node/1234', params={'navigationproperty': 6})
+
+
+def test_get_entities_gateway_error(ims_instance, mocker):
+    mocked_get = mocker.patch.object(ims_instance.session, 'get')
+    mocker.patch('inventory_provider.db.ims.time')
+
+    error_message = 'Dummy Gateway Error'
+    r = Response()
+    r.status_code = requests.codes.gateway_timeout
+    r._content = str.encode(error_message)
+    mocked_get.side_effect = [r] * 5
+    with pytest.raises(HTTPError) as exception_info:
+        list(ims_instance.get_entities('dummy'))
+        assert exception_info.value.message == error_message
+        assert mocked_get.call_count == 4