From 6e62575f75375af493a76a394264a293cb80c7b8 Mon Sep 17 00:00:00 2001
From: Remco Tukker <remco.tukker@geant.org>
Date: Thu, 4 May 2023 10:00:14 +0200
Subject: [PATCH 1/4] use flask-sqlalchemy for the main db

---
 compendium_v2/__init__.py                     |  12 +-
 compendium_v2/db/__init__.py                  |  50 +-
 compendium_v2/db/model.py                     |  48 +-
 compendium_v2/migrations/env.py               |   2 +-
 compendium_v2/migrations/migration_utils.py   |   8 +-
 compendium_v2/publishers/helpers.py           |  13 +-
 .../publishers/survey_publisher_2022.py       | 485 +++++++++---------
 .../publishers/survey_publisher_v1.py         | 272 +++++-----
 compendium_v2/routes/budget.py                |  31 +-
 compendium_v2/routes/charging.py              |  29 +-
 compendium_v2/routes/ec_projects.py           |  24 +-
 compendium_v2/routes/funding.py               |  27 +-
 compendium_v2/routes/organization.py          |  29 +-
 compendium_v2/routes/staff.py                 |  24 +-
 compendium_v2/survey_db/__init__.py           |   5 -
 requirements.txt                              |   1 +
 setup.py                                      |   1 +
 test/conftest.py                              | 122 ++---
 test/test_survey_publisher_2022.py            |  28 +-
 test/test_survey_publisher_v1.py              |  23 +-
 20 files changed, 545 insertions(+), 689 deletions(-)

diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py
index 622565f2..7a69bf9e 100644
--- a/compendium_v2/__init__.py
+++ b/compendium_v2/__init__.py
@@ -8,7 +8,7 @@ from flask import Flask
 from flask_cors import CORS  # for debugging
 
 from compendium_v2 import config, environment
-
+from compendium_v2.db import db
 from compendium_v2.migrations import migration_utils
 
 
@@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask:
     return app
 
 
+def _create_app_with_db(app_config) -> Flask:
+    # used by the tests and the publishers
+    app = _create_app(app_config)
+    app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI']
+    db.init_app(app)
+    return app
+
+
 def create_app() -> Flask:
     """
     overrides default settings with those found
@@ -46,7 +54,7 @@ def create_app() -> Flask:
     with open(os.environ['SETTINGS_FILENAME']) as f:
         app_config = config.load(f)
 
-    app = _create_app(app_config)
+    app = _create_app_with_db(app_config)
 
     logging.info('Flask app initialized')
 
diff --git a/compendium_v2/db/__init__.py b/compendium_v2/db/__init__.py
index ce48aaa0..13719e0a 100644
--- a/compendium_v2/db/__init__.py
+++ b/compendium_v2/db/__init__.py
@@ -1,45 +1,17 @@
-import contextlib
 import logging
-from typing import Optional, Union, Callable, Iterator
 
-from sqlalchemy import create_engine
-from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.orm import sessionmaker, Session
+from flask_sqlalchemy import SQLAlchemy
+from sqlalchemy import MetaData
 
-logger = logging.getLogger(__name__)
-_SESSION_MAKER: Union[None, sessionmaker] = None
-
-
-@contextlib.contextmanager
-def session_scope(
-        callback_before_close: Optional[Callable] = None) -> Iterator[Session]:
-    # best practice is to keep session scope separate from data processing
-    # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html
-
-    assert _SESSION_MAKER
-    session = _SESSION_MAKER()
-    try:
-        yield session
-        session.commit()
-        if callback_before_close:
-            callback_before_close()
-    except SQLAlchemyError:
-        logger.error('caught sql layer exception, rolling back')
-        session.rollback()
-        raise  # re-raise, will be handled by main consumer
-    finally:
-        session.close()
-
-
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
-    return (f'postgresql://{db_username}:{db_password}'
-            f'@{db_hostname}:{port}/{db_name}')
 
+logger = logging.getLogger(__name__)
 
-def init_db_model(dsn):
-    global _SESSION_MAKER
+metadata_obj = MetaData(naming_convention={
+    "ix": "ix_%(column_0_label)s",
+    "uq": "uq_%(table_name)s_%(column_0_name)s",
+    "ck": "ck_%(table_name)s_%(constraint_name)s",
+    "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+    "pk": "pk_%(table_name)s",
+})
 
-    # cf. https://docs.sqlalchemy.org/en
-    #        /latest/orm/extensions/automap.html
-    engine = create_engine(dsn, pool_size=10)
-    _SESSION_MAKER = sessionmaker(bind=engine)
+db = SQLAlchemy(metadata=metadata_obj)
diff --git a/compendium_v2/db/model.py b/compendium_v2/db/model.py
index 67672743..2199f759 100644
--- a/compendium_v2/db/model.py
+++ b/compendium_v2/db/model.py
@@ -4,42 +4,34 @@ from enum import Enum
 from typing import Optional
 from typing_extensions import Annotated
 
-from sqlalchemy import MetaData, String
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from sqlalchemy import String
+from sqlalchemy.orm import Mapped, mapped_column, relationship
 from sqlalchemy.schema import ForeignKey
 
+from compendium_v2.db import db
 
-logger = logging.getLogger(__name__)
-
-convention = {
-    "ix": "ix_%(column_0_label)s",
-    "uq": "uq_%(table_name)s_%(column_0_name)s",
-    "ck": "ck_%(table_name)s_%(constraint_name)s",
-    "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
-    "pk": "pk_%(table_name)s",
-}
 
-metadata_obj = MetaData(naming_convention=convention)
+logger = logging.getLogger(__name__)
 
 str128 = Annotated[str, 128]
+str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)]
+str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)]
 int_pk = Annotated[int, mapped_column(primary_key=True)]
 int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)]
 
 
-class Base(DeclarativeBase):
-    metadata = metadata_obj
-    type_annotation_map = {
-        str128: String(128),
-    }
+# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet.
+# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140
+# mypy: disable-error-code="name-defined"
 
 
-class NREN(Base):
+class NREN(db.Model):
     __tablename__ = 'nren'
     id: Mapped[int_pk]
     name: Mapped[str128]
 
 
-class BudgetEntry(Base):
+class BudgetEntry(db.Model):
     __tablename__ = 'budgets'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -47,7 +39,7 @@ class BudgetEntry(Base):
     budget: Mapped[Decimal]
 
 
-class FundingSource(Base):
+class FundingSource(db.Model):
     __tablename__ = 'funding_source'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -67,7 +59,7 @@ class FeeType(Enum):
     other = "other"
 
 
-class ChargingStructure(Base):
+class ChargingStructure(db.Model):
     __tablename__ = 'charging_structure'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -75,7 +67,7 @@ class ChargingStructure(Base):
     fee_type: Mapped[Optional[FeeType]]
 
 
-class NrenStaff(Base):
+class NrenStaff(db.Model):
     __tablename__ = 'nren_staff'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -86,7 +78,7 @@ class NrenStaff(Base):
     non_technical_fte: Mapped[Decimal]
 
 
-class ParentOrganization(Base):
+class ParentOrganization(db.Model):
     __tablename__ = 'parent_organization'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
@@ -94,18 +86,18 @@ class ParentOrganization(Base):
     organization: Mapped[str128]
 
 
-class SubOrganization(Base):
+class SubOrganization(db.Model):
     __tablename__ = 'sub_organization'
     nren_id: Mapped[int_pk_fkNREN]
     nren: Mapped[NREN] = relationship(lazy='joined')
     year: Mapped[int_pk]
-    organization: Mapped[str128] = mapped_column(primary_key=True)
+    organization: Mapped[str128_pk]
     role: Mapped[str128]
 
 
-class ECProject(Base):
+class ECProject(db.Model):
     __tablename__ = 'ec_project'
     nren_id: Mapped[int_pk_fkNREN]
-    nren: Mapped[NREN] = relationship(NREN, lazy='joined')
+    nren: Mapped[NREN] = relationship(lazy='joined')
     year: Mapped[int_pk]
-    project: Mapped[str] = mapped_column(String(256), primary_key=True)
+    project: Mapped[str256_pk]
diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py
index 0307be33..5ea9c8d3 100644
--- a/compendium_v2/migrations/env.py
+++ b/compendium_v2/migrations/env.py
@@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config
 from sqlalchemy import pool
 
 from alembic import context
-from compendium_v2.db.model import metadata_obj
+from compendium_v2.db import metadata_obj
 
 # this is the Alembic Config object, which provides
 # access to the values within the .ini file in use.
diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py
index 3b29b540..f25f03c9 100644
--- a/compendium_v2/migrations/migration_utils.py
+++ b/compendium_v2/migrations/migration_utils.py
@@ -1,7 +1,6 @@
 import logging
 import os
 
-from compendium_v2 import db
 from alembic.config import Config
 from alembic import command
 
@@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY):
     command.upgrade(alembic_config, 'head')
 
 
+def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
+    return (f'postgresql://{db_username}:{db_password}'
+            f'@{db_hostname}:{port}/{db_name}')
+
+
 if __name__ == "__main__":
     logging.basicConfig(level=logging.DEBUG)
-    upgrade(db.postgresql_dsn(
+    upgrade(postgresql_dsn(
         db_username='compendium',
         db_password='compendium321',
         db_hostname='localhost',
diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py
index d95ad4e1..43d1bf02 100644
--- a/compendium_v2/publishers/helpers.py
+++ b/compendium_v2/publishers/helpers.py
@@ -1,21 +1,20 @@
-from compendium_v2 import db, survey_db
-from compendium_v2.db import model
+from sqlalchemy import select
+
+from compendium_v2 import survey_db
+from compendium_v2.db import db, model
 
 
 def init_db(config):
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
     dsn_survey = config['SURVEY_DATABASE_URI']
     survey_db.init_db_model(dsn_survey)
 
 
-def get_uppercase_nren_dict(session):
+def get_uppercase_nren_dict():
     """
