From 6e62575f75375af493a76a394264a293cb80c7b8 Mon Sep 17 00:00:00 2001
From: Remco Tukker <remco.tukker@geant.org>
Date: Thu, 4 May 2023 10:00:14 +0200
Subject: [PATCH] use flask-sqlalchemy for the main db

---
 compendium_v2/__init__.py                     |  12 +-
 compendium_v2/db/__init__.py                  |  50 +-
 compendium_v2/db/model.py                     |  48 +-
 compendium_v2/migrations/env.py               |   2 +-
 compendium_v2/migrations/migration_utils.py   |   8 +-
 compendium_v2/publishers/helpers.py           |  13 +-
 .../publishers/survey_publisher_2022.py       | 485 +++++++++---------
 .../publishers/survey_publisher_v1.py         | 272 +++++-----
 compendium_v2/routes/budget.py                |  31 +-
 compendium_v2/routes/charging.py              |  29 +-
 compendium_v2/routes/ec_projects.py           |  24 +-
 compendium_v2/routes/funding.py               |  27 +-
 compendium_v2/routes/organization.py          |  29 +-
 compendium_v2/routes/staff.py                 |  24 +-
 compendium_v2/survey_db/__init__.py           |   5 -
 requirements.txt                              |   1 +
 setup.py                                      |   1 +
 test/conftest.py                              | 122 ++---
 test/test_survey_publisher_2022.py            |  28 +-
 test/test_survey_publisher_v1.py              |  23 +-
 20 files changed, 545 insertions(+), 689 deletions(-)

diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py
index 622565f2..7a69bf9e 100644
--- a/compendium_v2/__init__.py
+++ b/compendium_v2/__init__.py
@@ -8,7 +8,7 @@ from flask import Flask
 from flask_cors import CORS  # for debugging
 
 from compendium_v2 import config, environment
-
+from compendium_v2.db import db
 from compendium_v2.migrations import migration_utils
 
 
@@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask:
     return app
 
 
+def _create_app_with_db(app_config) -> Flask:
+    # used by the tests and the publishers
+    app = _create_app(app_config)
+    app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI']
+    db.init_app(app)
+    return app
+
+
 def create_app() -> Flask:
     """
     overrides default settings with those found
@@ -46,7 +54,7 @@ def create_app() -> Flask:
     with open(os.environ['SETTINGS_FILENAME']) as f:
         app_config = config.load(f)
 
-    app = _create_app(app_config)
+    app = _create_app_with_db(app_config)
 
     logging.info('Flask app initialized')
 
diff --git a/compendium_v2/db/__init__.py b/compendium_v2/db/__init__.py
index ce48aaa0..13719e0a 100644
--- a/compendium_v2/db/__init__.py
+++ b/compendium_v2/db/__init__.py
@@ -1,45 +1,17 @@
-import contextlib
 import logging
-from typing import Optional, Union, Callable, Iterator
 
-from sqlalchemy import create_engine
-from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.orm import sessionmaker, Session
+from flask_sqlalchemy import SQLAlchemy
+from sqlalchemy import MetaData
 
-logger = logging.getLogger(__name__)
-_SESSION_MAKER: Union[None, sessionmaker] = None
-
-
-@contextlib.contextmanager
-def session_scope(
-        callback_before_close: Optional[Callable] = None) -> Iterator[Session]:
-    # best practice is to keep session scope separate from data processing
-    # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html
-
-    assert _SESSION_MAKER
-    session = _SESSION_MAKER()
-    try:
-        yield session
-        session.commit()
-        if callback_before_close:
-            callback_before_close()
-    except SQLAlchemyError:
-        logger.error('caught sql layer exception, rolling back')
-        session.rollback()
-        raise  # re-raise, will be handled by main consumer
-    finally:
-        session.close()
-
-
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
-    return (f'postgresql://{db_username}:{db_password}'
-            f'@{db_hostname}:{port}/{db_name}')
 
+logger = logging.getLogger(__name__)
 
-def init_db_model(dsn):
-    global _SESSION_MAKER
+metadata_obj = MetaData(naming_convention={
+    "ix": "ix_%(column_0_label)s",
+    "uq": "uq_%(table_name)s_%(column_0_name)s",
+    "ck": "ck_%(table_name)s_%(constraint_name)s",
+    "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+    "pk": "pk_%(table_name)s",
+})
 
-    # cf. https://docs.sqlalchemy.org/en
-    #        /latest/orm/extensions/automap.html
-    engine = create_engine(dsn, pool_size=10)
-    _SESSION_MAKER = sessionmaker(bind=engine)
+db = SQLAlchemy(metadata=metadata_obj)
diff --git a/compendium_v2/db/model.py b/compendium_v2/db/model.py
index 67672743..2199f759 100644
--- a/compendium_v2/db/model.py
+++ b/compendium_v2/db/model.py
@@ -4,42 +4,34 @@ from enum import Enum
 from typing import Optional
 from typing_extensions import Annotated
 
-from sqlalchemy import MetaData, String
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from sqlalchemy import String
+from sqlalchemy.orm import Mapped, mapped_column, relationship
 from sqlalchemy.schema import ForeignKey
 
+from compendium_v2.db import db
 
-logger = logging.getLogger(__name__)
-
-convention = {
-    "ix": "ix_%(column_0_label)s",
-    "uq": "uq_%(table_name)s_%(column_0_name)s",
-    "ck": "ck_%(table_name)s_%(constraint_name)s",
-    "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
-    "pk": "pk_%(table_name)s",
-}
 
-metadata_obj = MetaData(naming_convention=convention)
+logger = logging.getLogger(__name__)
 
 str128 = Annotated[str, 128]
+str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)]
+str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)]
 int_pk = Annotated[int, mapped_column(primary_key=True)]
 int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)]
 
 
-class Base(DeclarativeBase):
-    metadata = metadata_obj
-    type_annotation_map = {
-        str128: String(128),
-    }
+# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet.
+# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140
+# mypy: disable-error-code="name-defined"
 
 
-class NREN(Base):
+class NREN(db.Model):
     __tablename__ = 'nren'
     id: Mapped[int_pk]
     name: Mapped[str128]
 
 
-class BudgetEntry(Base):
+class BudgetEntry(db.Model):
     __tablename__ = 'budgets'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -47,7 +39,7 @@ class BudgetEntry(Base):
     budget: Mapped[Decimal]
 
 
-class FundingSource(Base):
+class FundingSource(db.Model):
     __tablename__ = 'funding_source'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -67,7 +59,7 @@ class FeeType(Enum):
     other = "other"
 
 
-class ChargingStructure(Base):
+class ChargingStructure(db.Model):
     __tablename__ = 'charging_structure'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -75,7 +67,7 @@ class ChargingStructure(Base):
     fee_type: Mapped[Optional[FeeType]]
 
 
-class NrenStaff(Base):
+class NrenStaff(db.Model):
     __tablename__ = 'nren_staff'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -86,7 +78,7 @@ class NrenStaff(Base):
     non_technical_fte: Mapped[Decimal]
 
 
-class ParentOrganization(Base):
+class ParentOrganization(db.Model):
     __tablename__ = 'parent_organization'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -94,18 +86,18 @@ class ParentOrganization(Base):
     organization: Mapped[str128]
 
 
-class SubOrganization(Base):
+class SubOrganization(db.Model):
     __tablename__ = 'sub_organization'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
     year: Mapped[int_pk]
-    organization: Mapped[str128] = mapped_column(primary_key=True)
+    organization: Mapped[str128_pk]
     role: Mapped[str128]
 
 
-class ECProject(Base):
+class ECProject(db.Model):
     __tablename__ = 'ec_project'
     nren_id: Mapped[int_pk_fkNREN]
-    nren: Mapped[NREN] = relationship(NREN, lazy='joined')
+    nren: Mapped[NREN] = relationship(lazy='joined')
     year: Mapped[int_pk]
-    project: Mapped[str] = mapped_column(String(256), primary_key=True)
+    project: Mapped[str256_pk]
diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py
index 0307be33..5ea9c8d3 100644
--- a/compendium_v2/migrations/env.py
+++ b/compendium_v2/migrations/env.py
@@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config
 from sqlalchemy import pool
 
 from alembic import context
-from compendium_v2.db.model import metadata_obj
+from compendium_v2.db import metadata_obj
 
 # this is the Alembic Config object, which provides
 # access to the values within the .ini file in use.
diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py
index 3b29b540..f25f03c9 100644
--- a/compendium_v2/migrations/migration_utils.py
+++ b/compendium_v2/migrations/migration_utils.py
@@ -1,7 +1,6 @@
 import logging
 import os
 
-from compendium_v2 import db
 from alembic.config import Config
 from alembic import command
 
@@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY):
     command.upgrade(alembic_config, 'head')
 
 
+def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
+    return (f'postgresql://{db_username}:{db_password}'
+            f'@{db_hostname}:{port}/{db_name}')
+
+
 if __name__ == "__main__":
     logging.basicConfig(level=logging.DEBUG)
-    upgrade(db.postgresql_dsn(
+    upgrade(postgresql_dsn(
         db_username='compendium',
         db_password='compendium321',
         db_hostname='localhost',
diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py
index d95ad4e1..43d1bf02 100644
--- a/compendium_v2/publishers/helpers.py
+++ b/compendium_v2/publishers/helpers.py
@@ -1,21 +1,20 @@
-from compendium_v2 import db, survey_db
-from compendium_v2.db import model
+from sqlalchemy import select
+
+from compendium_v2 import survey_db
+from compendium_v2.db import db, model
 
 
 def init_db(config):
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
     dsn_survey = config['SURVEY_DATABASE_URI']
     survey_db.init_db_model(dsn_survey)
 
 
-def get_uppercase_nren_dict(session):
+def get_uppercase_nren_dict():
     """