-    :param session: db session that is used to query the known NRENs
     :return: a dictionary of all known NRENs db entities keyed on the
              uppercased name
     """
-    current_nrens = session.query(model.NREN).all()
+    current_nrens = db.session.scalars(select(model.NREN))
     nren_dict = {nren.name.upper(): nren for nren in current_nrens}
     # add aliases that are used in the source data:
     nren_dict['ASNET'] = nren_dict['ASNET-AM']
diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py
index 710c899d..ea88ee77 100644
--- a/compendium_v2/publishers/survey_publisher_2022.py
+++ b/compendium_v2/publishers/survey_publisher_2022.py
@@ -16,11 +16,12 @@ import html
 from sqlalchemy import text
 from collections import defaultdict
 
+import compendium_v2
 from compendium_v2.db.model import FeeType
 from compendium_v2.environment import setup_logging
 from compendium_v2.config import load
-from compendium_v2 import db, survey_db
-from compendium_v2.db import model
+from compendium_v2 import survey_db
+from compendium_v2.db import db, model
 from compendium_v2.publishers import helpers
 
 setup_logging()
@@ -133,163 +134,151 @@ def query_question(question: enum.Enum):
         return survey.execute(text(query))
 
 
-def transfer_budget():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        rows = query_budget()
-        for row in rows:
+def transfer_budget(nren_dict):
+    rows = query_budget()
+    for row in rows:
+        nren_name = row[0].upper()
+        _budget = row[1]
+        try:
+            budget = float(_budget.replace('"', '').replace(',', ''))
+        except ValueError:
+            logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))')
+            continue
+
+        if budget > 200:
+            logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})')
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
+
+        budget_entry = model.BudgetEntry(
+            nren=nren_dict[nren_name],
+            budget=budget,
+            year=2022,
+        )
+        db.session.merge(budget_entry)
+    db.session.commit()
+
+
+def transfer_funding_sources(nren_dict):
+    sourcedata = {}
+    for source, data in query_funding_sources():
+        for row in data:
             nren_name = row[0].upper()
-            _budget = row[1]
+            _value = row[1]
             try:
-                budget = float(_budget.replace('"', '').replace(',', ''))
+                value = float(_value.replace('"', '').replace(',', ''))
             except ValueError:
-                logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))')
-                continue
-
-            if budget > 200:
-                logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})')
-
-            if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping.')
-                continue
+                name = source.name
+                logger.info(f'{nren_name} has invalid value for {name}. ({_value}))')
+                value = 0
 
-            budget_entry = model.BudgetEntry(
-                nren=nren_dict[nren_name],
-                budget=budget,
-                year=2022,
+            nren_info = sourcedata.setdefault(
+                nren_name,
+                {source_type: 0 for source_type in FundingSource}
             )
-            session.merge(budget_entry)
-        session.commit()
-
-
-def transfer_funding_sources():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        sourcedata = {}
-        for source, data in query_funding_sources():
-            for row in data:
-                nren_name = row[0].upper()
-                _value = row[1]
-                try:
-                    value = float(_value.replace('"', '').replace(',', ''))
-                except ValueError:
-                    name = source.name
-                    logger.info(f'{nren_name} has invalid value for {name}. ({_value}))')
-                    value = 0
-
-                nren_info = sourcedata.setdefault(
-                    nren_name,
-                    {source_type: 0 for source_type in FundingSource}
-                )
-                nren_info[source] = value
-
-        for nren_name, nren_info in sourcedata.items():
-            total = sum(nren_info.values())
-
-            if not math.isclose(total, 100, abs_tol=0.01):
-                logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})')
+            nren_info[source] = value
+
+    for nren_name, nren_info in sourcedata.items():
+        total = sum(nren_info.values())
+
+        if not math.isclose(total, 100, abs_tol=0.01):
+            logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})')
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
+
+        funding_source = model.FundingSource(
+            nren=nren_dict[nren_name],
+            year=2022,
+            client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS],
+            european_funding=nren_info[FundingSource.EUROPEAN_FUNDING],
+            gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES],
+            commercial=nren_info[FundingSource.COMMERCIAL],
+            other=nren_info[FundingSource.OTHER],
+        )
+        db.session.merge(funding_source)
+    db.session.commit()
+
+
+def transfer_staff_data(nren_dict):
+    data = {}
+    for question in StaffQuestion:
+        rows = query_question(question)
+        for row in rows:
+            nren_name = row[0].upper()
+            _value = row[1]
+            try:
+                value = float(_value.replace('"', '').replace(',', ''))
+            except ValueError:
+                value = 0
 
             if nren_name not in nren_dict:
                 logger.info(f'{nren_name} unknown. Skipping.')
                 continue
 
-            funding_source = model.FundingSource(
-                nren=nren_dict[nren_name],
-                year=2022,
-                client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS],
-                european_funding=nren_info[FundingSource.EUROPEAN_FUNDING],
-                gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES],
-                commercial=nren_info[FundingSource.COMMERCIAL],
-                other=nren_info[FundingSource.OTHER],
-            )
-            session.merge(funding_source)
-        session.commit()
-
-
-def transfer_staff_data():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        data = {}
-        for question in StaffQuestion:
-            rows = query_question(question)
-            for row in rows:
-                nren_name = row[0].upper()
-                _value = row[1]
-                try:
-                    value = float(_value.replace('"', '').replace(',', ''))
-                except ValueError:
-                    value = 0
-
-                if nren_name not in nren_dict:
-                    logger.info(f'{nren_name} unknown. Skipping.')
-                    continue
-
-                # initialize on first use, so we don't add data for nrens with no answers
-                data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[
-                    question] = value
-
-        for nren_name, nren_info in data.items():
-            if sum([nren_info[question] for question in StaffQuestion]) == 0:
-                logger.info(f'{nren_name} has no staff data. Deleting if exists.')
-                session.query(model.NrenStaff).filter(
-                    model.NrenStaff.nren_id == nren_dict[nren_name].id,
-                    model.NrenStaff.year == 2022,
-                ).delete()
-                continue
-
-            employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
-            technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE]
-
-            if not math.isclose(employed, technical, abs_tol=0.01):
-                logger.info(f'{nren_name} FTE do not equal across employed/technical categories.'
-                            f' ({employed} != {technical})')
-
-            staff_data = model.NrenStaff(
-                nren_id=nren_dict[nren_name].id,
-                year=2022,
-                permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE],
-                subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE],
-                technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE],
-                non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE],
-            )
-            session.merge(staff_data)
-        session.commit()
-
-
-def transfer_nren_parent_org():
+            # initialize on first use, so we don't add data for nrens with no answers
+            data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[
+                question] = value
+
+    for nren_name, nren_info in data.items():
+        if sum([nren_info[question] for question in StaffQuestion]) == 0:
+            logger.info(f'{nren_name} has no staff data. Deleting if exists.')
+            db.session.query(model.NrenStaff).filter(
+                model.NrenStaff.nren_id == nren_dict[nren_name].id,
+                model.NrenStaff.year == 2022,
+            ).delete()
+            continue
+
+        employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
+        technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE]
+
+        if not math.isclose(employed, technical, abs_tol=0.01):
+            logger.info(f'{nren_name} FTE do not equal across employed/technical categories.'
+                        f' ({employed} != {technical})')
+
+        staff_data = model.NrenStaff(
+            nren_id=nren_dict[nren_name].id,
+            year=2022,
+            permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE],
+            subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE],
+            technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE],
+            non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE],
+        )
+        db.session.merge(staff_data)
+    db.session.commit()
+
+
+def transfer_nren_parent_org(nren_dict):
     # clean up the data a bit by removing some strings
     strings_to_replace = [
         'We are affiliated to '
     ]
 
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        rows = query_question(OrgQuestion.PARENT_ORG_NAME)
-        for row in rows:
-            nren_name = row[0].upper()
-            value = str(row[1]).replace('"', '')
+    rows = query_question(OrgQuestion.PARENT_ORG_NAME)
+    for row in rows:
+        nren_name = row[0].upper()
+        value = str(row[1]).replace('"', '')
 
-            for string in strings_to_replace:
-                value = value.replace(string, '')
+        for string in strings_to_replace:
+            value = value.replace(string, '')
 
-            if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping.')
-                continue
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
 
-            parent_org = model.ParentOrganization(
-                nren_id=nren_dict[nren_name].id,
-                year=2022,
-                organization=value,
-            )
-            session.merge(parent_org)
-        session.commit()
+        parent_org = model.ParentOrganization(
+            nren_id=nren_dict[nren_name].id,
+            year=2022,
+            organization=value,
+        )
+        db.session.merge(parent_org)
+    db.session.commit()
 
 
-def transfer_nren_sub_org():
+def transfer_nren_sub_org(nren_dict):
     suborg_questions = [
         (OrgQuestion.SUB_ORGS_1_NAME, OrgQuestion.SUB_ORGS_1_CHOICE, OrgQuestion.SUB_ORGS_1_ROLE),
         (OrgQuestion.SUB_ORGS_2_NAME, OrgQuestion.SUB_ORGS_2_CHOICE, OrgQuestion.SUB_ORGS_2_ROLE),
@@ -298,140 +287,134 @@ def transfer_nren_sub_org():
         (OrgQuestion.SUB_ORGS_5_NAME, OrgQuestion.SUB_ORGS_5_CHOICE, OrgQuestion.SUB_ORGS_5_ROLE)
     ]
 
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        lookup = defaultdict(list)
-
-        for name, choice, role in suborg_questions:
-            _name_rows = query_question(name)
-            _choice_rows = query_question(choice)
-            _role_rows = list(query_question(role))
-            for _name, _choice in zip(_name_rows, _choice_rows):
-                nren_name = _name[0].upper()
-                suborg_name = _name[1].replace('"', '').strip()
-                role_choice = _choice[1].replace('"', '').strip()
-
-                if nren_name not in nren_dict:
-                    logger.info(f'{nren_name} unknown. Skipping.')
-                    continue
-
-                if role_choice.lower() == 'other':
-                    for _role in _role_rows:
-                        if _role[0] == _name[0]:
-                            role = _role[1].replace('"', '').strip()
-                            break
-                else:
-                    role = role_choice
-
-                lookup[nren_name].append((suborg_name, role))
-
-        for nren_name, suborgs in lookup.items():
-            for suborg_name, role in suborgs:
-                suborg = model.SubOrganization(
-                    nren_id=nren_dict[nren_name].id,
-                    year=2022,
-                    organization=suborg_name,
-                    role=role,
-                )
-                session.merge(suborg)
-        session.commit()
-
-
-def transfer_charging_structure():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        rows = query_question(ChargingStructure.charging_structure)
-        for row in rows:
-            nren_name = row[0].upper()
-            value = row[1].replace('"', '').strip()
+    lookup = defaultdict(list)
+
+    for name, choice, role in suborg_questions:
+        _name_rows = query_question(name)
+        _choice_rows = query_question(choice)
+        _role_rows = list(query_question(role))
+        for _name, _choice in zip(_name_rows, _choice_rows):
+            nren_name = _name[0].upper()
+            suborg_name = _name[1].replace('"', '').strip()
+            role_choice = _choice[1].replace('"', '').strip()
 
             if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping from charging structure.')
+                logger.info(f'{nren_name} unknown. Skipping.')
                 continue
 
-            if "do not charge" in value:
-                charging_structure = FeeType.no_charge
-            elif "combination" in value:
-                charging_structure = FeeType.combination
-            elif "flat" in value:
-                charging_structure = FeeType.flat_fee
-            elif "usage-based" in value:
-                charging_structure = FeeType.usage_based_fee
-            elif "Other" in value:
-                charging_structure = FeeType.other
+            if role_choice.lower() == 'other':
+                for _role in _role_rows:
+                    if _role[0] == _name[0]:
+                        role = _role[1].replace('"', '').strip()
+                        break
             else:
-                charging_structure = None
+                role = role_choice
+
+            lookup[nren_name].append((suborg_name, role))
 
-            charging_structure = model.ChargingStructure(
+    for nren_name, suborgs in lookup.items():
+        for suborg_name, role in suborgs:
+            suborg = model.SubOrganization(
                 nren_id=nren_dict[nren_name].id,
                 year=2022,
-                fee_type=charging_structure,
+                organization=suborg_name,
+                role=role,
             )
-            session.merge(charging_structure)
-        session.commit()
-
-
-def transfer_ec_projects():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        # delete all existing EC projects, in case something changed
-        session.query(model.ECProject).filter(
-            model.ECProject.year == 2022,
-        ).delete()
-
-        rows = query_question(ECQuestion.EC_PROJECT)
-        for row in rows:
-            nren_name = row[0].upper()
-
-            if nren_name not in nren_dict:
-                logger.info(f'{nren_name} unknown. Skipping.')
+            db.session.merge(suborg)
+    db.session.commit()
+
+
+def transfer_charging_structure(nren_dict):
+    rows = query_question(ChargingStructure.charging_structure)
+    for row in rows:
+        nren_name = row[0].upper()
+        value = row[1].replace('"', '').strip()
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping from charging structure.')
+            continue
+
+        if "do not charge" in value:
+            charging_structure = FeeType.no_charge
+        elif "combination" in value:
+            charging_structure = FeeType.combination
+        elif "flat" in value:
+            charging_structure = FeeType.flat_fee
+        elif "usage-based" in value:
+            charging_structure = FeeType.usage_based_fee
+        elif "Other" in value:
+            charging_structure = FeeType.other
+        else:
+            charging_structure = None
+
+        charging_structure = model.ChargingStructure(
+            nren_id=nren_dict[nren_name].id,
+            year=2022,
+            fee_type=charging_structure,
+        )
+        db.session.merge(charging_structure)
+    db.session.commit()
+
+
+def transfer_ec_projects(nren_dict):
+    # delete all existing EC projects, in case something changed
+    db.session.query(model.ECProject).filter(
+        model.ECProject.year == 2022,
+    ).delete()
+
+    rows = query_question(ECQuestion.EC_PROJECT)
+    for row in rows:
+        nren_name = row[0].upper()
+
+        if nren_name not in nren_dict:
+            logger.info(f'{nren_name} unknown. Skipping.')
+            continue
+
+        try:
+            value = json.loads(row[1])
+        except json.decoder.JSONDecodeError:
+            logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.')
+            continue
+
+        for val in value:
+            if not val:
+                logger.info(f'Invalid EC project value for {nren_name}: {val}.')
                 continue
 
-            try:
-                value = json.loads(row[1])
-            except json.decoder.JSONDecodeError:
-                logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.')
-                continue
-
-            for val in value:
-                if not val:
-                    logger.info(f'Invalid EC project value for {nren_name}: {val}.')
-                    continue
-
-                # strip html entities/NBSP from val
-                val = html.unescape(val).replace('\xa0', ' ')
+            # strip html entities/NBSP from val
+            val = html.unescape(val).replace('\xa0', ' ')
 
-                # some answers include contract numbers, which we don't want here
-                val = val.split('(contract n')[0]
+            # some answers include contract numbers, which we don't want here
+            val = val.split('(contract n')[0]
 
-                ec_project = model.ECProject(
-                    nren_id=nren_dict[nren_name].id,
-                    year=2022,
-                    project=str(val).strip()
-                )
-                session.add(ec_project)
-        session.commit()
+            ec_project = model.ECProject(
+                nren_id=nren_dict[nren_name].id,
+                year=2022,
+                project=str(val).strip()
+            )
+            db.session.add(ec_project)
+    db.session.commit()
 
 
-def _cli(config):
+def _cli(config, app):
     helpers.init_db(config)
-    transfer_budget()
-    transfer_funding_sources()
-    transfer_staff_data()
-    transfer_nren_parent_org()
-    transfer_nren_sub_org()
-    transfer_charging_structure()
-    transfer_ec_projects()
+    with app.app_context():
+        nren_dict = helpers.get_uppercase_nren_dict()
+        transfer_budget(nren_dict)
+        transfer_funding_sources(nren_dict)
+        transfer_staff_data(nren_dict)
+        transfer_nren_parent_org(nren_dict)
+        transfer_nren_sub_org(nren_dict)
+        transfer_charging_structure(nren_dict)
+        transfer_ec_projects(nren_dict)
 
 
 @click.command()
 @click.option('--config', type=click.STRING, default='config.json')
 def cli(config):
     app_config = load(open(config, 'r'))
-    _cli(app_config)
+    app = compendium_v2._create_app_with_db(app_config)
+    _cli(app_config, app)
 
 
 if __name__ == "__main__":
diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py
index ba0e4e2a..5f4cd8b2 100644
--- a/compendium_v2/publishers/survey_publisher_v1.py
+++ b/compendium_v2/publishers/survey_publisher_v1.py
@@ -11,11 +11,12 @@ import logging
 import math
 import click
 
+import compendium_v2
 from compendium_v2.environment import setup_logging
-from compendium_v2 import db, survey_db
+from compendium_v2 import survey_db
 from compendium_v2.background_task import parse_excel_data
 from compendium_v2.config import load
-from compendium_v2.db import model
+from compendium_v2.db import db, model
 from compendium_v2.survey_db import model as survey_model
 from compendium_v2.publishers import helpers
 
@@ -24,11 +25,8 @@ setup_logging()
 logger = logging.getLogger('survey-publisher-v1')
 
 
-def db_budget_migration():
-    with survey_db.session_scope() as survey_session, \
-            db.session_scope() as session:
-
-        nren_dict = helpers.get_uppercase_nren_dict(session)
+def db_budget_migration(nren_dict):
+    with survey_db.session_scope() as survey_session:
 
         # move data from Survey DB budget table
         data = survey_session.query(survey_model.Nrens)
@@ -49,7 +47,7 @@ def db_budget_migration():
                     budget=float(budget.budget),
                     year=year
                 )
-                session.merge(budget_entry)
+                db.session.merge(budget_entry)
 
         # Import the data from excel sheet to database
         exceldata = parse_excel_data.fetch_budget_excel_data()
@@ -63,165 +61,153 @@ def db_budget_migration():
                 logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})')
 
             budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year)
-            session.merge(budget_entry)
-        session.commit()
-
-
-def db_funding_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        # Import the data to database
-        data = parse_excel_data.fetch_funding_excel_data()
-
-        for (abbrev, year, client_institution,
-             european_funding,
-             gov_public_bodies,
-             commercial, other) in data:
-
-            _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
-            total = sum(_data)
-            if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
-                logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
-
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
-
-            budget_entry = model.FundingSource(
-                nren=nren_dict[abbrev],
-                year=year,
-                client_institutions=client_institution,
-                european_funding=european_funding,
-                gov_public_bodies=gov_public_bodies,
-                commercial=commercial,
-                other=other)
-            session.merge(budget_entry)
-        session.commit()
-
-
-def db_charging_structure_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        # Import the data to database
-        data = parse_excel_data.fetch_charging_structure_excel_data()
-
-        for (abbrev, year, charging_structure) in data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
-
-            charging_structure_entry = model.ChargingStructure(
-                nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
-            session.merge(charging_structure_entry)
-        session.commit()
-
-
-def db_staffing_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        staff_data = parse_excel_data.fetch_staffing_excel_data()
-
-        nren_staff_map = {}
-        for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping staff data.')
-                continue
-
-            nren = nren_dict[abbrev]
+            db.session.merge(budget_entry)
+        db.session.commit()
+
+
+def db_funding_migration(nren_dict):
+    # Import the data to database
+    data = parse_excel_data.fetch_funding_excel_data()
+
+    for (abbrev, year, client_institution,
+            european_funding,
+            gov_public_bodies,
+            commercial, other) in data:
+
+        _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
+        total = sum(_data)
+        if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
+            logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
+
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
+
+        budget_entry = model.FundingSource(
+            nren=nren_dict[abbrev],
+            year=year,
+            client_institutions=client_institution,
+            european_funding=european_funding,
+            gov_public_bodies=gov_public_bodies,
+            commercial=commercial,
+            other=other)
+        db.session.merge(budget_entry)
+    db.session.commit()
+
+
+def db_charging_structure_migration(nren_dict):
+    # Import the data to database
+    data = parse_excel_data.fetch_charging_structure_excel_data()
+
+    for (abbrev, year, charging_structure) in data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
+
+        charging_structure_entry = model.ChargingStructure(
+            nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
+        db.session.merge(charging_structure_entry)
+    db.session.commit()
+
+
+def db_staffing_migration(nren_dict):
+    staff_data = parse_excel_data.fetch_staffing_excel_data()
+
+    nren_staff_map = {}
+    for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping staff data.')
+            continue
+
+        nren = nren_dict[abbrev]
+        nren_staff_map[(nren.id, year)] = model.NrenStaff(
+            nren=nren,
+            nren_id=nren.id,
+            year=year,
+            permanent_fte=permanent_fte,
+            subcontracted_fte=subcontracted_fte,
+            technical_fte=0,
+            non_technical_fte=0
+        )
+
+    function_data = parse_excel_data.fetch_staff_function_excel_data()
+    for (abbrev, year, technical_fte, non_technical_fte) in function_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping staff function data.')
+            continue
+
+        nren = nren_dict[abbrev]
+        if (nren.id, year) in nren_staff_map:
+            nren_staff_map[(nren.id, year)].technical_fte = technical_fte
+            nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
+        else:
             nren_staff_map[(nren.id, year)] = model.NrenStaff(
                 nren=nren,
                 nren_id=nren.id,
                 year=year,
-                permanent_fte=permanent_fte,
-                subcontracted_fte=subcontracted_fte,
-                technical_fte=0,
-                non_technical_fte=0
+                permanent_fte=0,
+                subcontracted_fte=0,
+                technical_fte=technical_fte,
+                non_technical_fte=non_technical_fte
             )
 
-        function_data = parse_excel_data.fetch_staff_function_excel_data()
-        for (abbrev, year, technical_fte, non_technical_fte) in function_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping staff function data.')
-                continue
-
-            nren = nren_dict[abbrev]
-            if (nren.id, year) in nren_staff_map:
-                nren_staff_map[(nren.id, year)].technical_fte = technical_fte
-                nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
-            else:
-                nren_staff_map[(nren.id, year)] = model.NrenStaff(
-                    nren=nren,
-                    nren_id=nren.id,
-                    year=year,
-                    permanent_fte=0,
-                    subcontracted_fte=0,
-                    technical_fte=technical_fte,
-                    non_technical_fte=non_technical_fte
-                )
-
-        for nren_staff_model in nren_staff_map.values():
-            employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
-            technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
-            if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
-                logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
-                               f' FTE do not equal across employed/technical categories ({employed} != {technical})')
-
-            session.merge(nren_staff_model)
+    for nren_staff_model in nren_staff_map.values():
+        employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
+        technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
+        if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
+            logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
+                           f' FTE do not equal across employed/technical categories ({employed} != {technical})')
 
-        session.commit()
+        db.session.merge(nren_staff_model)
 
+    db.session.commit()
 
-def db_ecprojects_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
-
-        ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
-        for (abbrev, year, project) in ecproject_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
 
-            nren = nren_dict[abbrev]
-            ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
-            session.merge(ecproject_entry)
-        session.commit()
+def db_ecprojects_migration(nren_dict):
+    ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
+    for (abbrev, year, project) in ecproject_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
 
+        nren = nren_dict[abbrev]
+        ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
+        db.session.merge(ecproject_entry)
+    db.session.commit()
 
-def db_organizations_migration():
-    with db.session_scope() as session:
-        nren_dict = helpers.get_uppercase_nren_dict(session)
 
-        organization_data = parse_excel_data.fetch_organization_excel_data()
-        for (abbrev, year, org) in organization_data:
-            if abbrev not in nren_dict:
-                logger.warning(f'{abbrev} unknown. Skipping.')
-                continue
+def db_organizations_migration(nren_dict):
+    organization_data = parse_excel_data.fetch_organization_excel_data()
+    for (abbrev, year, org) in organization_data:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
 
-            nren = nren_dict[abbrev]
-            org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
-            session.merge(org_entry)
-        session.commit()
+        nren = nren_dict[abbrev]
+        org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
+        db.session.merge(org_entry)
+    db.session.commit()
 
 
-def _cli(config):
+def _cli(config, app):
     helpers.init_db(config)
-    db_budget_migration()
-    db_funding_migration()
-    db_charging_structure_migration()
-    db_staffing_migration()
-    db_ecprojects_migration()
-    db_organizations_migration()
+    with app.app_context():
+        nren_dict = helpers.get_uppercase_nren_dict()
+        db_budget_migration(nren_dict)
+        db_funding_migration(nren_dict)
+        db_charging_structure_migration(nren_dict)
+        db_staffing_migration(nren_dict)
+        db_ecprojects_migration(nren_dict)
+        db_organizations_migration(nren_dict)
 
 
 @click.command()
 @click.option('--config', type=click.STRING, default='config.json')
 def cli(config):
     app_config = load(open(config, 'r'))
+    app = compendium_v2._create_app_with_db(app_config)
     print("survey-publisher-v1 starting")
-    _cli(app_config)
+    _cli(app_config, app)
 
 
 if __name__ == "__main__":
diff --git a/compendium_v2/routes/budget.py b/compendium_v2/routes/budget.py
index 763b3af3..1f4ff850 100644
--- a/compendium_v2/routes/budget.py
+++ b/compendium_v2/routes/budget.py
@@ -1,28 +1,17 @@
 import logging
 from typing import Any
 
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-from compendium_v2 import db
-from compendium_v2.db import model
+from compendium_v2.db import db
+from compendium_v2.db.model import BudgetEntry
 from compendium_v2.routes import common
 
-routes = Blueprint('budget', __name__)
-
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
-
 
+routes = Blueprint('budget', __name__)
 logger = logging.getLogger(__name__)
 
-col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061',
-           '#bd7ebe', '#ffb55a', '#ffee65',
-           '#beb9db', '#fdcce5', '#8bd3c7']
-
 BUDGET_RESPONSE_SCHEMA = {
     '$schema': 'http://json-schema.org/draft-07/schema#',
 
@@ -58,15 +47,15 @@ def budget_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.BudgetEntry):
+    def _extract_data(entry: BudgetEntry):
         return {
             'NREN': entry.nren.name,
             'BUDGET': float(entry.budget),
             'BUDGET_YEAR': entry.year,
         }
 
-    with db.session_scope() as session:
-        entries = sorted([_extract_data(entry)
-                          for entry in session.query(model.BudgetEntry)],
-                         key=lambda d: (d['BUDGET_YEAR'], d['NREN']))
+    entries = sorted(
+        [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))],
+        key=lambda d: (d['BUDGET_YEAR'], d['NREN'])
+    )
     return jsonify(entries)
diff --git a/compendium_v2/routes/charging.py b/compendium_v2/routes/charging.py
index 0b2c2384..57e8114f 100644
--- a/compendium_v2/routes/charging.py
+++ b/compendium_v2/routes/charging.py
@@ -1,21 +1,15 @@
 import logging
-
-from flask import Blueprint, jsonify, current_app
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('charging', __name__)
-
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ChargingStructure
+from compendium_v2.routes import common
 
 
+routes = Blueprint('charging', __name__)
 logger = logging.getLogger(__name__)
 
 CHARGING_STRUCTURE_RESPONSE_SCHEMA = {
@@ -53,16 +47,15 @@ def charging_structure_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.ChargingStructure):
+    def _extract_data(entry: ChargingStructure):
         return {
             'NREN': entry.nren.name,
             'YEAR': int(entry.year),
             'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None,
         }
 
-    with db.session_scope() as session:
-        entries = sorted([_extract_data(entry)
-                          for entry in session.query(model.ChargingStructure)
-                         .all()],
-                         key=lambda d: (d['NREN'], d['YEAR']))
+    entries = sorted(
+        [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))],
+        key=lambda d: (d['NREN'], d['YEAR'])
+    )
     return jsonify(entries)
diff --git a/compendium_v2/routes/ec_projects.py b/compendium_v2/routes/ec_projects.py
index 7114718d..b58d8931 100644
--- a/compendium_v2/routes/ec_projects.py
+++ b/compendium_v2/routes/ec_projects.py
@@ -1,22 +1,15 @@
 import logging
-
-from flask import Blueprint, jsonify, current_app
-
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('ec-projects', __name__)
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ECProject
+from compendium_v2.routes import common
 
 
+routes = Blueprint('ec-projects', __name__)
 logger = logging.getLogger(__name__)
 
 EC_PROJECTS_RESPONSE_SCHEMA = {
@@ -55,13 +48,12 @@ def ec_projects_view() -> Any:
     :return:
     """
 
-    def _extract_project(entry: model.ECProject):
+    def _extract_project(entry: ECProject):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
             'project': entry.project
         }
 
-    with db.session_scope() as session:
-        result = [_extract_project(project) for project in session.query(model.ECProject)]
+    result = [_extract_project(project) for project in db.session.scalars(select(ECProject))]
     return jsonify(result)
diff --git a/compendium_v2/routes/funding.py b/compendium_v2/routes/funding.py
index ed0e26c8..c02bf136 100644
--- a/compendium_v2/routes/funding.py
+++ b/compendium_v2/routes/funding.py
@@ -1,22 +1,15 @@
 import logging
 
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-from compendium_v2 import db
 from compendium_v2.routes import common
-from compendium_v2.db import model
+from compendium_v2.db import db
+from compendium_v2.db.model import FundingSource
 from typing import Any
 
-routes = Blueprint('funding', __name__)
-
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
-
 
+routes = Blueprint('funding', __name__)
 logger = logging.getLogger(__name__)
 
 FUNDING_RESPONSE_SCHEMA = {
@@ -59,7 +52,7 @@ def funding_source_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.FundingSource):
+    def _extract_data(entry: FundingSource):
         return {
             'NREN': entry.nren.name,
             'YEAR': entry.year,
@@ -70,8 +63,8 @@ def funding_source_view() -> Any:
             'OTHER': float(entry.other)
         }
 