-    :param session: db session that is used to query the known NRENs
     :return: a dictionary of all known NRENs db entities keyed on the
              uppercased name
     """
-    current_nrens = session.query(model.NREN).all()
+    current_nrens = db.session.scalars(select(model.NREN))
     nren_dict = {nren.name.upper(): nren for nren in current_nrens}
     # add aliases that are used in the source data:
     nren_dict['ASNET'] = nren_dict['ASNET-AM']
diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py
index 710c899d..ea88ee77 100644
--- a/compendium_v2/publishers/survey_publisher_2022.py
+++ b/compendium_v2/publishers/survey_publisher_2022.py
@@ -16,11 +16,12 @@ import html
 from sqlalchemy import text
 from collections import defaultdict
 
+import compendium_v2
 from compendium_v2.db.model import FeeType
 from compendium_v2.environment import setup_logging
 from compendium_v2.config import load
-from compendium_v2 import db, survey_db
-from compendium_v2.db import model
+from compendium_v2 import survey_db
+from compendium_v2.db import db, model
 from compendium_v2.publishers import helpers
 
 setup_logging()
@@ -133,163 +134,151 @@ def query_question(question: enum.Enum):
         return survey.execute(text(query))
 
 
-def transfer_budget():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        rows = query_budget()
-        for row in rows:
+def transfer_budget(nren_dict):
+    rows = query_budget()
+    for row in rows:
+        nren_name = row[0].upper()
+        _budget = row[1]
+        try:
+            budget = float(_budget.replace('"', '').replace(',', ''))
+        except ValueError:
+            logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))')
+            continue
+
+        if budget > 200:
+            logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})')
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
+
+        budget_entry = model.BudgetEntry(
+            nren=nren_dict[nren_name],
+            budget=budget,
+            year=2022,
+        )
+        db.session.merge(budget_entry)
+    db.session.commit()
+
+
+def transfer_funding_sources(nren_dict):
+    sourcedata = {}
+    for source, data in query_funding_sources():
+        for row in data:
             nren_name = row[0].upper()
-            _budget = row[1]
+            _value = row[1]
             try:
-                budget = float(_budget.replace('"', '').replace(',', ''))
+                value = float(_value.replace('"', '').replace(',', ''))
             except ValueError:
-                logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))')
-                continue
-
-            if budget > 200:
-                logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})')
-
-            if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping.')
-                continue
+                name = source.name
+                logger.info(f'{nren_name} has invalid value for {name}. ({_value}))')
+                value = 0
 
-            budget_entry = model.BudgetEntry(
-                nren=nren_dict[nren_name],
-                budget=budget,
-                year=2022,
+            nren_info = sourcedata.setdefault(
+                nren_name,
+                {source_type: 0 for source_type in FundingSource}
             )
-            session.merge(budget_entry)
-        session.commit()
-
-
-def transfer_funding_sources():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        sourcedata = {}
-        for source, data in query_funding_sources():
-            for row in data:
-                nren_name = row[0].upper()
-                _value = row[1]
-                try:
-                    value = float(_value.replace('"', '').replace(',', ''))
-                except ValueError:
-                    name = source.name
-                    logger.info(f'{nren_name} has invalid value for {name}. ({_value}))')
-                    value = 0
-
-                nren_info = sourcedata.setdefault(
-                    nren_name,
-                    {source_type: 0 for source_type in FundingSource}
-                )
-                nren_info[source] = value
-
-        for nren_name, nren_info in sourcedata.items():
-            total = sum(nren_info.values())
-
-            if not math.isclose(total, 100, abs_tol=0.01):
-                logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})')
+            nren_info[source] = value
+
+    for nren_name, nren_info in sourcedata.items():
+        total = sum(nren_info.values())
+
+        if not math.isclose(total, 100, abs_tol=0.01):
+            logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})')
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
+
+        funding_source = model.FundingSource(
+            nren=nren_dict[nren_name],
+            year=2022,
+            client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS],
+            european_funding=nren_info[FundingSource.EUROPEAN_FUNDING],
+            gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES],
+            commercial=nren_info[FundingSource.COMMERCIAL],
+            other=nren_info[FundingSource.OTHER],
+        )
+        db.session.merge(funding_source)
+    db.session.commit()
+
+
+def transfer_staff_data(nren_dict):
+    data = {}
+    for question in StaffQuestion:
+        rows = query_question(question)
+        for row in rows:
+            nren_name = row[0].upper()
+            _value = row[1]
+            try:
+                value = float(_value.replace('"', '').replace(',', ''))
+            except ValueError:
+                value = 0
 
             if nren_name not in nren_dict:
                 logger.info(f'{nren_name} unknown. Skipping.')
                 continue
 
-            funding_source = model.FundingSource(
-                nren=nren_dict[nren_name],
-                year=2022,
-                client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS],
-                european_funding=nren_info[FundingSource.EUROPEAN_FUNDING],
-                gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES],
-                commercial=nren_info[FundingSource.COMMERCIAL],
-                other=nren_info[FundingSource.OTHER],
-            )
-            session.merge(funding_source)
-        session.commit()
-
-
-def transfer_staff_data():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        data = {}
-        for question in StaffQuestion:
-            rows = query_question(question)
-            for row in rows:
-                nren_name = row[0].upper()
-                _value = row[1]
-                try:
-                    value = float(_value.replace('"', '').replace(',', ''))
-                except ValueError:
-                    value = 0
-
-                if nren_name not in nren_dict:
-                    logger.info(f'{nren_name} unknown. Skipping.')
-                    continue
-
-                # initialize on first use, so we don't add data for nrens with no answers
-                data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[
-                    question] = value
-
-        for nren_name, nren_info in data.items():
-            if sum([nren_info[question] for question in StaffQuestion]) == 0:
-                logger.info(f'{nren_name} has no staff data. Deleting if exists.')
-                session.query(model.NrenStaff).filter(
-                    model.NrenStaff.nren_id == nren_dict[nren_name].id,
-                    model.NrenStaff.year == 2022,
-                ).delete()
-                continue
-
-            employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
-            technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE]
-
-            if not math.isclose(employed, technical, abs_tol=0.01):
-                logger.info(f'{nren_name} FTE do not equal across employed/technical categories.'
-                            f' ({employed} != {technical})')
-
-            staff_data = model.NrenStaff(
-                nren_id=nren_dict[nren_name].id,
-                year=2022,
-                permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE],
-                subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE],
-                technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE],
-                non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE],
-            )
-            session.merge(staff_data)
-        session.commit()
-
-
-def transfer_nren_parent_org():
+            # initialize on first use, so we don't add data for nrens with no answers
+            data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[
+                question] = value
+
+    for nren_name, nren_info in data.items():
+        if sum([nren_info[question] for question in StaffQuestion]) == 0:
+            logger.info(f'{nren_name} has no staff data. Deleting if exists.')
+            db.session.query(model.NrenStaff).filter(
+                model.NrenStaff.nren_id == nren_dict[nren_name].id,
+                model.NrenStaff.year == 2022,
+            ).delete()
+            continue
+
+        employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
+        technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE]
+
+        if not math.isclose(employed, technical, abs_tol=0.01):
+            logger.info(f'{nren_name} FTE do not equal across employed/technical categories.'
+                        f' ({employed} != {technical})')
+
+        staff_data = model.NrenStaff(
+            nren_id=nren_dict[nren_name].id,
+            year=2022,
+            permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE],
+            subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE],
+            technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE],
+            non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE],
+        )
+        db.session.merge(staff_data)
+    db.session.commit()
+
+
+def transfer_nren_parent_org(nren_dict):
     # clean up the data a bit by removing some strings
     strings_to_replace = [
         'We are affiliated to '
     ]
 
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        rows = query_question(OrgQuestion.PARENT_ORG_NAME)
-        for row in rows:
-            nren_name = row[0].upper()
-            value = str(row[1]).replace('"', '')
+    rows = query_question(OrgQuestion.PARENT_ORG_NAME)
+    for row in rows:
+        nren_name = row[0].upper()
+        value = str(row[1]).replace('"', '')
 
-            for string in strings_to_replace:
-                value = value.replace(string, '')
+        for string in strings_to_replace:
+            value = value.replace(string, '')
 
-            if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping.')
-                continue
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
 
-            parent_org = model.ParentOrganization(
-                nren_id=nren_dict[nren_name].id,
-                year=2022,
-                organization=value,
-            )
-            session.merge(parent_org)
-        session.commit()
+        parent_org = model.ParentOrganization(
+            nren_id=nren_dict[nren_name].id,
+            year=2022,
+            organization=value,
+        )
+        db.session.merge(parent_org)
+    db.session.commit()
 
 
-def transfer_nren_sub_org():
+def transfer_nren_sub_org(nren_dict):
     suborg_questions = [
         (OrgQuestion.SUB_ORGS_1_NAME, OrgQuestion.SUB_ORGS_1_CHOICE, OrgQuestion.SUB_ORGS_1_ROLE),
         (OrgQuestion.SUB_ORGS_2_NAME, OrgQuestion.SUB_ORGS_2_CHOICE, OrgQuestion.SUB_ORGS_2_ROLE),
@@ -298,140 +287,134 @@ def transfer_nren_sub_org():
         (OrgQuestion.SUB_ORGS_5_NAME, OrgQuestion.SUB_ORGS_5_CHOICE, OrgQuestion.SUB_ORGS_5_ROLE)
     ]
 
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        lookup = defaultdict(list)
-
-        for name, choice, role in suborg_questions:
-            _name_rows = query_question(name)
-            _choice_rows = query_question(choice)
-            _role_rows = list(query_question(role))
-            for _name, _choice in zip(_name_rows, _choice_rows):
-                nren_name = _name[0].upper()
-                suborg_name = _name[1].replace('"', '').strip()
-                role_choice = _choice[1].replace('"', '').strip()
-
-                if nren_name not in nren_dict:
-                    logger.info(f'{nren_name} unknown. Skipping.')
-                    continue
-
-                if role_choice.lower() == 'other':
-                    for _role in _role_rows:
-                        if _role[0] == _name[0]:
-                            role = _role[1].replace('"', '').strip()
-                            break
-                else:
-                    role = role_choice
-
-                lookup[nren_name].append((suborg_name, role))
-
-        for nren_name, suborgs in lookup.items():
-            for suborg_name, role in suborgs:
-                suborg = model.SubOrganization(
-                    nren_id=nren_dict[nren_name].id,
-                    year=2022,
-                    organization=suborg_name,
-                    role=role,
-                )
-                session.merge(suborg)
-        session.commit()
-
-
-def transfer_charging_structure():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        rows = query_question(ChargingStructure.charging_structure)
-        for row in rows:
-            nren_name = row[0].upper()
-            value = row[1].replace('"', '').strip()
+    lookup = defaultdict(list)
+
+    for name, choice, role in suborg_questions:
+        _name_rows = query_question(name)
+        _choice_rows = query_question(choice)
+        _role_rows = list(query_question(role))
+        for _name, _choice in zip(_name_rows, _choice_rows):
+            nren_name = _name[0].upper()
+            suborg_name = _name[1].replace('"', '').strip()
+            role_choice = _choice[1].replace('"', '').strip()
 
             if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping from charging structure.')
+                logger.info(f'{nren_name} unknown. Skipping.')
                 continue
 
-            if "do not charge" in value:
-                charging_structure = FeeType.no_charge
-            elif "combination" in value:
-                charging_structure = FeeType.combination
-            elif "flat" in value:
-                charging_structure = FeeType.flat_fee
-            elif "usage-based" in value:
-                charging_structure = FeeType.usage_based_fee
-            elif "Other" in value:
-                charging_structure = FeeType.other
+            if role_choice.lower() == 'other':
+                for _role in _role_rows:
+                    if _role[0] == _name[0]:
+                        role = _role[1].replace('"', '').strip()
+                        break
             else:
-                charging_structure = None
+                role = role_choice
+
+            lookup[nren_name].append((suborg_name, role))
 
-            charging_structure = model.ChargingStructure(
+    for nren_name, suborgs in lookup.items():
+        for suborg_name, role in suborgs:
+            suborg = model.SubOrganization(
                 nren_id=nren_dict[nren_name].id,
                 year=2022,
-                fee_type=charging_structure,
+                organization=suborg_name,
+                role=role,
             )
-            session.merge(charging_structure)
-        session.commit()
-
-
-def transfer_ec_projects():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        # delete all existing EC projects, in case something changed
-        session.query(model.ECProject).filter(
-            model.ECProject.year == 2022,
-        ).delete()
-
-        rows = query_question(ECQuestion.EC_PROJECT)
-        for row in rows:
-            nren_name = row[0].upper()
-
-            if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping.')
+            db.session.merge(suborg)
+    db.session.commit()
+
+
+def transfer_charging_structure(nren_dict):
+    rows = query_question(ChargingStructure.charging_structure)
+    for row in rows:
+        nren_name = row[0].upper()
+        value = row[1].replace('"', '').strip()
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping from charging structure.')
+            continue
+
+        if "do not charge" in value:
+            charging_structure = FeeType.no_charge
+        elif "combination" in value:
+            charging_structure = FeeType.combination
+        elif "flat" in value:
+            charging_structure = FeeType.flat_fee
+        elif "usage-based" in value:
+            charging_structure = FeeType.usage_based_fee
+        elif "Other" in value:
+            charging_structure = FeeType.other
+        else:
+            charging_structure = None
+
+        charging_structure = model.ChargingStructure(
+            nren_id=nren_dict[nren_name].id,
+            year=2022,
+            fee_type=charging_structure,
+        )
+        db.session.merge(charging_structure)
+    db.session.commit()
+
+
+def transfer_ec_projects(nren_dict):
+    # delete all existing EC projects, in case something changed
+    db.session.query(model.ECProject).filter(
+        model.ECProject.year == 2022,
+    ).delete()
+
+    rows = query_question(ECQuestion.EC_PROJECT)
+    for row in rows:
+        nren_name = row[0].upper()
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
+
+        try:
+            value = json.loads(row[1])
+        except json.decoder.JSONDecodeError:
+            logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.')
+            continue
+
+        for val in value:
+            if not val:
+                logger.info(f'Invalid EC project value for {nren_name}: {val}.')
                 continue
 
-            try:
-                value = json.loads(row[1])
-            except json.decoder.JSONDecodeError:
-                logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.')
-                continue
-
-            for val in value:
-                if not val:
-                    logger.info(f'Invalid EC project value for {nren_name}: {val}.')
-                    continue
-
-                # strip html entities/NBSP from val
-                val = html.unescape(val).replace('\xa0', ' ')
+            # strip html entities/NBSP from val
+            val = html.unescape(val).replace('\xa0', ' ')
 
-                # some answers include contract numbers, which we don't want here
-                val = val.split('(contract n')[0]
+            # some answers include contract numbers, which we don't want here
+            val = val.split('(contract n')[0]
 
-                ec_project = model.ECProject(
-                    nren_id=nren_dict[nren_name].id,
-                    year=2022,
-                    project=str(val).strip()
-                )
-                session.add(ec_project)
-        session.commit()
+            ec_project = model.ECProject(
+                nren_id=nren_dict[nren_name].id,
+                year=2022,
+                project=str(val).strip()
+            )
+            db.session.add(ec_project)
+    db.session.commit()
 
 
-def _cli(config):
+def _cli(config, app):
     helpers.init_db(config)
-    transfer_budget()
-    transfer_funding_sources()
-    transfer_staff_data()
-    transfer_nren_parent_org()
-    transfer_nren_sub_org()
-    transfer_charging_structure()
-    transfer_ec_projects()
+    with app.app_context():
+        nren_dict = helpers.get_uppercase_nren_dict()
+        transfer_budget(nren_dict)
+        transfer_funding_sources(nren_dict)
+        transfer_staff_data(nren_dict)
+        transfer_nren_parent_org(nren_dict)
+        transfer_nren_sub_org(nren_dict)
+        transfer_charging_structure(nren_dict)
+        transfer_ec_projects(nren_dict)
 
 
 @click.command()
 @click.option('--config', type=click.STRING, default='config.json')
 def cli(config):
     app_config = load(open(config, 'r'))
-    _cli(app_config)
+    app = compendium_v2._create_app_with_db(app_config)
+    _cli(app_config, app)
 
 
 if __name__ == "__main__":
diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py
index ba0e4e2a..5f4cd8b2 100644
--- a/compendium_v2/publishers/survey_publisher_v1.py
+++ b/compendium_v2/publishers/survey_publisher_v1.py
@@ -11,11 +11,12 @@ import logging
 import math
 import click
 
+import compendium_v2
 from compendium_v2.environment import setup_logging
-from compendium_v2 import db, survey_db
+from compendium_v2 import survey_db
 from compendium_v2.background_task import parse_excel_data
 from compendium_v2.config import load
-from compendium_v2.db import model
+from compendium_v2.db import db, model
 from compendium_v2.survey_db import model as survey_model
 from compendium_v2.publishers import helpers
 
@@ -24,11 +25,8 @@ setup_logging()
 logger = logging.getLogger('survey-publisher-v1')
 
 
-def db_budget_migration():
-    with survey_db.session_scope() as survey_session, \
-            db.session_scope() as session:
-
-        nren_dict = helpers.get_uppercase_nren_dict(session)
+def db_budget_migration(nren_dict):
+    with survey_db.session_scope() as survey_session:
 
         # move data from Survey DB budget table
         data = survey_session.query(survey_model.Nrens)
@@ -49,7 +47,7 @@ def db_budget_migration():
                     budget=float(budget.budget),
                     year=year
                 )
-                session.merge(budget_entry)
+                db.session.merge(budget_entry)
 
         # Import the data from excel sheet to database
         exceldata = parse_excel_data.fetch_budget_excel_data()
@@ -63,165 +61,153 @@ def db_budget_migration():
                 logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})')
 
             budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year)
-            session.merge(budget_entry)
-        session.commit()
-
-
-def db_funding_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        # Import the data to database
-        data = parse_excel_data.fetch_funding_excel_data()
-
-        for (abbrev, year, client_institution,
-             european_funding,
-             gov_public_bodies,
-             commercial, other) in data:
-
-            _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
-            total = sum(_data)
-            if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
-                logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
-
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
-
-            budget_entry = model.FundingSource(
-                nren=nren_dict[abbrev],
-                year=year,
-                client_institutions=client_institution,
-                european_funding=european_funding,
-                gov_public_bodies=gov_public_bodies,
-                commercial=commercial,
-                other=other)
-            session.merge(budget_entry)
-        session.commit()
-
-
-def db_charging_structure_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        # Import the data to database
-        data = parse_excel_data.fetch_charging_structure_excel_data()
-
-        for (abbrev, year, charging_structure) in data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
-
-            charging_structure_entry = model.ChargingStructure(
-                nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
-            session.merge(charging_structure_entry)
-        session.commit()
-
-
-def db_staffing_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        staff_data = parse_excel_data.fetch_staffing_excel_data()
-
-        nren_staff_map = {}
-        for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping staff data.')
-                continue
-
-            nren = nren_dict[abbrev]
+            db.session.merge(budget_entry)
+        db.session.commit()
+
+
+def db_funding_migration(nren_dict):
+    # Import the data to database
+    data = parse_excel_data.fetch_funding_excel_data()
+
+    for (abbrev, year, client_institution,
+            european_funding,
+            gov_public_bodies,
+            commercial, other) in data:
+
+        _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
+        total = sum(_data)
+        if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
+            logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
+
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
+
+        budget_entry = model.FundingSource(
+            nren=nren_dict[abbrev],
+            year=year,
+            client_institutions=client_institution,
+            european_funding=european_funding,
+            gov_public_bodies=gov_public_bodies,
+            commercial=commercial,
+            other=other)
+        db.session.merge(budget_entry)
+    db.session.commit()
+
+
+def db_charging_structure_migration(nren_dict):
+    # Import the data to database
+    data = parse_excel_data.fetch_charging_structure_excel_data()
+
+    for (abbrev, year, charging_structure) in data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
+
+        charging_structure_entry = model.ChargingStructure(
+            nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
+        db.session.merge(charging_structure_entry)
+    db.session.commit()
+
+
+def db_staffing_migration(nren_dict):
+    staff_data = parse_excel_data.fetch_staffing_excel_data()
+
+    nren_staff_map = {}
+    for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping staff data.')
+            continue
+
+        nren = nren_dict[abbrev]
+        nren_staff_map[(nren.id, year)] = model.NrenStaff(
+            nren=nren,
+            nren_id=nren.id,
+            year=year,
+            permanent_fte=permanent_fte,
+            subcontracted_fte=subcontracted_fte,
+            technical_fte=0,
+            non_technical_fte=0
+        )
+
+    function_data = parse_excel_data.fetch_staff_function_excel_data()
+    for (abbrev, year, technical_fte, non_technical_fte) in function_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping staff function data.')
+            continue
+
+        nren = nren_dict[abbrev]
+        if (nren.id, year) in nren_staff_map:
+            nren_staff_map[(nren.id, year)].technical_fte = technical_fte
+            nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
+        else:
             nren_staff_map[(nren.id, year)] = model.NrenStaff(
                 nren=nren,
                 nren_id=nren.id,
                 year=year,
-                permanent_fte=permanent_fte,
-                subcontracted_fte=subcontracted_fte,
-                technical_fte=0,
-                non_technical_fte=0
+                permanent_fte=0,
+                subcontracted_fte=0,
+                technical_fte=technical_fte,
+                non_technical_fte=non_technical_fte
             )
 
-        function_data = parse_excel_data.fetch_staff_function_excel_data()
-        for (abbrev, year, technical_fte, non_technical_fte) in function_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping staff function data.')
-                continue
-
-            nren = nren_dict[abbrev]
-            if (nren.id, year) in nren_staff_map:
-                nren_staff_map[(nren.id, year)].technical_fte = technical_fte
-                nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
-            else:
-                nren_staff_map[(nren.id, year)] = model.NrenStaff(
-                    nren=nren,
-                    nren_id=nren.id,
-                    year=year,
-                    permanent_fte=0,
-                    subcontracted_fte=0,
-                    technical_fte=technical_fte,
-                    non_technical_fte=non_technical_fte
-                )
-
-        for nren_staff_model in nren_staff_map.values():
-            employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
-            technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
-            if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
-                logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
-                               f' FTE do not equal across employed/technical categories ({employed} != {technical})')
-
-            session.merge(nren_staff_model)
+    for nren_staff_model in nren_staff_map.values():
+        employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
+        technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
+        if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
+            logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
+                           f' FTE do not equal across employed/technical categories ({employed} != {technical})')
 
-        session.commit()
+        db.session.merge(nren_staff_model)
 
+    db.session.commit()
 
-def db_ecprojects_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
-        for (abbrev, year, project) in ecproject_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
 
-            nren = nren_dict[abbrev]
-            ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
-            session.merge(ecproject_entry)
-        session.commit()
+def db_ecprojects_migration(nren_dict):
+    ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
+    for (abbrev, year, project) in ecproject_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
 
+        nren = nren_dict[abbrev]
+        ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
+        db.session.merge(ecproject_entry)
+    db.session.commit()
 
-def db_organizations_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
 
-        organization_data = parse_excel_data.fetch_organization_excel_data()
-        for (abbrev, year, org) in organization_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
+def db_organizations_migration(nren_dict):
+    organization_data = parse_excel_data.fetch_organization_excel_data()
+    for (abbrev, year, org) in organization_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
 
-            nren = nren_dict[abbrev]
-            org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
-            session.merge(org_entry)
-        session.commit()
+        nren = nren_dict[abbrev]
+        org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
+        db.session.merge(org_entry)
+    db.session.commit()
 
 
-def _cli(config):
+def _cli(config, app):
     helpers.init_db(config)
-    db_budget_migration()
-    db_funding_migration()
-    db_charging_structure_migration()
-    db_staffing_migration()
-    db_ecprojects_migration()
-    db_organizations_migration()
+    with app.app_context():
+        nren_dict = helpers.get_uppercase_nren_dict()
+        db_budget_migration(nren_dict)
+        db_funding_migration(nren_dict)
+        db_charging_structure_migration(nren_dict)
+        db_staffing_migration(nren_dict)
+        db_ecprojects_migration(nren_dict)
+        db_organizations_migration(nren_dict)
 
 
 @click.command()
 @click.option('--config', type=click.STRING, default='config.json')
 def cli(config):
     app_config = load(open(config, 'r'))
+    app = compendium_v2._create_app_with_db(app_config)
     print("survey-publisher-v1 starting")
-    _cli(app_config)
+    _cli(app_config, app)
 
 
 if __name__ == "__main__":
diff --git a/compendium_v2/routes/budget.py b/compendium_v2/routes/budget.py
index 763b3af3..1f4ff850 100644
--- a/compendium_v2/routes/budget.py
+++ b/compendium_v2/routes/budget.py
@@ -1,28 +1,17 @@
 import logging
 from typing import Any
 
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-from compendium_v2 import db
-from compendium_v2.db import model
+from compendium_v2.db import db
+from compendium_v2.db.model import BudgetEntry
 from compendium_v2.routes import common
 
-routes = Blueprint('budget', __name__)
-
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
-
 
+routes = Blueprint('budget', __name__)
 logger = logging.getLogger(__name__)
 
-col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061',
-           '#bd7ebe', '#ffb55a', '#ffee65',
-           '#beb9db', '#fdcce5', '#8bd3c7']
-
 BUDGET_RESPONSE_SCHEMA = {
     '$schema': 'http://json-schema.org/draft-07/schema#',
 
@@ -58,15 +47,15 @@ def budget_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.BudgetEntry):
+    def _extract_data(entry: BudgetEntry):
         return {
             'NREN': entry.nren.name,
             'BUDGET': float(entry.budget),
             'BUDGET_YEAR': entry.year,
         }
 
-    with db.session_scope() as session:
-        entries = sorted([_extract_data(entry)
-                          for entry in session.query(model.BudgetEntry)],
-                         key=lambda d: (d['BUDGET_YEAR'], d['NREN']))
+    entries = sorted(
+        [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))],
+        key=lambda d: (d['BUDGET_YEAR'], d['NREN'])
+    )
     return jsonify(entries)
diff --git a/compendium_v2/routes/charging.py b/compendium_v2/routes/charging.py
index 0b2c2384..57e8114f 100644
--- a/compendium_v2/routes/charging.py
+++ b/compendium_v2/routes/charging.py
@@ -1,21 +1,15 @@
 import logging
-
-from flask import Blueprint, jsonify, current_app
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('charging', __name__)
-
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ChargingStructure
+from compendium_v2.routes import common
 
 
+routes = Blueprint('charging', __name__)
 logger = logging.getLogger(__name__)
 
 CHARGING_STRUCTURE_RESPONSE_SCHEMA = {
@@ -53,16 +47,15 @@ def charging_structure_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.ChargingStructure):
+    def _extract_data(entry: ChargingStructure):
         return {
             'NREN': entry.nren.name,
             'YEAR': int(entry.year),
             'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None,
         }
 
-    with db.session_scope() as session:
-        entries = sorted([_extract_data(entry)
-                          for entry in session.query(model.ChargingStructure)
-                         .all()],
-                         key=lambda d: (d['NREN'], d['YEAR']))
+    entries = sorted(
+        [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))],
+        key=lambda d: (d['NREN'], d['YEAR'])
+    )
     return jsonify(entries)
diff --git a/compendium_v2/routes/ec_projects.py b/compendium_v2/routes/ec_projects.py
index 7114718d..b58d8931 100644
--- a/compendium_v2/routes/ec_projects.py
+++ b/compendium_v2/routes/ec_projects.py
@@ -1,22 +1,15 @@
 import logging
-
-from flask import Blueprint, jsonify, current_app
-
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('ec-projects', __name__)
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ECProject
+from compendium_v2.routes import common
 
 
+routes = Blueprint('ec-projects', __name__)
 logger = logging.getLogger(__name__)
 
 EC_PROJECTS_RESPONSE_SCHEMA = {
@@ -55,13 +48,12 @@ def ec_projects_view() -> Any:
     :return:
     """
 
-    def _extract_project(entry: model.ECProject):
+    def _extract_project(entry: ECProject):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
             'project': entry.project
         }
 
-    with db.session_scope() as session:
-        result = [_extract_project(project) for project in session.query(model.ECProject)]
+    result = [_extract_project(project) for project in db.session.scalars(select(ECProject))]
     return jsonify(result)
diff --git a/compendium_v2/routes/funding.py b/compendium_v2/routes/funding.py
index ed0e26c8..c02bf136 100644
--- a/compendium_v2/routes/funding.py
+++ b/compendium_v2/routes/funding.py
@@ -1,22 +1,15 @@
 import logging
 
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-from compendium_v2 import db
 from compendium_v2.routes import common
-from compendium_v2.db import model
+from compendium_v2.db import db
+from compendium_v2.db.model import FundingSource
 from typing import Any
 
-routes = Blueprint('funding', __name__)
-
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
-
 
+routes = Blueprint('funding', __name__)
 logger = logging.getLogger(__name__)
 
 FUNDING_RESPONSE_SCHEMA = {
@@ -59,7 +52,7 @@ def funding_source_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.FundingSource):
+    def _extract_data(entry: FundingSource):
         return {
             'NREN': entry.nren.name,
             'YEAR': entry.year,
@@ -70,8 +63,8 @@ def funding_source_view() -> Any:
             'OTHER': float(entry.other)
         }
 