-    with db.session_scope() as session:
-        entries = sorted([_extract_data(entry)
-                         for entry in session.query(model.FundingSource)],
-                         key=lambda d: (d['NREN'], d['YEAR']))
+    entries = sorted(
+        [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))],
+        key=lambda d: (d['NREN'], d['YEAR'])
+    )
     return jsonify(entries)
diff --git a/compendium_v2/routes/organization.py b/compendium_v2/routes/organization.py
index 61a43354..8e8ebc8d 100644
--- a/compendium_v2/routes/organization.py
+++ b/compendium_v2/routes/organization.py
@@ -1,22 +1,15 @@
 import logging
-
-from flask import Blueprint, jsonify, current_app
-
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('organization', __name__)
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ParentOrganization, SubOrganization
+from compendium_v2.routes import common
 
 
+routes = Blueprint('organization', __name__)
 logger = logging.getLogger(__name__)
 
 ORGANIZATION_RESPONSE_SCHEMA = {
@@ -71,15 +64,14 @@ def parent_organization_view() -> Any:
     :return:
     """
 
-    def _extract_parent(entry: model.ParentOrganization):
+    def _extract_parent(entry: ParentOrganization):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
             'name': entry.organization
         }
 
-    with db.session_scope() as session:
-        result = [_extract_parent(org) for org in session.query(model.ParentOrganization)]
+    result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))]
     return jsonify(result)
 
 
@@ -98,7 +90,7 @@ def sub_organization_view() -> Any:
     :return:
     """
 
-    def _extract_sub(entry: model.SubOrganization):
+    def _extract_sub(entry: SubOrganization):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
@@ -106,6 +98,5 @@ def sub_organization_view() -> Any:
             'role': entry.role
         }
 
-    with db.session_scope() as session:
-        result = [_extract_sub(org) for org in session.query(model.SubOrganization)]
+    result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))]
     return jsonify(result)
diff --git a/compendium_v2/routes/staff.py b/compendium_v2/routes/staff.py
index 73e79c48..b8478266 100644
--- a/compendium_v2/routes/staff.py
+++ b/compendium_v2/routes/staff.py
@@ -1,22 +1,15 @@
 import logging
 
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
 
-from compendium_v2 import db
+from compendium_v2.db import db
+from compendium_v2.db.model import NREN, NrenStaff
 from compendium_v2.routes import common
-from compendium_v2.db import model
 from typing import Any
 
-routes = Blueprint('staff', __name__)
-
-
-@routes.before_request
-def before_request():
-    config = current_app.config['CONFIG_PARAMS']
-    dsn_prn = config['SQLALCHEMY_DATABASE_URI']
-    db.init_db_model(dsn_prn)
-
 