-    with db.session_scope() as session:
-        entries = sorted([_extract_data(entry)
-                         for entry in session.query(model.FundingSource)],
-                         key=lambda d: (d['NREN'], d['YEAR']))
+    entries = sorted(
+        [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))],
+        key=lambda d: (d['NREN'], d['YEAR'])
+    )
     return jsonify(entries)
diff --git a/compendium_v2/routes/organization.py b/compendium_v2/routes/organization.py
index 61a43354..8e8ebc8d 100644
--- a/compendium_v2/routes/organization.py
+++ b/compendium_v2/routes/organization.py
@@ -1,22 +1,15 @@
 import logging
-
-from flask import Blueprint, jsonify, current_app
-
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('organization', __name__)
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ParentOrganization, SubOrganization
+from compendium_v2.routes import common
 
 
+routes = Blueprint('organization', __name__)
 logger = logging.getLogger(__name__)
 
 ORGANIZATION_RESPONSE_SCHEMA = {
@@ -71,15 +64,14 @@ def parent_organization_view() -> Any:
     :return:
     """
 
-    def _extract_parent(entry: model.ParentOrganization):
+    def _extract_parent(entry: ParentOrganization):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
             'name': entry.organization
         }
 
-    with db.session_scope() as session:
-        result = [_extract_parent(org) for org in session.query(model.ParentOrganization)]
+    result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))]
     return jsonify(result)
 
 
@@ -98,7 +90,7 @@ def sub_organization_view() -> Any:
     :return:
     """
 
-    def _extract_sub(entry: model.SubOrganization):
+    def _extract_sub(entry: SubOrganization):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
@@ -106,6 +98,5 @@ def sub_organization_view() -> Any:
             'role': entry.role
         }
 
-    with db.session_scope() as session:
-        result = [_extract_sub(org) for org in session.query(model.SubOrganization)]
+    result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))]
     return jsonify(result)
diff --git a/compendium_v2/routes/staff.py b/compendium_v2/routes/staff.py
index 73e79c48..b8478266 100644
--- a/compendium_v2/routes/staff.py
+++ b/compendium_v2/routes/staff.py
@@ -1,22 +1,15 @@
 import logging
 
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-from compendium_v2 import db
+from compendium_v2.db import db
+from compendium_v2.db.model import NREN, NrenStaff
 from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('staff', __name__)
-
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
-
 
+routes = Blueprint('staff', __name__)
 logger = logging.getLogger(__name__)
 
 STAFF_RESPONSE_SCHEMA = {
@@ -57,7 +50,7 @@ def staff_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.NrenStaff):
+    def _extract_data(entry: NrenStaff):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
@@ -67,7 +60,6 @@ def staff_view() -> Any:
             'non_technical_fte': float(entry.non_technical_fte)
         }
 
-    with db.session_scope() as session:
-        entries = [_extract_data(entry) for entry in session.query(
-            model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())]
+    entries = [_extract_data(entry) for entry in db.session.scalars(
+        select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))]
     return jsonify(entries)
diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py
index ce48aaa0..1550ddcb 100644
--- a/compendium_v2/survey_db/__init__.py
+++ b/compendium_v2/survey_db/__init__.py
@@ -31,11 +31,6 @@ def session_scope(
         session.close()
 
 
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
-    return (f'postgresql://{db_username}:{db_password}'
-            f'@{db_hostname}:{port}/{db_name}')
-
-
 def init_db_model(dsn):
     global _SESSION_MAKER
 
diff --git a/requirements.txt b/requirements.txt
index 12ec2076..98804e02 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ click~=8.1
 jsonschema~=4.17
 flask~=2.2
 flask-cors~=3.0
+flask-sqlalchemy~=3.0
 openpyxl~=3.1
 psycopg2-binary~=2.9
 SQLAlchemy~=2.0
diff --git a/setup.py b/setup.py
index 213288e5..d051ddd7 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,7 @@ setup(
         'jsonschema~=4.17',
         'flask~=2.2',
         'flask-cors~=3.0',
+        'flask-sqlalchemy~=3.0',
         'openpyxl~=3.1',
         'psycopg2-binary~=2.9',
         'SQLAlchemy~=2.0',
diff --git a/test/conftest.py b/test/conftest.py
index 79186212..ba6b9a43 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,20 +1,15 @@
-import json
+import csv
 import os
-import tempfile
-import random
-
 import pytest
-import compendium_v2
-from compendium_v2 import db
-from compendium_v2.db import model
-from compendium_v2 import survey_db
-from compendium_v2.survey_db import model as survey_model
+import random
 
 from sqlalchemy import create_engine
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.pool import StaticPool
 
-import csv
+import compendium_v2
+from compendium_v2.db import db, model
+from compendium_v2.survey_db import model as survey_model
 
 
 def _test_data_csv(filename):
@@ -43,61 +38,29 @@ def mocked_survey_db(mocker):
 
 
 @pytest.fixture
-def mocked_db(mocker):
-    # cf. https://stackoverflow.com/a/33057675
-    engine = create_engine(
-        'sqlite://',
-        connect_args={'check_same_thread': False},
-        poolclass=StaticPool,
-        echo=False)
-    model.Base.metadata.create_all(engine)
-    mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine))
-    mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None)
-    mocker.patch('compendium_v2.migrate_database', lambda config: None)
-
-
-@pytest.fixture
-def test_budget_data():
-    with db.session_scope() as session:
+def test_budget_data(app):
+    with app.app_context():
         data = [row for row in _test_data_csv("BudgetTestData.csv")]
         nren_names = set([row["nren"] for row in data])
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
             budget = row["budget"]
             year = row["year"]
 