+routes = Blueprint('staff', __name__)
 logger = logging.getLogger(__name__)
 
 STAFF_RESPONSE_SCHEMA = {
@@ -57,7 +50,7 @@ def staff_view() -> Any:
     :return:
     """
 
-    def _extract_data(entry: model.NrenStaff):
+    def _extract_data(entry: NrenStaff):
         return {
             'nren': entry.nren.name,
             'year': entry.year,
@@ -67,7 +60,6 @@ def staff_view() -> Any:
             'non_technical_fte': float(entry.non_technical_fte)
         }
 
-    with db.session_scope() as session:
-        entries = [_extract_data(entry) for entry in session.query(
-            model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())]
+    entries = [_extract_data(entry) for entry in db.session.scalars(
+        select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))]
     return jsonify(entries)
diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py
index ce48aaa0..1550ddcb 100644
--- a/compendium_v2/survey_db/__init__.py
+++ b/compendium_v2/survey_db/__init__.py
@@ -31,11 +31,6 @@ def session_scope(
         session.close()
 
 
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
-    return (f'postgresql://{db_username}:{db_password}'
-            f'@{db_hostname}:{port}/{db_name}')
-
-
 def init_db_model(dsn):
     global _SESSION_MAKER
 
diff --git a/requirements.txt b/requirements.txt
index 12ec2076..98804e02 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ click~=8.1
 jsonschema~=4.17
 flask~=2.2
 flask-cors~=3.0
+flask-sqlalchemy~=3.0
 openpyxl~=3.1
 psycopg2-binary~=2.9
 SQLAlchemy~=2.0
diff --git a/setup.py b/setup.py
index 213288e5..d051ddd7 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,7 @@ setup(
         'jsonschema~=4.17',
         'flask~=2.2',
         'flask-cors~=3.0',
+        'flask-sqlalchemy~=3.0',
         'openpyxl~=3.1',
         'psycopg2-binary~=2.9',
         'SQLAlchemy~=2.0',
diff --git a/test/conftest.py b/test/conftest.py
index 79186212..ba6b9a43 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,20 +1,15 @@
-import json
+import csv
 import os
-import tempfile
-import random
-
 import pytest
-import compendium_v2
-from compendium_v2 import db
-from compendium_v2.db import model
-from compendium_v2 import survey_db
-from compendium_v2.survey_db import model as survey_model
+import random
 
 from sqlalchemy import create_engine
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.pool import StaticPool
 
-import csv
+import compendium_v2
+from compendium_v2.db import db, model
+from compendium_v2.survey_db import model as survey_model
 
 
 def _test_data_csv(filename):
@@ -43,61 +38,29 @@ def mocked_survey_db(mocker):
 
 
 @pytest.fixture
-def mocked_db(mocker):
-    # cf. https://stackoverflow.com/a/33057675
-    engine = create_engine(
-        'sqlite://',
-        connect_args={'check_same_thread': False},
-        poolclass=StaticPool,
-        echo=False)
-    model.Base.metadata.create_all(engine)
-    mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine))
-    mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None)
-    mocker.patch('compendium_v2.migrate_database', lambda config: None)
-
-
-@pytest.fixture
-def test_budget_data():
-    with db.session_scope() as session:
+def test_budget_data(app):
+    with app.app_context():
         data = [row for row in _test_data_csv("BudgetTestData.csv")]
         nren_names = set([row["nren"] for row in data])
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
             budget = row["budget"]
             year = row["year"]
 
-            session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
-
-    with survey_db.session_scope() as session:
-        data = _test_data_csv("BudgetTestData.csv")
-        nrens = set()
-        budgets_data = []
-        for row in data:
-            nren = row["nren"]
-            budget = row["budget"]
-            year = row["year"]
-            country_code = row["nren"]
-
-            nrens.add(nren)
-
-            budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code))
-
-        for nren in nrens:
-            session.add(survey_model.Nrens(abbreviation=nren, country_code=nren))
-
-        session.add_all(budgets_data)
+            db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
+        db.session.commit()
 
 
 @pytest.fixture
-def test_funding_source_data():
-    with db.session_scope() as session:
+def test_funding_source_data(app):
+    with app.app_context():
         data = [row for row in _test_data_csv("FundingSourceTestData.csv")]
         nren_names = set([row["nren"] for row in data])
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
@@ -108,7 +71,7 @@ def test_funding_source_data():
             commercial = row["commercial"]
             other = row["other"]
 
-            session.add(
+            db.session.add(
                 model.FundingSource(
                     nren=nren, year=year,
                     client_institutions=client,
@@ -117,10 +80,11 @@ def test_funding_source_data():
                     commercial=commercial,
                     other=other)
             )
+        db.session.commit()
 
 
 @pytest.fixture
-def test_staff_data():
+def test_staff_data(app):
 
     # generator of random test data for 5 years and 100 nrens
     def _generate_rows():
@@ -135,12 +99,12 @@ def test_staff_data():
                     "non_technical_fte": random.randint(0, 100)
                 }
 
-    with db.session_scope() as session:
+    with app.app_context():
 
         data = list(_generate_rows())
 
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
@@ -150,7 +114,7 @@ def test_staff_data():
             technical_fte = row["technical_fte"]
             non_technical_fte = row["non_technical_fte"]
 
-            session.add(
+            db.session.add(
                 model.NrenStaff(
                     nren=nren,
                     year=year,
@@ -160,30 +124,29 @@ def test_staff_data():
                     non_technical_fte=non_technical_fte
                 )
             )
+        db.session.commit()
 
 
 @pytest.fixture
-def data_config_filename(dummy_config):
-    with tempfile.NamedTemporaryFile() as f:
-        f.write(json.dumps(dummy_config).encode('utf-8'))
-        f.flush()
-        yield f.name
+def app(dummy_config):
+    app = compendium_v2._create_app_with_db(dummy_config)
+    with app.app_context():
+        db.create_all()
+    yield app
 
 
 @pytest.fixture
-def client(data_config_filename, mocked_db, mocked_survey_db):
-    os.environ['SETTINGS_FILENAME'] = data_config_filename
-    with compendium_v2.create_app().test_client() as c:
-        yield c
+def client(app):
+    return app.test_client()
 
 
 @pytest.fixture
-def test_charging_structure_data():
-    with db.session_scope() as session:
+def test_charging_structure_data(app):
+    with app.app_context():
         data = [row for row in _test_data_csv("ChargingStructureTestData.csv")]
         nren_names = set([row["nren"] for row in data])
         nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for row in data:
             nren = nren_dict[row["nren"]]
@@ -192,15 +155,16 @@ def test_charging_structure_data():
             if fee_type == "null":
                 fee_type = None
 
-            session.add(
+            db.session.add(
                 model.ChargingStructure(
                     nren=nren, year=year,
                     fee_type=fee_type)
             )
+        db.session.commit()
 
 
 @pytest.fixture
-def test_organization_data():
+def test_organization_data(app):
     def _generate_sub_org_data():
         for nren in ["nren" + str(i) for i in range(1, 50)]:
             for year in range(2016, 2021):
@@ -220,21 +184,21 @@ def test_organization_data():
                     'name': 'org' + str(year)
                 }
 
-    with db.session_scope() as session:
+    with app.app_context():
 
         org_data = list(_generate_org_data())
         sub_org_data = list(_generate_sub_org_data())
 
         nren_dict = {nren_name: model.NREN(name=nren_name)
                      for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for org in org_data:
             nren = nren_dict[org["nren"]]
             year = org["year"]
             name = org["name"]
 
-            session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
+            db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
 
         for sub_org in sub_org_data:
             nren = nren_dict[sub_org["nren"]]
@@ -242,13 +206,13 @@ def test_organization_data():
             name = sub_org["name"]
             role = sub_org["role"]
 
-            session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
+            db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
 
-        session.commit()
+        db.session.commit()
 
 
 @pytest.fixture
-def test_ec_project_data():
+def test_ec_project_data(app):
     def _generate_ec_project_data():
         for nren in ["nren" + str(i) for i in range(1, 50)]:
             for year in range(2016, 2021):
@@ -264,19 +228,19 @@ def test_ec_project_data():
                         'project': 'ec_project2',
                     }
 
-    with db.session_scope() as session:
+    with app.app_context():
 
         ec_project_data = list(_generate_ec_project_data())
 
         nren_dict = {nren_name: model.NREN(name=nren_name)
                      for nren_name in set(d['nren'] for d in ec_project_data)}
-        session.add_all(nren_dict.values())
+        db.session.add_all(nren_dict.values())
 
         for ec_project in ec_project_data:
             nren = nren_dict[ec_project["nren"]]
             year = ec_project["year"]
             project = ec_project["project"]
 
-            session.add(model.ECProject(nren=nren, year=year, project=project))
+            db.session.add(model.ECProject(nren=nren, year=year, project=project))
 
-        session.commit()
+        db.session.commit()
diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py
index ec8802ed..433def2c 100644
--- a/test/test_survey_publisher_2022.py
+++ b/test/test_survey_publisher_2022.py
@@ -1,5 +1,4 @@
-from compendium_v2 import db
-from compendium_v2.db import model
+from compendium_v2.db import db, model
 from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \
     StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion
 
@@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority"
         ]
 
 
-def test_publisher(client, mocker, dummy_config):
+def test_publisher(app, mocker, dummy_config):
     global org_data
 
     def get_rows_as_tuples(*args, **kwargs):
@@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config):
     mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data)
     mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data)
 
-    with db.session_scope() as session:
-        nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
-        session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+    nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
+    with app.app_context():
+        db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+        db.session.commit()
 
-    _cli(dummy_config)
+    _cli(dummy_config, app)
 
-    with db.session_scope() as session:
-        budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
+    with app.app_context():
+        budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
         assert len(budgets) == 3
         assert budgets[0].nren.name.lower() == 'nren1'
         assert budgets[0].budget == 100
 
-        funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
+        funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
         assert len(funding_sources) == 3
         assert funding_sources[0].nren.name.lower() == 'nren1'
         assert funding_sources[0].client_institutions == 10
@@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config):
         assert funding_sources[2].european_funding == 30
         assert funding_sources[2].other == 30
 
-        staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
+        staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
 
         assert len(staff_data) == 3
         assert staff_data[0].nren.name.lower() == 'nren1'
@@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config):
         assert staff_data[2].permanent_fte == 30
         assert staff_data[2].subcontracted_fte == 0
 
-        _org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
+        _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
 
         assert len(_org_data) == 2
         assert _org_data[0].nren.name.lower() == 'nren1'
@@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config):
         assert _org_data[1].nren.name.lower() == 'nren3'
         assert _org_data[1].organization == 'Org3'
 
-        charging_structures = session.query(model.ChargingStructure).order_by(
+        charging_structures = db.session.query(model.ChargingStructure).order_by(
             model.ChargingStructure.nren_id.asc()).all()
         assert len(charging_structures) == 3
         assert charging_structures[0].nren.name.lower() == 'nren1'
@@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config):
         assert charging_structures[2].nren.name.lower() == 'nren3'
         assert charging_structures[2].fee_type == model.FeeType.other
 
-        _ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
+        _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
 
         assert len(_ec_data) == 3
         assert _ec_data[0].nren.name.lower() == 'nren2'
diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py
index c7ebc2d9..a6dbbaa1 100644
--- a/test/test_survey_publisher_v1.py
+++ b/test/test_survey_publisher_v1.py
@@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli
 EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx")
 
 
-def test_publisher(client, mocker, dummy_config):
+def test_publisher(mocked_survey_db, app, mocker, dummy_config):
     mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE)
 
-    with db.session_scope() as session:
+    with app.app_context():
         nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
-        session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+        db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+        db.session.commit()
 
-    _cli(dummy_config)
+    _cli(dummy_config, app)
 
-    with db.session_scope() as session:
-        budget_count = session.query(model.BudgetEntry.year).count()
+    with app.app_context():
+        budget_count = db.session.query(model.BudgetEntry.year).count()
         assert budget_count
-        funding_source_count = session.query(model.FundingSource.year).count()
+        funding_source_count = db.session.query(model.FundingSource.year).count()
         assert funding_source_count
-        charging_structure_count = session.query(model.ChargingStructure.year).count()
+        charging_structure_count = db.session.query(model.ChargingStructure.year).count()
         assert charging_structure_count
-        staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
+        staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
 
         # data should only be saved for the NRENs we have saved in the database
         staff_data_nrens = set([staff.nren.name for staff in staff_data])
@@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config):
         assert kifu_data[5].technical_fte == 133
         assert kifu_data[5].non_technical_fte == 45
 
-        ecproject_data = session.query(model.ECProject).all()
+        ecproject_data = db.session.query(model.ECProject).all()
         # test a couple of random entries
         surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017]
         assert len(surf2017) == 1
@@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config):
         assert len(kifu2019) == 4
         assert kifu2019[3].project == 'SuperHeroes for Science'
 
-        parent_data = session.query(model.ParentOrganization).all()
+        parent_data = db.session.query(model.ParentOrganization).all()
         # test a random entry
         asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021]
         assert len(asnet2021) == 1
-- 
GitLab


From 14411075c87043bcd91d2ced0ba1efdd5f150ade Mon Sep 17 00:00:00 2001
From: Remco Tukker <remco.tukker@geant.org>
Date: Thu, 4 May 2023 15:50:41 +0200
Subject: [PATCH 2/4] use flask-migrate for the migrations

---
 README.md                                   |  6 +-
 compendium_v2/__init__.py                   | 13 ++---
 compendium_v2/alembic.ini                   | 10 ----
 compendium_v2/migrations/README             |  1 +
 compendium_v2/migrations/__init__.py        |  0
 compendium_v2/migrations/alembic.ini        | 50 ++++++++++++++++
 compendium_v2/migrations/env.py             | 63 ++++++++++++++++-----
 compendium_v2/migrations/migration_utils.py | 41 --------------
 requirements.txt                            |  1 +
 setup.py                                    |  1 +
 10 files changed, 111 insertions(+), 75 deletions(-)
 delete mode 100644 compendium_v2/alembic.ini
 create mode 100644 compendium_v2/migrations/README
 delete mode 100644 compendium_v2/migrations/__init__.py
 create mode 100644 compendium_v2/migrations/alembic.ini
 delete mode 100644 compendium_v2/migrations/migration_utils.py

diff --git a/README.md b/README.md
index 81b1b0b6..dd9a7dce 100644
--- a/README.md
+++ b/README.md
@@ -63,11 +63,13 @@ survey-publisher-2022
 ## Creating a db migration after editing the sqlalchemy models
 
 ```bash
-cd compendium_v2
-alembic revision --autogenerate -m "description"
+flask db migrate -m "description"
 ```
 
 Then go to the created migration file to make any necessary additions, for example to migrate data.
 Also see https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect
+Flask-migrate sets `compare_type=True` by default.
 
 Note that starting the application applies all upgrades.
+This also happens when running `flask db` commands such as `flask db downgrade`,
+so if you want to downgrade 2 or more versions you need to do so in one command, eg by specifying the revision number.
diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py
index 7a69bf9e..4717a486 100644
--- a/compendium_v2/__init__.py
+++ b/compendium_v2/__init__.py
@@ -6,15 +6,11 @@ import os
 
 from flask import Flask
 from flask_cors import CORS  # for debugging
+# the currently available stubs for flask_migrate are old (they depend on sqlalchemy 1.4 types)
+from flask_migrate import Migrate, upgrade  # type: ignore
 
 from compendium_v2 import config, environment
 from compendium_v2.db import db
-from compendium_v2.migrations import migration_utils
-
-
-def migrate_database(config: dict) -> None:
-    dsn = config['SQLALCHEMY_DATABASE_URI']
-    migration_utils.upgrade(dsn)
 
 
 def _create_app(app_config) -> Flask:
@@ -56,11 +52,14 @@ def create_app() -> Flask:
 
     app = _create_app_with_db(app_config)
 
+    Migrate(app, db, directory=os.path.join(os.path.dirname(__file__), 'migrations'))
+
     logging.info('Flask app initialized')
 
     environment.setup_logging()
 
     # run migrations on startup
-    migrate_database(app_config)
+    with app.app_context():
+        upgrade()
 
     return app
diff --git a/compendium_v2/alembic.ini b/compendium_v2/alembic.ini
deleted file mode 100644
index 2145863b..00000000
--- a/compendium_v2/alembic.ini
+++ /dev/null
@@ -1,10 +0,0 @@
-# A generic, single database configuration.
-
-# only needed for generating new revision scripts
-[alembic]
-# make sure the right line is un / commented depending on which schema you want
-# a migration for
-script_location = migrations
-# script_location = cachedb_migrations
-# change this to run migrations from the command line
-sqlalchemy.url = postgresql://compendium:compendium321@localhost:65000/compendium
diff --git a/compendium_v2/migrations/README b/compendium_v2/migrations/README
new file mode 100644
index 00000000..0e048441
--- /dev/null
+++ b/compendium_v2/migrations/README
@@ -0,0 +1 @@
+Single-database configuration for Flask.
diff --git a/compendium_v2/migrations/__init__.py b/compendium_v2/migrations/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/compendium_v2/migrations/alembic.ini b/compendium_v2/migrations/alembic.ini
new file mode 100644
index 00000000..ec9d45c2
--- /dev/null
+++ b/compendium_v2/migrations/alembic.ini
@@ -0,0 +1,50 @@
+# A generic, single database configuration.
+
+[alembic]
+# template used to generate migration files
+# file_template = %%(rev)s_%%(slug)s
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic,flask_migrate
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[logger_flask_migrate]
+level = INFO
+handlers =
+qualname = flask_migrate
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S
diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py
index 5ea9c8d3..e2408681 100644
--- a/compendium_v2/migrations/env.py
+++ b/compendium_v2/migrations/env.py
@@ -1,10 +1,9 @@
 import logging
+from logging.config import fileConfig
 
-from sqlalchemy import engine_from_config
-from sqlalchemy import pool
+from flask import current_app
 
 from alembic import context
-from compendium_v2.db import metadata_obj
 
 # this is the Alembic Config object, which provides
 # access to the values within the .ini file in use.
@@ -12,13 +11,34 @@ config = context.config
 
 # Interpret the config file for Python logging.
 # This line sets up loggers basically.
-logging.basicConfig(level=logging.INFO)
+if config.config_file_name is not None:
+    fileConfig(config.config_file_name)
+logger = logging.getLogger('alembic.env')
+
+
+def get_engine():
+    try:
+        # this works with Flask-SQLAlchemy<3 and Alchemical
+        return current_app.extensions['migrate'].db.get_engine()
+    except TypeError:
+        # this works with Flask-SQLAlchemy>=3
+        return current_app.extensions['migrate'].db.engine
+
+
+def get_engine_url():
+    try:
+        return get_engine().url.render_as_string(hide_password=False).replace(
+            '%', '%%')
+    except AttributeError:
+        return str(get_engine().url).replace('%', '%%')
+
 
 # add your model's MetaData object here
 # for 'autogenerate' support
 # from myapp import mymodel
 # target_metadata = mymodel.Base.metadata
-target_metadata = metadata_obj
+config.set_main_option('sqlalchemy.url', get_engine_url())
+target_db = current_app.extensions['migrate'].db
 
 # other values from the config, defined by the needs of env.py,
 # can be acquired:
@@ -26,6 +46,12 @@ target_metadata = metadata_obj
 # ... etc.
 
 
+def get_metadata():
+    if hasattr(target_db, 'metadatas'):
+        return target_db.metadatas[None]
+    return target_db.metadata
+
+
 def run_migrations_offline():
     """Run migrations in 'offline' mode.
 
@@ -40,10 +66,7 @@ def run_migrations_offline():
     """
     url = config.get_main_option("sqlalchemy.url")
     context.configure(
-        url=url,
-        target_metadata=target_metadata,
-        literal_binds=True,
-        dialect_opts={"paramstyle": "named"},
+        url=url, target_metadata=get_metadata(), literal_binds=True
     )
 
     with context.begin_transaction():
@@ -57,15 +80,25 @@ def run_migrations_online():
     and associate a connection with the context.
 
     """
-    connectable = engine_from_config(
-        config.get_section(config.config_ini_section),
-        prefix="sqlalchemy.",
-        poolclass=pool.NullPool,
-    )
+
+    # this callback is used to prevent an auto-migration from being generated
+    # when there are no changes to the schema
+    # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html
+    def process_revision_directives(context, revision, directives):
+        if getattr(config.cmd_opts, 'autogenerate', False):
+            script = directives[0]
+            if script.upgrade_ops.is_empty():
+                directives[:] = []
+                logger.info('No changes in schema detected.')
+
+    connectable = get_engine()
 
     with connectable.connect() as connection:
         context.configure(
-            connection=connection, target_metadata=target_metadata
+            connection=connection,
+            target_metadata=get_metadata(),
+            process_revision_directives=process_revision_directives,
+            **current_app.extensions['migrate'].configure_args
         )
 
         with context.begin_transaction():
diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py
deleted file mode 100644
index f25f03c9..00000000
--- a/compendium_v2/migrations/migration_utils.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import logging
-import os
-
-from alembic.config import Config
-from alembic import command
-
-logger = logging.getLogger(__name__)
-DEFAULT_MIGRATIONS_DIRECTORY = os.path.dirname(__file__)
-
-
-def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY):
-    """
-    migrate db to head version
-
-    cf. https://stackoverflow.com/a/43530495,
-        https://stackoverflow.com/a/54402853
-
-    :param dsn: dsn string, passed to alembic
-    :param migrations_directory: full path to migrations directory
-        (default is this directory)
-    :return:
-    """
-    alembic_config = Config()
-    alembic_config.set_main_option('script_location', migrations_directory)
-    alembic_config.set_main_option('sqlalchemy.url', dsn)
-    command.upgrade(alembic_config, 'head')
-
-
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
-    return (f'postgresql://{db_username}:{db_password}'
-            f'@{db_hostname}:{port}/{db_name}')
-
-
-if __name__ == "__main__":
-    logging.basicConfig(level=logging.DEBUG)
-    upgrade(postgresql_dsn(
-        db_username='compendium',
-        db_password='compendium321',
-        db_hostname='localhost',
-        db_name='compendium',
-        port=65000))
diff --git a/requirements.txt b/requirements.txt
index 98804e02..f49a7960 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ click~=8.1
 jsonschema~=4.17
 flask~=2.2
 flask-cors~=3.0
+flask-migrate~=4.0
 flask-sqlalchemy~=3.0
 openpyxl~=3.1
 psycopg2-binary~=2.9
diff --git a/setup.py b/setup.py
index 48a37f1b..52acb8cb 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,7 @@ setup(
         'jsonschema~=4.17',
         'flask~=2.2',
         'flask-cors~=3.0',
+        'flask-migrate~=4.0',
         'flask-sqlalchemy~=3.0',
         'openpyxl~=3.1',
         'psycopg2-binary~=2.9',
-- 
GitLab


From 6b736da7efa507be8cbb92d42834b61a83093e3c Mon Sep 17 00:00:00 2001
From: Remco Tukker <remco.tukker@geant.org>
Date: Thu, 4 May 2023 17:04:23 +0200
Subject: [PATCH 3/4] use sqlalchemy2 syntax everywhere

---
 .../publishers/survey_publisher_2022.py       | 14 +++++-----
 .../publishers/survey_publisher_v1.py         |  4 ++-
 test/test_survey_publisher_2022.py            | 27 ++++++++++++++-----
 test/test_survey_publisher_v1.py              | 14 +++++-----
 4 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py
index f5a5cf76..5e449c3e 100644
--- a/compendium_v2/publishers/survey_publisher_2022.py
+++ b/compendium_v2/publishers/survey_publisher_2022.py
@@ -13,7 +13,7 @@ import math
 import json
 import html
 
-from sqlalchemy import text
+from sqlalchemy import text, delete
 from collections import defaultdict
 
 import compendium_v2
@@ -228,10 +228,10 @@ def transfer_staff_data(nren_dict):
     for nren_name, nren_info in data.items():
         if sum([nren_info[question] for question in StaffQuestion]) == 0:
             logger.info(f'{nren_name} has no staff data. Deleting if exists.')
-            db.session.query(model.NrenStaff).filter(
+            db.session.execute(delete(model.NrenStaff).where(
                 model.NrenStaff.nren_id == nren_dict[nren_name].id,
-                model.NrenStaff.year == 2022,
-            ).delete()
+                model.NrenStaff.year == 2022
+            ))
             continue
 
         employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
@@ -364,9 +364,9 @@ def transfer_charging_structure(nren_dict):
 
 def transfer_ec_projects(nren_dict):
     # delete all existing EC projects, in case something changed
-    db.session.query(model.ECProject).filter(
-        model.ECProject.year == 2022,
-    ).delete()
+    db.session.execute(
+        delete(model.ECProject).where(model.ECProject.year == 2022)
+    )
 
     rows = query_question(ECQuestion.EC_PROJECT)
     for row in rows:
diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py
index dce1ba58..c55926bb 100644
--- a/compendium_v2/publishers/survey_publisher_v1.py
+++ b/compendium_v2/publishers/survey_publisher_v1.py
@@ -11,6 +11,8 @@ import logging
 import math
 import click
 
+from sqlalchemy import select
+
 import compendium_v2
 from compendium_v2.environment import setup_logging
 from compendium_v2 import survey_db
@@ -29,7 +31,7 @@ def db_budget_migration(nren_dict):
     with survey_db.session_scope() as survey_session:
 
         # move data from Survey DB budget table
-        data = survey_session.query(survey_model.Nrens)
+        data = survey_session.scalars(select(survey_model.Nrens))
         for nren in data:
             for budget in nren.budgets:
                 abbrev = nren.abbreviation.upper()
diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py
index 433def2c..3216caf9 100644
--- a/test/test_survey_publisher_2022.py
+++ b/test/test_survey_publisher_2022.py
@@ -1,3 +1,5 @@
+from sqlalchemy import select
+
 from compendium_v2.db import db, model
 from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \
     StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion
@@ -193,12 +195,16 @@ def test_publisher(app, mocker, dummy_config):
     _cli(dummy_config, app)
 
     with app.app_context():
-        budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
+        budgets = db.session.scalars(
+            select(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc())
+        ).all()
         assert len(budgets) == 3
         assert budgets[0].nren.name.lower() == 'nren1'
         assert budgets[0].budget == 100
 
-        funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
+        funding_sources = db.session.scalars(
+            select(model.FundingSource).order_by(model.FundingSource.nren_id.asc())
+        ).all()
         assert len(funding_sources) == 3
         assert funding_sources[0].nren.name.lower() == 'nren1'
         assert funding_sources[0].client_institutions == 10
@@ -215,7 +221,9 @@ def test_publisher(app, mocker, dummy_config):
         assert funding_sources[2].european_funding == 30
         assert funding_sources[2].other == 30
 
-        staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
+        staff_data = db.session.scalars(
+            select(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc())
+        ).all()
 
         assert len(staff_data) == 3
         assert staff_data[0].nren.name.lower() == 'nren1'
@@ -236,7 +244,9 @@ def test_publisher(app, mocker, dummy_config):
         assert staff_data[2].permanent_fte == 30
         assert staff_data[2].subcontracted_fte == 0
 
-        _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
+        _org_data = db.session.scalars(
+            select(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc())
+        ).all()
 
         assert len(_org_data) == 2
         assert _org_data[0].nren.name.lower() == 'nren1'
@@ -245,8 +255,9 @@ def test_publisher(app, mocker, dummy_config):
         assert _org_data[1].nren.name.lower() == 'nren3'
         assert _org_data[1].organization == 'Org3'
 
-        charging_structures = db.session.query(model.ChargingStructure).order_by(
-            model.ChargingStructure.nren_id.asc()).all()
+        charging_structures = db.session.scalars(
+            select(model.ChargingStructure).order_by(model.ChargingStructure.nren_id.asc())
+        ).all()
         assert len(charging_structures) == 3
         assert charging_structures[0].nren.name.lower() == 'nren1'
         assert charging_structures[0].fee_type == model.FeeType.no_charge
@@ -255,7 +266,9 @@ def test_publisher(app, mocker, dummy_config):
         assert charging_structures[2].nren.name.lower() == 'nren3'
         assert charging_structures[2].fee_type == model.FeeType.other
 
-        _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
+        _ec_data = db.session.scalars(
+            select(model.ECProject).order_by(model.ECProject.nren_id.asc())
+        ).all()
 
         assert len(_ec_data) == 3
         assert _ec_data[0].nren.name.lower() == 'nren2'
diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py
index a6dbbaa1..88def54b 100644
--- a/test/test_survey_publisher_v1.py
+++ b/test/test_survey_publisher_v1.py
@@ -1,5 +1,7 @@
 import os
 
+from sqlalchemy import select, func
+
 from compendium_v2 import db
 from compendium_v2.db import model
 from compendium_v2.publishers.survey_publisher_v1 import _cli
@@ -18,13 +20,13 @@ def test_publisher(mocked_survey_db, app, mocker, dummy_config):
     _cli(dummy_config, app)
 
     with app.app_context():
-        budget_count = db.session.query(model.BudgetEntry.year).count()
+        budget_count = db.session.scalar(select(func.count(model.BudgetEntry.year)))
         assert budget_count
-        funding_source_count = db.session.query(model.FundingSource.year).count()
+        funding_source_count = db.session.scalar(select(func.count(model.FundingSource.year)))
         assert funding_source_count
-        charging_structure_count = db.session.query(model.ChargingStructure.year).count()
+        charging_structure_count = db.session.scalar(select(func.count(model.ChargingStructure.year)))
         assert charging_structure_count
-        staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
+        staff_data = db.session.scalars(select(model.NrenStaff).order_by(model.NrenStaff.year.asc())).all()
 
         # data should only be saved for the NRENs we have saved in the database
         staff_data_nrens = set([staff.nren.name for staff in staff_data])
@@ -70,7 +72,7 @@ def test_publisher(mocked_survey_db, app, mocker, dummy_config):
         assert kifu_data[5].technical_fte == 133
         assert kifu_data[5].non_technical_fte == 45
 
-        ecproject_data = db.session.query(model.ECProject).all()
+        ecproject_data = db.session.scalars(select(model.ECProject)).all()
         # test a couple of random entries
         surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017]
         assert len(surf2017) == 1