-            session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
-
-    with survey_db.session_scope() as session:
-        data = _test_data_csv("BudgetTestData.csv")
-        nrens = set()
-        budgets_data = []
-        for row in data:
-            nren = row["nren"]
-            budget = row["budget"]
-            year = row["year"]
-            country_code = row["nren"]
-
-            nrens.add(nren)
-
-            budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code))
-
-        for nren in nrens:
-            session.add(survey_model.Nrens(abbreviation=nren, country_code=nren))
-
-        session.add_all(budgets_data)
+            db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
+        db.session.commit()
 
 
 @pytest.fixture
-def test_funding_source_data():
-    with db.session_scope() as session:
+def test_funding_source_data(app):
+    with app.app_context():
         data = [row for row in _test_data_csv("FundingSourceTestData.csv")]
         nren_names = set([row["nren"] for row in data])
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
@@ -108,7 +71,7 @@ def test_funding_source_data():
             commercial = row["commercial"]
             other = row["other"]
 
-            session.add(
+            db.session.add(
                 model.FundingSource(
                     nren=nren, year=year,
                     client_institutions=client,
@@ -117,10 +80,11 @@ def test_funding_source_data():
                     commercial=commercial,
                     other=other)
             )
+        db.session.commit()
 
 
 @pytest.fixture
-def test_staff_data():
+def test_staff_data(app):
 
     # generator of random test data for 5 years and 100 nrens
     def _generate_rows():
@@ -135,12 +99,12 @@ def test_staff_data():
                     "non_technical_fte": random.randint(0, 100)
                 }
 
-    with db.session_scope() as session:
+    with app.app_context():
 
         data = list(_generate_rows())
 
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
@@ -150,7 +114,7 @@ def test_staff_data():
             technical_fte = row["technical_fte"]
             non_technical_fte = row["non_technical_fte"]
 
-            session.add(
+            db.session.add(
                 model.NrenStaff(
                     nren=nren,
                     year=year,
@@ -160,30 +124,29 @@ def test_staff_data():
                     non_technical_fte=non_technical_fte
                 )
             )
+        db.session.commit()
 
 
 @pytest.fixture
-def data_config_filename(dummy_config):
-    with tempfile.NamedTemporaryFile() as f:
-        f.write(json.dumps(dummy_config).encode('utf-8'))
-        f.flush()
-        yield f.name
+def app(dummy_config):
+    app = compendium_v2._create_app_with_db(dummy_config)
+    with app.app_context():
+        db.create_all()
+    yield app
 
 
 @pytest.fixture
-def client(data_config_filename, mocked_db, mocked_survey_db):
-    os.environ['SETTINGS_FILENAME'] = data_config_filename
-    with compendium_v2.create_app().test_client() as c:
-        yield c
+def client(app):
+    return app.test_client()
 
 
 @pytest.fixture
-def test_charging_structure_data():
-    with db.session_scope() as session:
+def test_charging_structure_data(app):
+    with app.app_context():
         data = [row for row in _test_data_csv("ChargingStructureTestData.csv")]
         nren_names = set([row["nren"] for row in data])
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
@@ -192,15 +155,16 @@ def test_charging_structure_data():
             if fee_type == "null":
                 fee_type = None
 
-            session.add(
+            db.session.add(
                 model.ChargingStructure(
                     nren=nren, year=year,
                     fee_type=fee_type)
             )
+        db.session.commit()
 
 
 @pytest.fixture
-def test_organization_data():
+def test_organization_data(app):
     def _generate_sub_org_data():
         for nren in ["nren" + str(i) for i in range(1, 50)]:
             for year in range(2016, 2021):
@@ -220,21 +184,21 @@ def test_organization_data():
                     'name': 'org' + str(year)
                 }
 
-    with db.session_scope() as session:
+    with app.app_context():
 
         org_data = list(_generate_org_data())
         sub_org_data = list(_generate_sub_org_data())
 
         nren_dict = {nren_name: model.NREN(name=nren_name)
                      for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for org in org_data:
             nren = nren_dict[org["nren"]]
             year = org["year"]
             name = org["name"]
 
-            session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
+            db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
 
         for sub_org in sub_org_data:
             nren = nren_dict[sub_org["nren"]]
@@ -242,13 +206,13 @@ def test_organization_data():
             name = sub_org["name"]
             role = sub_org["role"]
 
-            session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
+            db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
 
-        session.commit()
+        db.session.commit()
 
 
 @pytest.fixture
-def test_ec_project_data():
+def test_ec_project_data(app):
     def _generate_ec_project_data():
         for nren in ["nren" + str(i) for i in range(1, 50)]:
             for year in range(2016, 2021):
@@ -264,19 +228,19 @@ def test_ec_project_data():
                         'project': 'ec_project2',
                     }
 
-    with db.session_scope() as session:
+    with app.app_context():
 
         ec_project_data = list(_generate_ec_project_data())
 
         nren_dict = {nren_name: model.NREN(name=nren_name)
                      for nren_name in set(d['nren'] for d in ec_project_data)}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for ec_project in ec_project_data:
             nren = nren_dict[ec_project["nren"]]
             year = ec_project["year"]
             project = ec_project["project"]
 
-            session.add(model.ECProject(nren=nren, year=year, project=project))
+            db.session.add(model.ECProject(nren=nren, year=year, project=project))
 
-        session.commit()
+        db.session.commit()
diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py
index ec8802ed..433def2c 100644
--- a/test/test_survey_publisher_2022.py
+++ b/test/test_survey_publisher_2022.py
@@ -1,5 +1,4 @@
-from compendium_v2 import db
-from compendium_v2.db import model
+from compendium_v2.db import db, model
 from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \
     StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion
 
@@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority"
         ]
 
 