@@ -84,7 +86,7 @@ def test_publisher(mocked_survey_db, app, mocker, dummy_config):
         assert len(kifu2019) == 4
         assert kifu2019[3].project == 'SuperHeroes for Science'
 
-        parent_data = db.session.query(model.ParentOrganization).all()
+        parent_data = db.session.scalars(select(model.ParentOrganization)).all()
         # test a random entry
         asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021]
         assert len(asnet2021) == 1
-- 
GitLab


From 0f775788d87de4966a37ce50fc3888bbdd3ba8ac Mon Sep 17 00:00:00 2001
From: Remco Tukker <remco.tukker@geant.org>
Date: Thu, 4 May 2023 21:20:49 +0200
Subject: [PATCH 4/4] also use flask-sqlalchemy for the survey db

---
 compendium_v2/__init__.py                     |  5 ++
 compendium_v2/publishers/helpers.py           |  6 --
 .../publishers/survey_publisher_2022.py       | 15 ++--
 .../publishers/survey_publisher_v1.py         | 69 +++++++++----------
 compendium_v2/survey_db/__init__.py           | 40 -----------
 compendium_v2/survey_db/model.py              | 17 +++--
 test/conftest.py                              | 25 +++----
 test/test_survey_publisher_2022.py            |  8 +--
 test/test_survey_publisher_v1.py              |  8 +--
 9 files changed, 75 insertions(+), 118 deletions(-)

diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py
index 4717a486..de1b313f 100644
--- a/compendium_v2/__init__.py
+++ b/compendium_v2/__init__.py
@@ -33,6 +33,11 @@ def _create_app_with_db(app_config) -> Flask:
     # used by the tests and the publishers
     app = _create_app(app_config)
     app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI']