-def test_publisher(client, mocker, dummy_config):
+def test_publisher(app, mocker, dummy_config):
     global org_data
 
     def get_rows_as_tuples(*args, **kwargs):
@@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config):
     mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data)
     mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data)
 
-    with db.session_scope() as session:
-        nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
-        session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+    nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
+    with app.app_context():
+        db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+        db.session.commit()
 
-    _cli(dummy_config)
+    _cli(dummy_config, app)
 
-    with db.session_scope() as session:
-        budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
+    with app.app_context():
+        budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
         assert len(budgets) == 3
         assert budgets[0].nren.name.lower() == 'nren1'
         assert budgets[0].budget == 100
 
-        funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
+        funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
         assert len(funding_sources) == 3
         assert funding_sources[0].nren.name.lower() == 'nren1'
         assert funding_sources[0].client_institutions == 10
@@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config):
         assert funding_sources[2].european_funding == 30
         assert funding_sources[2].other == 30
 
-        staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
+        staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
 
         assert len(staff_data) == 3
         assert staff_data[0].nren.name.lower() == 'nren1'
@@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config):
         assert staff_data[2].permanent_fte == 30
         assert staff_data[2].subcontracted_fte == 0
 
-        _org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
+        _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
 
         assert len(_org_data) == 2
         assert _org_data[0].nren.name.lower() == 'nren1'
@@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config):
         assert _org_data[1].nren.name.lower() == 'nren3'
         assert _org_data[1].organization == 'Org3'
 
-        charging_structures = session.query(model.ChargingStructure).order_by(
+        charging_structures = db.session.query(model.ChargingStructure).order_by(
             model.ChargingStructure.nren_id.asc()).all()
         assert len(charging_structures) == 3
         assert charging_structures[0].nren.name.lower() == 'nren1'
@@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config):
         assert charging_structures[2].nren.name.lower() == 'nren3'
         assert charging_structures[2].fee_type == model.FeeType.other
 
-        _ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
+        _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
 
         assert len(_ec_data) == 3
         assert _ec_data[0].nren.name.lower() == 'nren2'
diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py
index c7ebc2d9..a6dbbaa1 100644
--- a/test/test_survey_publisher_v1.py
+++ b/test/test_survey_publisher_v1.py
@@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli
 EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx")
 
 
-def test_publisher(client, mocker, dummy_config):
+def test_publisher(mocked_survey_db, app, mocker, dummy_config):
     mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE)
 
-    with db.session_scope() as session:
+    with app.app_context():
         nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
-        session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+        db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+        db.session.commit()
 
-    _cli(dummy_config)
+    _cli(dummy_config, app)
 
-    with db.session_scope() as session:
-        budget_count = session.query(model.BudgetEntry.year).count()
+    with app.app_context():
+        budget_count = db.session.query(model.BudgetEntry.year).count()
         assert budget_count
-        funding_source_count = session.query(model.FundingSource.year).count()
+        funding_source_count = db.session.query(model.FundingSource.year).count()
         assert funding_source_count
-        charging_structure_count = session.query(model.ChargingStructure.year).count()
+        charging_structure_count = db.session.query(model.ChargingStructure.year).count()
         assert charging_structure_count
-        staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
+        staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
 
         # data should only be saved for the NRENs we have saved in the database
         staff_data_nrens = set([staff.nren.name for staff in staff_data])
@@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config):
         assert kifu_data[5].technical_fte == 133
         assert kifu_data[5].non_technical_fte == 45
 
-        ecproject_data = session.query(model.ECProject).all()
+        ecproject_data = db.session.query(model.ECProject).all()
         # test a couple of random entries
         surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017]
         assert len(surf2017) == 1
@@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config):
         assert len(kifu2019) == 4
         assert kifu2019[3].project == 'SuperHeroes for Science'
 
-        parent_data = session.query(model.ParentOrganization).all()
+        parent_data = db.session.query(model.ParentOrganization).all()
         # test a random entry
         asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021]
         assert len(asnet2021) == 1
-- 
GitLab