+
+    if 'SQLALCHEMY_BINDS' in app.config['CONFIG_PARAMS']:
+        # for the publishers
+        app.config['SQLALCHEMY_BINDS'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_BINDS']
+
     db.init_app(app)
     return app
 
diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py
index 43d1bf02..fb9fb40e 100644
--- a/compendium_v2/publishers/helpers.py
+++ b/compendium_v2/publishers/helpers.py
@@ -1,14 +1,8 @@
 from sqlalchemy import select
 
-from compendium_v2 import survey_db
 from compendium_v2.db import db, model
 
 
-def init_db(config):
-    dsn_survey = config['SURVEY_DATABASE_URI']
-    survey_db.init_db_model(dsn_survey)
-
-
 def get_uppercase_nren_dict():
     """
     :return: a dictionary of all known NRENs db entities keyed on the
diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py
index 5e449c3e..c213d265 100644
--- a/compendium_v2/publishers/survey_publisher_2022.py
+++ b/compendium_v2/publishers/survey_publisher_2022.py
@@ -20,7 +20,7 @@ import compendium_v2
 from compendium_v2.db.model import FeeType
 from compendium_v2.environment import setup_logging
 from compendium_v2.config import load
-from compendium_v2 import survey_db
+from compendium_v2.survey_db import model as survey_model
 from compendium_v2.db import db, model
 from compendium_v2.publishers import helpers
 
@@ -117,21 +117,18 @@ class ChargingStructure(enum.Enum):
 
 
 def query_budget():
-    with survey_db.session_scope() as survey:
-        return survey.execute(text(BUDGET_QUERY))
+    return db.session.execute(text(BUDGET_QUERY), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]})
 
 
 def query_funding_sources():
     for source in FundingSource:
         query = QUESTION_TEMPLATE_QUERY.format(source.value)
-        with survey_db.session_scope() as survey:
-            yield source, survey.execute(text(query))
+        yield source, db.session.execute(text(query), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]})
 
 
 def query_question(question: enum.Enum):
     query = QUESTION_TEMPLATE_QUERY.format(question.value)
-    with survey_db.session_scope() as survey:
-        return survey.execute(text(query))
+    return db.session.execute(text(query), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]})
 
 
 def transfer_budget(nren_dict):
@@ -404,7 +401,6 @@ def transfer_ec_projects(nren_dict):
 
 
 def _cli(config, app):
-    helpers.init_db(config)
     with app.app_context():
         nren_dict = helpers.get_uppercase_nren_dict()
         transfer_budget(nren_dict)
@@ -420,6 +416,9 @@ def _cli(config, app):
 @click.option('--config', type=click.STRING, default='config.json')
 def cli(config):
     app_config = load(open(config, 'r'))
+
+    app_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: app_config['SURVEY_DATABASE_URI']}
+
     app = compendium_v2._create_app_with_db(app_config)
     _cli(app_config, app)
 
diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py
index c55926bb..cbb7a05e 100644
--- a/compendium_v2/publishers/survey_publisher_v1.py
+++ b/compendium_v2/publishers/survey_publisher_v1.py
@@ -15,7 +15,6 @@ from sqlalchemy import select
 
 import compendium_v2
 from compendium_v2.environment import setup_logging
-from compendium_v2 import survey_db
 from compendium_v2.background_task import parse_excel_data
 from compendium_v2.config import load
 from compendium_v2.db import db, model
@@ -28,49 +27,47 @@ logger = logging.getLogger('survey-publisher-v1')
 
 
 def db_budget_migration(nren_dict):
-    with survey_db.session_scope() as survey_session:
-
-        # move data from Survey DB budget table
-        data = survey_session.scalars(select(survey_model.Nrens))
-        for nren in data:
-            for budget in nren.budgets:
-                abbrev = nren.abbreviation.upper()
-                year = budget.year
-
-                if float(budget.budget) > 200:
-                    logger.warning(f'Incorrect Data: {abbrev} has budget set >200M EUR for {year}. ({budget.budget})')
-
-                if abbrev not in nren_dict:
-                    logger.warning(f'{abbrev} unknown. Skipping.')
-                    continue
-
-                budget_entry = model.BudgetEntry(
-                    nren=nren_dict[abbrev],
-                    nren_id=nren_dict[abbrev].id,
-                    budget=float(budget.budget),
-                    year=year
-                )
-                db.session.merge(budget_entry)
-
-        # Import the data from excel sheet to database
-        exceldata = parse_excel_data.fetch_budget_excel_data()
-
-        for abbrev, budget, year in exceldata:
+    # move data from Survey DB budget table
+    data = db.session.scalars(select(survey_model.Nrens))
+    for nren in data:
+        for budget in nren.budgets:
+            abbrev = nren.abbreviation.upper()
+            year = budget.year
+
+            if float(budget.budget) > 200:
+                logger.warning(f'Incorrect Data: {abbrev} has budget set >200M EUR for {year}. ({budget.budget})')
+
             if abbrev not in nren_dict:
                 logger.warning(f'{abbrev} unknown. Skipping.')
                 continue
 
-            if budget > 200:
-                logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})')
-
             budget_entry = model.BudgetEntry(
                 nren=nren_dict[abbrev],
                 nren_id=nren_dict[abbrev].id,
-                budget=budget,
+                budget=float(budget.budget),
                 year=year
             )
             db.session.merge(budget_entry)
-        db.session.commit()
+
+    # Import the data from excel sheet to database
+    exceldata = parse_excel_data.fetch_budget_excel_data()
+
+    for abbrev, budget, year in exceldata:
+        if abbrev not in nren_dict:
+            logger.warning(f'{abbrev} unknown. Skipping.')
+            continue
+
+        if budget > 200:
+            logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})')
+
+        budget_entry = model.BudgetEntry(
+            nren=nren_dict[abbrev],
+            nren_id=nren_dict[abbrev].id,
+            budget=budget,
+            year=year
+        )
+        db.session.merge(budget_entry)
+    db.session.commit()
 
 
 def db_funding_migration(nren_dict):
@@ -203,7 +200,6 @@ def db_organizations_migration(nren_dict):
 
 
 def _cli(config, app):
-    helpers.init_db(config)
     with app.app_context():
         nren_dict = helpers.get_uppercase_nren_dict()
         db_budget_migration(nren_dict)
@@ -218,6 +214,9 @@ def _cli(config, app):
 @click.option('--config', type=click.STRING, default='config.json')
 def cli(config):
     app_config = load(open(config, 'r'))
+
+    app_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: app_config['SURVEY_DATABASE_URI']}
+
     app = compendium_v2._create_app_with_db(app_config)
     print("survey-publisher-v1 starting")
     _cli(app_config, app)
diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py
index 1550ddcb..e69de29b 100644
--- a/compendium_v2/survey_db/__init__.py
+++ b/compendium_v2/survey_db/__init__.py
@@ -1,40 +0,0 @@
-import contextlib
-import logging
-from typing import Optional, Union, Callable, Iterator
-
-from sqlalchemy import create_engine
-from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.orm import sessionmaker, Session
-
-logger = logging.getLogger(__name__)
-_SESSION_MAKER: Union[None, sessionmaker] = None
-
-
-@contextlib.contextmanager
-def session_scope(
-        callback_before_close: Optional[Callable] = None) -> Iterator[Session]:
-    # best practice is to keep session scope separate from data processing
-    # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html
-
-    assert _SESSION_MAKER
-    session = _SESSION_MAKER()
-    try:
-        yield session
-        session.commit()
-        if callback_before_close:
-            callback_before_close()
-    except SQLAlchemyError:
-        logger.error('caught sql layer exception, rolling back')
-        session.rollback()
-        raise  # re-raise, will be handled by main consumer
-    finally:
-        session.close()
-
-
-def init_db_model(dsn):
-    global _SESSION_MAKER
-
-    # cf. https://docs.sqlalchemy.org/en
-    #        /latest/orm/extensions/automap.html
-    engine = create_engine(dsn, pool_size=10)
-    _SESSION_MAKER = sessionmaker(bind=engine)
diff --git a/compendium_v2/survey_db/model.py b/compendium_v2/survey_db/model.py
index a908aa20..605cf2d5 100644
--- a/compendium_v2/survey_db/model.py
+++ b/compendium_v2/survey_db/model.py
@@ -1,17 +1,23 @@
 import logging
 from typing import List, Optional
 
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from sqlalchemy.orm import Mapped, mapped_column, relationship
 from sqlalchemy.schema import ForeignKey
 
+from compendium_v2.db import db
+
+
 logger = logging.getLogger(__name__)
 
+SURVEY_DB_BIND = 'survey'
 
-class Base(DeclarativeBase):
-    pass
+# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet.
+# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140
+# mypy: disable-error-code="name-defined"
 
 
-class Budgets(Base):
+class Budgets(db.Model):
+    __bind_key__ = SURVEY_DB_BIND
     __tablename__ = 'budgets'
     id: Mapped[int] = mapped_column(primary_key=True)
     budget: Mapped[Optional[str]]
@@ -20,7 +26,8 @@ class Budgets(Base):
     nren: Mapped[Optional['Nrens']] = relationship(back_populates='budgets')
 
 
-class Nrens(Base):
+class Nrens(db.Model):
+    __bind_key__ = SURVEY_DB_BIND
     __tablename__ = 'nrens'
     id: Mapped[int] = mapped_column(primary_key=True)
     abbreviation: Mapped[Optional[str]]
diff --git a/test/conftest.py b/test/conftest.py
index ba6b9a43..be9c16e4 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -3,10 +3,6 @@ import os
 import pytest
 import random
 
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.pool import StaticPool
-
 import compendium_v2
 from compendium_v2.db import db, model
 from compendium_v2.survey_db import model as survey_model
@@ -25,18 +21,6 @@ def dummy_config():
     }
 
 
-@pytest.fixture
-def mocked_survey_db(mocker):
-    engine = create_engine(
-        'sqlite://',
-        connect_args={'check_same_thread': False},
-        poolclass=StaticPool,
-        echo=False)
-    survey_model.Base.metadata.create_all(engine)
-    mocker.patch('compendium_v2.survey_db._SESSION_MAKER', sessionmaker(bind=engine))
-    mocker.patch('compendium_v2.survey_db.init_db_model', lambda dsn: None)
-
-
 @pytest.fixture
 def test_budget_data(app):
     with app.app_context():
@@ -129,6 +113,15 @@ def test_staff_data(app):
 
 @pytest.fixture
 def app(dummy_config):
+    app = compendium_v2._create_app_with_db(dummy_config)
+    with app.app_context():
+        db.create_all(bind_key=None)
+    yield app
+
+
+@pytest.fixture
+def app_with_survey_db(dummy_config):
+    dummy_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: dummy_config['SURVEY_DATABASE_URI']}
     app = compendium_v2._create_app_with_db(dummy_config)
     with app.app_context():
         db.create_all()
diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py
index 3216caf9..72267691 100644
--- a/test/test_survey_publisher_2022.py
+++ b/test/test_survey_publisher_2022.py
@@ -110,7 +110,7 @@ org_dataKTU,"NOC, administrative authority"
         ]
 
 
-def test_publisher(app, mocker, dummy_config):
+def test_publisher(app_with_survey_db, mocker, dummy_config):
     global org_data
 
     def get_rows_as_tuples(*args, **kwargs):
@@ -188,13 +188,13 @@ def test_publisher(app, mocker, dummy_config):
     mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data)
 
     nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
-    with app.app_context():
+    with app_with_survey_db.app_context():
         db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
         db.session.commit()
 
-    _cli(dummy_config, app)
+    _cli(dummy_config, app_with_survey_db)
 
-    with app.app_context():
+    with app_with_survey_db.app_context():
         budgets = db.session.scalars(
             select(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc())
         ).all()
diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py
index 88def54b..496cf71b 100644
--- a/test/test_survey_publisher_v1.py
+++ b/test/test_survey_publisher_v1.py
@@ -9,17 +9,17 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli
 EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx")
 
 
-def test_publisher(mocked_survey_db, app, mocker, dummy_config):
+def test_publisher(app_with_survey_db, mocker, dummy_config):
     mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE)
 
-    with app.app_context():
+    with app_with_survey_db.app_context():
         nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
         db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
         db.session.commit()
 
-    _cli(dummy_config, app)
+    _cli(dummy_config, app_with_survey_db)
 
-    with app.app_context():
+    with app_with_survey_db.app_context():
         budget_count = db.session.scalar(select(func.count(model.BudgetEntry.year)))
         assert budget_count
         funding_source_count = db.session.scalar(select(func.count(model.FundingSource.year)))
-- 
GitLab