From 6e62575f75375af493a76a394264a293cb80c7b8 Mon Sep 17 00:00:00 2001 From: Remco Tukker <remco.tukker@geant.org> Date: Thu, 4 May 2023 10:00:14 +0200 Subject: [PATCH 1/4] use flask-sqlalchemy for the main db --- compendium_v2/__init__.py | 12 +- compendium_v2/db/__init__.py | 50 +- compendium_v2/db/model.py | 48 +- compendium_v2/migrations/env.py | 2 +- compendium_v2/migrations/migration_utils.py | 8 +- compendium_v2/publishers/helpers.py | 13 +- .../publishers/survey_publisher_2022.py | 485 +++++++++--------- .../publishers/survey_publisher_v1.py | 272 +++++----- compendium_v2/routes/budget.py | 31 +- compendium_v2/routes/charging.py | 29 +- compendium_v2/routes/ec_projects.py | 24 +- compendium_v2/routes/funding.py | 27 +- compendium_v2/routes/organization.py | 29 +- compendium_v2/routes/staff.py | 24 +- compendium_v2/survey_db/__init__.py | 5 - requirements.txt | 1 + setup.py | 1 + test/conftest.py | 122 ++--- test/test_survey_publisher_2022.py | 28 +- test/test_survey_publisher_v1.py | 23 +- 20 files changed, 545 insertions(+), 689 deletions(-) diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py index 622565f2..7a69bf9e 100644 --- a/compendium_v2/__init__.py +++ b/compendium_v2/__init__.py @@ -8,7 +8,7 @@ from flask import Flask from flask_cors import CORS # for debugging from compendium_v2 import config, environment - +from compendium_v2.db import db from compendium_v2.migrations import migration_utils @@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask: return app +def _create_app_with_db(app_config) -> Flask: + # used by the tests and the publishers + app = _create_app(app_config) + app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI'] + db.init_app(app) + return app + + def create_app() -> Flask: """ overrides default settings with those found @@ -46,7 +54,7 @@ def create_app() -> Flask: with open(os.environ['SETTINGS_FILENAME']) as f: app_config = config.load(f) - app = _create_app(app_config) + app = _create_app_with_db(app_config) logging.info('Flask app initialized') diff --git a/compendium_v2/db/__init__.py b/compendium_v2/db/__init__.py index ce48aaa0..13719e0a 100644 --- a/compendium_v2/db/__init__.py +++ b/compendium_v2/db/__init__.py @@ -1,45 +1,17 @@ -import contextlib import logging -from typing import Optional, Union, Callable, Iterator -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker, Session +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData -logger = logging.getLogger(__name__) -_SESSION_MAKER: Union[None, sessionmaker] = None - - -@contextlib.contextmanager -def session_scope( - callback_before_close: Optional[Callable] = None) -> Iterator[Session]: - # best practice is to keep session scope separate from data processing - # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html - - assert _SESSION_MAKER - session = _SESSION_MAKER() - try: - yield session - session.commit() - if callback_before_close: - callback_before_close() - except SQLAlchemyError: - logger.error('caught sql layer exception, rolling back') - session.rollback() - raise # re-raise, will be handled by main consumer - finally: - session.close() - - -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') +logger = logging.getLogger(__name__) -def init_db_model(dsn): - global _SESSION_MAKER +metadata_obj = MetaData(naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +}) - # cf. https://docs.sqlalchemy.org/en - # /latest/orm/extensions/automap.html - engine = create_engine(dsn, pool_size=10) - _SESSION_MAKER = sessionmaker(bind=engine) +db = SQLAlchemy(metadata=metadata_obj) diff --git a/compendium_v2/db/model.py b/compendium_v2/db/model.py index 67672743..2199f759 100644 --- a/compendium_v2/db/model.py +++ b/compendium_v2/db/model.py @@ -4,42 +4,34 @@ from enum import Enum from typing import Optional from typing_extensions import Annotated -from sqlalchemy import MetaData, String -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey +from compendium_v2.db import db -logger = logging.getLogger(__name__) - -convention = { - "ix": "ix_%(column_0_label)s", - "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_%(constraint_name)s", - "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s", -} -metadata_obj = MetaData(naming_convention=convention) +logger = logging.getLogger(__name__) str128 = Annotated[str, 128] +str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)] +str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)] int_pk = Annotated[int, mapped_column(primary_key=True)] int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)] -class Base(DeclarativeBase): - metadata = metadata_obj - type_annotation_map = { - str128: String(128), - } +# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet. +# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140 +# mypy: disable-error-code="name-defined" -class NREN(Base): +class NREN(db.Model): __tablename__ = 'nren' id: Mapped[int_pk] name: Mapped[str128] -class BudgetEntry(Base): +class BudgetEntry(db.Model): __tablename__ = 'budgets' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -47,7 +39,7 @@ class BudgetEntry(Base): budget: Mapped[Decimal] -class FundingSource(Base): +class FundingSource(db.Model): __tablename__ = 'funding_source' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -67,7 +59,7 @@ class FeeType(Enum): other = "other" -class ChargingStructure(Base): +class ChargingStructure(db.Model): __tablename__ = 'charging_structure' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -75,7 +67,7 @@ class ChargingStructure(Base): fee_type: Mapped[Optional[FeeType]] -class NrenStaff(Base): +class NrenStaff(db.Model): __tablename__ = 'nren_staff' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -86,7 +78,7 @@ class NrenStaff(Base): non_technical_fte: Mapped[Decimal] -class ParentOrganization(Base): +class ParentOrganization(db.Model): __tablename__ = 'parent_organization' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -94,18 +86,18 @@ class ParentOrganization(Base): organization: Mapped[str128] -class SubOrganization(Base): +class SubOrganization(db.Model): __tablename__ = 'sub_organization' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') year: Mapped[int_pk] - organization: Mapped[str128] = mapped_column(primary_key=True) + organization: Mapped[str128_pk] role: Mapped[str128] -class ECProject(Base): +class ECProject(db.Model): __tablename__ = 'ec_project' nren_id: Mapped[int_pk_fkNREN] - nren: Mapped[NREN] = relationship(NREN, lazy='joined') + nren: Mapped[NREN] = relationship(lazy='joined') year: Mapped[int_pk] - project: Mapped[str] = mapped_column(String(256), primary_key=True) + project: Mapped[str256_pk] diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py index 0307be33..5ea9c8d3 100644 --- a/compendium_v2/migrations/env.py +++ b/compendium_v2/migrations/env.py @@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool from alembic import context -from compendium_v2.db.model import metadata_obj +from compendium_v2.db import metadata_obj # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py index 3b29b540..f25f03c9 100644 --- a/compendium_v2/migrations/migration_utils.py +++ b/compendium_v2/migrations/migration_utils.py @@ -1,7 +1,6 @@ import logging import os -from compendium_v2 import db from alembic.config import Config from alembic import command @@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY): command.upgrade(alembic_config, 'head') +def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): + return (f'postgresql://{db_username}:{db_password}' + f'@{db_hostname}:{port}/{db_name}') + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - upgrade(db.postgresql_dsn( + upgrade(postgresql_dsn( db_username='compendium', db_password='compendium321', db_hostname='localhost', diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py index d95ad4e1..43d1bf02 100644 --- a/compendium_v2/publishers/helpers.py +++ b/compendium_v2/publishers/helpers.py @@ -1,21 +1,20 @@ -from compendium_v2 import db, survey_db -from compendium_v2.db import model +from sqlalchemy import select + +from compendium_v2 import survey_db +from compendium_v2.db import db, model def init_db(config): - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) dsn_survey = config['SURVEY_DATABASE_URI'] survey_db.init_db_model(dsn_survey) -def get_uppercase_nren_dict(session): +def get_uppercase_nren_dict(): """ - :param session: db session that is used to query the known NRENs :return: a dictionary of all known NRENs db entities keyed on the uppercased name """ - current_nrens = session.query(model.NREN).all() + current_nrens = db.session.scalars(select(model.NREN)) nren_dict = {nren.name.upper(): nren for nren in current_nrens} # add aliases that are used in the source data: nren_dict['ASNET'] = nren_dict['ASNET-AM'] diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py index 710c899d..ea88ee77 100644 --- a/compendium_v2/publishers/survey_publisher_2022.py +++ b/compendium_v2/publishers/survey_publisher_2022.py @@ -16,11 +16,12 @@ import html from sqlalchemy import text from collections import defaultdict +import compendium_v2 from compendium_v2.db.model import FeeType from compendium_v2.environment import setup_logging from compendium_v2.config import load -from compendium_v2 import db, survey_db -from compendium_v2.db import model +from compendium_v2 import survey_db +from compendium_v2.db import db, model from compendium_v2.publishers import helpers setup_logging() @@ -133,163 +134,151 @@ def query_question(question: enum.Enum): return survey.execute(text(query)) -def transfer_budget(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_budget() - for row in rows: +def transfer_budget(nren_dict): + rows = query_budget() + for row in rows: + nren_name = row[0].upper() + _budget = row[1] + try: + budget = float(_budget.replace('"', '').replace(',', '')) + except ValueError: + logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))') + continue + + if budget > 200: + logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})') + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + budget_entry = model.BudgetEntry( + nren=nren_dict[nren_name], + budget=budget, + year=2022, + ) + db.session.merge(budget_entry) + db.session.commit() + + +def transfer_funding_sources(nren_dict): + sourcedata = {} + for source, data in query_funding_sources(): + for row in data: nren_name = row[0].upper() - _budget = row[1] + _value = row[1] try: - budget = float(_budget.replace('"', '').replace(',', '')) + value = float(_value.replace('"', '').replace(',', '')) except ValueError: - logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))') - continue - - if budget > 200: - logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})') - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue + name = source.name + logger.info(f'{nren_name} has invalid value for {name}. ({_value}))') + value = 0 - budget_entry = model.BudgetEntry( - nren=nren_dict[nren_name], - budget=budget, - year=2022, + nren_info = sourcedata.setdefault( + nren_name, + {source_type: 0 for source_type in FundingSource} ) - session.merge(budget_entry) - session.commit() - - -def transfer_funding_sources(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - sourcedata = {} - for source, data in query_funding_sources(): - for row in data: - nren_name = row[0].upper() - _value = row[1] - try: - value = float(_value.replace('"', '').replace(',', '')) - except ValueError: - name = source.name - logger.info(f'{nren_name} has invalid value for {name}. ({_value}))') - value = 0 - - nren_info = sourcedata.setdefault( - nren_name, - {source_type: 0 for source_type in FundingSource} - ) - nren_info[source] = value - - for nren_name, nren_info in sourcedata.items(): - total = sum(nren_info.values()) - - if not math.isclose(total, 100, abs_tol=0.01): - logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})') + nren_info[source] = value + + for nren_name, nren_info in sourcedata.items(): + total = sum(nren_info.values()) + + if not math.isclose(total, 100, abs_tol=0.01): + logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})') + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + funding_source = model.FundingSource( + nren=nren_dict[nren_name], + year=2022, + client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS], + european_funding=nren_info[FundingSource.EUROPEAN_FUNDING], + gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES], + commercial=nren_info[FundingSource.COMMERCIAL], + other=nren_info[FundingSource.OTHER], + ) + db.session.merge(funding_source) + db.session.commit() + + +def transfer_staff_data(nren_dict): + data = {} + for question in StaffQuestion: + rows = query_question(question) + for row in rows: + nren_name = row[0].upper() + _value = row[1] + try: + value = float(_value.replace('"', '').replace(',', '')) + except ValueError: + value = 0 if nren_name not in nren_dict: logger.info(f'{nren_name} unknown. Skipping.') continue - funding_source = model.FundingSource( - nren=nren_dict[nren_name], - year=2022, - client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS], - european_funding=nren_info[FundingSource.EUROPEAN_FUNDING], - gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES], - commercial=nren_info[FundingSource.COMMERCIAL], - other=nren_info[FundingSource.OTHER], - ) - session.merge(funding_source) - session.commit() - - -def transfer_staff_data(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - data = {} - for question in StaffQuestion: - rows = query_question(question) - for row in rows: - nren_name = row[0].upper() - _value = row[1] - try: - value = float(_value.replace('"', '').replace(',', '')) - except ValueError: - value = 0 - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - # initialize on first use, so we don't add data for nrens with no answers - data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[ - question] = value - - for nren_name, nren_info in data.items(): - if sum([nren_info[question] for question in StaffQuestion]) == 0: - logger.info(f'{nren_name} has no staff data. Deleting if exists.') - session.query(model.NrenStaff).filter( - model.NrenStaff.nren_id == nren_dict[nren_name].id, - model.NrenStaff.year == 2022, - ).delete() - continue - - employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] - technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE] - - if not math.isclose(employed, technical, abs_tol=0.01): - logger.info(f'{nren_name} FTE do not equal across employed/technical categories.' - f' ({employed} != {technical})') - - staff_data = model.NrenStaff( - nren_id=nren_dict[nren_name].id, - year=2022, - permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE], - subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE], - technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE], - non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE], - ) - session.merge(staff_data) - session.commit() - - -def transfer_nren_parent_org(): + # initialize on first use, so we don't add data for nrens with no answers + data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[ + question] = value + + for nren_name, nren_info in data.items(): + if sum([nren_info[question] for question in StaffQuestion]) == 0: + logger.info(f'{nren_name} has no staff data. Deleting if exists.') + db.session.query(model.NrenStaff).filter( + model.NrenStaff.nren_id == nren_dict[nren_name].id, + model.NrenStaff.year == 2022, + ).delete() + continue + + employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] + technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE] + + if not math.isclose(employed, technical, abs_tol=0.01): + logger.info(f'{nren_name} FTE do not equal across employed/technical categories.' + f' ({employed} != {technical})') + + staff_data = model.NrenStaff( + nren_id=nren_dict[nren_name].id, + year=2022, + permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE], + subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE], + technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE], + non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE], + ) + db.session.merge(staff_data) + db.session.commit() + + +def transfer_nren_parent_org(nren_dict): # clean up the data a bit by removing some strings strings_to_replace = [ 'We are affiliated to ' ] - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_question(OrgQuestion.PARENT_ORG_NAME) - for row in rows: - nren_name = row[0].upper() - value = str(row[1]).replace('"', '') + rows = query_question(OrgQuestion.PARENT_ORG_NAME) + for row in rows: + nren_name = row[0].upper() + value = str(row[1]).replace('"', '') - for string in strings_to_replace: - value = value.replace(string, '') + for string in strings_to_replace: + value = value.replace(string, '') - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue - parent_org = model.ParentOrganization( - nren_id=nren_dict[nren_name].id, - year=2022, - organization=value, - ) - session.merge(parent_org) - session.commit() + parent_org = model.ParentOrganization( + nren_id=nren_dict[nren_name].id, + year=2022, + organization=value, + ) + db.session.merge(parent_org) + db.session.commit() -def transfer_nren_sub_org(): +def transfer_nren_sub_org(nren_dict): suborg_questions = [ (OrgQuestion.SUB_ORGS_1_NAME, OrgQuestion.SUB_ORGS_1_CHOICE, OrgQuestion.SUB_ORGS_1_ROLE), (OrgQuestion.SUB_ORGS_2_NAME, OrgQuestion.SUB_ORGS_2_CHOICE, OrgQuestion.SUB_ORGS_2_ROLE), @@ -298,140 +287,134 @@ def transfer_nren_sub_org(): (OrgQuestion.SUB_ORGS_5_NAME, OrgQuestion.SUB_ORGS_5_CHOICE, OrgQuestion.SUB_ORGS_5_ROLE) ] - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - lookup = defaultdict(list) - - for name, choice, role in suborg_questions: - _name_rows = query_question(name) - _choice_rows = query_question(choice) - _role_rows = list(query_question(role)) - for _name, _choice in zip(_name_rows, _choice_rows): - nren_name = _name[0].upper() - suborg_name = _name[1].replace('"', '').strip() - role_choice = _choice[1].replace('"', '').strip() - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - if role_choice.lower() == 'other': - for _role in _role_rows: - if _role[0] == _name[0]: - role = _role[1].replace('"', '').strip() - break - else: - role = role_choice - - lookup[nren_name].append((suborg_name, role)) - - for nren_name, suborgs in lookup.items(): - for suborg_name, role in suborgs: - suborg = model.SubOrganization( - nren_id=nren_dict[nren_name].id, - year=2022, - organization=suborg_name, - role=role, - ) - session.merge(suborg) - session.commit() - - -def transfer_charging_structure(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_question(ChargingStructure.charging_structure) - for row in rows: - nren_name = row[0].upper() - value = row[1].replace('"', '').strip() + lookup = defaultdict(list) + + for name, choice, role in suborg_questions: + _name_rows = query_question(name) + _choice_rows = query_question(choice) + _role_rows = list(query_question(role)) + for _name, _choice in zip(_name_rows, _choice_rows): + nren_name = _name[0].upper() + suborg_name = _name[1].replace('"', '').strip() + role_choice = _choice[1].replace('"', '').strip() if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping from charging structure.') + logger.info(f'{nren_name} unknown. Skipping.') continue - if "do not charge" in value: - charging_structure = FeeType.no_charge - elif "combination" in value: - charging_structure = FeeType.combination - elif "flat" in value: - charging_structure = FeeType.flat_fee - elif "usage-based" in value: - charging_structure = FeeType.usage_based_fee - elif "Other" in value: - charging_structure = FeeType.other + if role_choice.lower() == 'other': + for _role in _role_rows: + if _role[0] == _name[0]: + role = _role[1].replace('"', '').strip() + break else: - charging_structure = None + role = role_choice + + lookup[nren_name].append((suborg_name, role)) - charging_structure = model.ChargingStructure( + for nren_name, suborgs in lookup.items(): + for suborg_name, role in suborgs: + suborg = model.SubOrganization( nren_id=nren_dict[nren_name].id, year=2022, - fee_type=charging_structure, + organization=suborg_name, + role=role, ) - session.merge(charging_structure) - session.commit() - - -def transfer_ec_projects(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # delete all existing EC projects, in case something changed - session.query(model.ECProject).filter( - model.ECProject.year == 2022, - ).delete() - - rows = query_question(ECQuestion.EC_PROJECT) - for row in rows: - nren_name = row[0].upper() - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') + db.session.merge(suborg) + db.session.commit() + + +def transfer_charging_structure(nren_dict): + rows = query_question(ChargingStructure.charging_structure) + for row in rows: + nren_name = row[0].upper() + value = row[1].replace('"', '').strip() + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping from charging structure.') + continue + + if "do not charge" in value: + charging_structure = FeeType.no_charge + elif "combination" in value: + charging_structure = FeeType.combination + elif "flat" in value: + charging_structure = FeeType.flat_fee + elif "usage-based" in value: + charging_structure = FeeType.usage_based_fee + elif "Other" in value: + charging_structure = FeeType.other + else: + charging_structure = None + + charging_structure = model.ChargingStructure( + nren_id=nren_dict[nren_name].id, + year=2022, + fee_type=charging_structure, + ) + db.session.merge(charging_structure) + db.session.commit() + + +def transfer_ec_projects(nren_dict): + # delete all existing EC projects, in case something changed + db.session.query(model.ECProject).filter( + model.ECProject.year == 2022, + ).delete() + + rows = query_question(ECQuestion.EC_PROJECT) + for row in rows: + nren_name = row[0].upper() + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + try: + value = json.loads(row[1]) + except json.decoder.JSONDecodeError: + logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.') + continue + + for val in value: + if not val: + logger.info(f'Invalid EC project value for {nren_name}: {val}.') continue - try: - value = json.loads(row[1]) - except json.decoder.JSONDecodeError: - logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.') - continue - - for val in value: - if not val: - logger.info(f'Invalid EC project value for {nren_name}: {val}.') - continue - - # strip html entities/NBSP from val - val = html.unescape(val).replace('\xa0', ' ') + # strip html entities/NBSP from val + val = html.unescape(val).replace('\xa0', ' ') - # some answers include contract numbers, which we don't want here - val = val.split('(contract n')[0] + # some answers include contract numbers, which we don't want here + val = val.split('(contract n')[0] - ec_project = model.ECProject( - nren_id=nren_dict[nren_name].id, - year=2022, - project=str(val).strip() - ) - session.add(ec_project) - session.commit() + ec_project = model.ECProject( + nren_id=nren_dict[nren_name].id, + year=2022, + project=str(val).strip() + ) + db.session.add(ec_project) + db.session.commit() -def _cli(config): +def _cli(config, app): helpers.init_db(config) - transfer_budget() - transfer_funding_sources() - transfer_staff_data() - transfer_nren_parent_org() - transfer_nren_sub_org() - transfer_charging_structure() - transfer_ec_projects() + with app.app_context(): + nren_dict = helpers.get_uppercase_nren_dict() + transfer_budget(nren_dict) + transfer_funding_sources(nren_dict) + transfer_staff_data(nren_dict) + transfer_nren_parent_org(nren_dict) + transfer_nren_sub_org(nren_dict) + transfer_charging_structure(nren_dict) + transfer_ec_projects(nren_dict) @click.command() @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) - _cli(app_config) + app = compendium_v2._create_app_with_db(app_config) + _cli(app_config, app) if __name__ == "__main__": diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py index ba0e4e2a..5f4cd8b2 100644 --- a/compendium_v2/publishers/survey_publisher_v1.py +++ b/compendium_v2/publishers/survey_publisher_v1.py @@ -11,11 +11,12 @@ import logging import math import click +import compendium_v2 from compendium_v2.environment import setup_logging -from compendium_v2 import db, survey_db +from compendium_v2 import survey_db from compendium_v2.background_task import parse_excel_data from compendium_v2.config import load -from compendium_v2.db import model +from compendium_v2.db import db, model from compendium_v2.survey_db import model as survey_model from compendium_v2.publishers import helpers @@ -24,11 +25,8 @@ setup_logging() logger = logging.getLogger('survey-publisher-v1') -def db_budget_migration(): - with survey_db.session_scope() as survey_session, \ - db.session_scope() as session: - - nren_dict = helpers.get_uppercase_nren_dict(session) +def db_budget_migration(nren_dict): + with survey_db.session_scope() as survey_session: # move data from Survey DB budget table data = survey_session.query(survey_model.Nrens) @@ -49,7 +47,7 @@ def db_budget_migration(): budget=float(budget.budget), year=year ) - session.merge(budget_entry) + db.session.merge(budget_entry) # Import the data from excel sheet to database exceldata = parse_excel_data.fetch_budget_excel_data() @@ -63,165 +61,153 @@ def db_budget_migration(): logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year) - session.merge(budget_entry) - session.commit() - - -def db_funding_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # Import the data to database - data = parse_excel_data.fetch_funding_excel_data() - - for (abbrev, year, client_institution, - european_funding, - gov_public_bodies, - commercial, other) in data: - - _data = [client_institution, european_funding, gov_public_bodies, commercial, other] - total = sum(_data) - if not math.isclose(total, 100, abs_tol=0.01) and total != 0: - logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') - - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - budget_entry = model.FundingSource( - nren=nren_dict[abbrev], - year=year, - client_institutions=client_institution, - european_funding=european_funding, - gov_public_bodies=gov_public_bodies, - commercial=commercial, - other=other) - session.merge(budget_entry) - session.commit() - - -def db_charging_structure_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # Import the data to database - data = parse_excel_data.fetch_charging_structure_excel_data() - - for (abbrev, year, charging_structure) in data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - charging_structure_entry = model.ChargingStructure( - nren=nren_dict[abbrev], year=year, fee_type=charging_structure) - session.merge(charging_structure_entry) - session.commit() - - -def db_staffing_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - staff_data = parse_excel_data.fetch_staffing_excel_data() - - nren_staff_map = {} - for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping staff data.') - continue - - nren = nren_dict[abbrev] + db.session.merge(budget_entry) + db.session.commit() + + +def db_funding_migration(nren_dict): + # Import the data to database + data = parse_excel_data.fetch_funding_excel_data() + + for (abbrev, year, client_institution, + european_funding, + gov_public_bodies, + commercial, other) in data: + + _data = [client_institution, european_funding, gov_public_bodies, commercial, other] + total = sum(_data) + if not math.isclose(total, 100, abs_tol=0.01) and total != 0: + logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') + + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + budget_entry = model.FundingSource( + nren=nren_dict[abbrev], + year=year, + client_institutions=client_institution, + european_funding=european_funding, + gov_public_bodies=gov_public_bodies, + commercial=commercial, + other=other) + db.session.merge(budget_entry) + db.session.commit() + + +def db_charging_structure_migration(nren_dict): + # Import the data to database + data = parse_excel_data.fetch_charging_structure_excel_data() + + for (abbrev, year, charging_structure) in data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + charging_structure_entry = model.ChargingStructure( + nren=nren_dict[abbrev], year=year, fee_type=charging_structure) + db.session.merge(charging_structure_entry) + db.session.commit() + + +def db_staffing_migration(nren_dict): + staff_data = parse_excel_data.fetch_staffing_excel_data() + + nren_staff_map = {} + for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping staff data.') + continue + + nren = nren_dict[abbrev] + nren_staff_map[(nren.id, year)] = model.NrenStaff( + nren=nren, + nren_id=nren.id, + year=year, + permanent_fte=permanent_fte, + subcontracted_fte=subcontracted_fte, + technical_fte=0, + non_technical_fte=0 + ) + + function_data = parse_excel_data.fetch_staff_function_excel_data() + for (abbrev, year, technical_fte, non_technical_fte) in function_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping staff function data.') + continue + + nren = nren_dict[abbrev] + if (nren.id, year) in nren_staff_map: + nren_staff_map[(nren.id, year)].technical_fte = technical_fte + nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte + else: nren_staff_map[(nren.id, year)] = model.NrenStaff( nren=nren, nren_id=nren.id, year=year, - permanent_fte=permanent_fte, - subcontracted_fte=subcontracted_fte, - technical_fte=0, - non_technical_fte=0 + permanent_fte=0, + subcontracted_fte=0, + technical_fte=technical_fte, + non_technical_fte=non_technical_fte ) - function_data = parse_excel_data.fetch_staff_function_excel_data() - for (abbrev, year, technical_fte, non_technical_fte) in function_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping staff function data.') - continue - - nren = nren_dict[abbrev] - if (nren.id, year) in nren_staff_map: - nren_staff_map[(nren.id, year)].technical_fte = technical_fte - nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte - else: - nren_staff_map[(nren.id, year)] = model.NrenStaff( - nren=nren, - nren_id=nren.id, - year=year, - permanent_fte=0, - subcontracted_fte=0, - technical_fte=technical_fte, - non_technical_fte=non_technical_fte - ) - - for nren_staff_model in nren_staff_map.values(): - employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte - technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte - if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0: - logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:' - f' FTE do not equal across employed/technical categories ({employed} != {technical})') - - session.merge(nren_staff_model) + for nren_staff_model in nren_staff_map.values(): + employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte + technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte + if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0: + logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:' + f' FTE do not equal across employed/technical categories ({employed} != {technical})') - session.commit() + db.session.merge(nren_staff_model) + db.session.commit() -def db_ecprojects_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - ecproject_data = parse_excel_data.fetch_ecproject_excel_data() - for (abbrev, year, project) in ecproject_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - nren = nren_dict[abbrev] - ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) - session.merge(ecproject_entry) - session.commit() +def db_ecprojects_migration(nren_dict): + ecproject_data = parse_excel_data.fetch_ecproject_excel_data() + for (abbrev, year, project) in ecproject_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + nren = nren_dict[abbrev] + ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) + db.session.merge(ecproject_entry) + db.session.commit() -def db_organizations_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - organization_data = parse_excel_data.fetch_organization_excel_data() - for (abbrev, year, org) in organization_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue +def db_organizations_migration(nren_dict): + organization_data = parse_excel_data.fetch_organization_excel_data() + for (abbrev, year, org) in organization_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue - nren = nren_dict[abbrev] - org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) - session.merge(org_entry) - session.commit() + nren = nren_dict[abbrev] + org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) + db.session.merge(org_entry) + db.session.commit() -def _cli(config): +def _cli(config, app): helpers.init_db(config) - db_budget_migration() - db_funding_migration() - db_charging_structure_migration() - db_staffing_migration() - db_ecprojects_migration() - db_organizations_migration() + with app.app_context(): + nren_dict = helpers.get_uppercase_nren_dict() + db_budget_migration(nren_dict) + db_funding_migration(nren_dict) + db_charging_structure_migration(nren_dict) + db_staffing_migration(nren_dict) + db_ecprojects_migration(nren_dict) + db_organizations_migration(nren_dict) @click.command() @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) + app = compendium_v2._create_app_with_db(app_config) print("survey-publisher-v1 starting") - _cli(app_config) + _cli(app_config, app) if __name__ == "__main__": diff --git a/compendium_v2/routes/budget.py b/compendium_v2/routes/budget.py index 763b3af3..1f4ff850 100644 --- a/compendium_v2/routes/budget.py +++ b/compendium_v2/routes/budget.py @@ -1,28 +1,17 @@ import logging from typing import Any -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db -from compendium_v2.db import model +from compendium_v2.db import db +from compendium_v2.db.model import BudgetEntry from compendium_v2.routes import common -routes = Blueprint('budget', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('budget', __name__) logger = logging.getLogger(__name__) -col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061', - '#bd7ebe', '#ffb55a', '#ffee65', - '#beb9db', '#fdcce5', '#8bd3c7'] - BUDGET_RESPONSE_SCHEMA = { '$schema': 'http://json-schema.org/draft-07/schema#', @@ -58,15 +47,15 @@ def budget_view() -> Any: :return: """ - def _extract_data(entry: model.BudgetEntry): + def _extract_data(entry: BudgetEntry): return { 'NREN': entry.nren.name, 'BUDGET': float(entry.budget), 'BUDGET_YEAR': entry.year, } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.BudgetEntry)], - key=lambda d: (d['BUDGET_YEAR'], d['NREN'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))], + key=lambda d: (d['BUDGET_YEAR'], d['NREN']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/charging.py b/compendium_v2/routes/charging.py index 0b2c2384..57e8114f 100644 --- a/compendium_v2/routes/charging.py +++ b/compendium_v2/routes/charging.py @@ -1,21 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('charging', __name__) - +from flask import Blueprint, jsonify +from sqlalchemy import select -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ChargingStructure +from compendium_v2.routes import common +routes = Blueprint('charging', __name__) logger = logging.getLogger(__name__) CHARGING_STRUCTURE_RESPONSE_SCHEMA = { @@ -53,16 +47,15 @@ def charging_structure_view() -> Any: :return: """ - def _extract_data(entry: model.ChargingStructure): + def _extract_data(entry: ChargingStructure): return { 'NREN': entry.nren.name, 'YEAR': int(entry.year), 'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None, } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.ChargingStructure) - .all()], - key=lambda d: (d['NREN'], d['YEAR'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))], + key=lambda d: (d['NREN'], d['YEAR']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/ec_projects.py b/compendium_v2/routes/ec_projects.py index 7114718d..b58d8931 100644 --- a/compendium_v2/routes/ec_projects.py +++ b/compendium_v2/routes/ec_projects.py @@ -1,22 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app - -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('ec-projects', __name__) +from flask import Blueprint, jsonify +from sqlalchemy import select - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ECProject +from compendium_v2.routes import common +routes = Blueprint('ec-projects', __name__) logger = logging.getLogger(__name__) EC_PROJECTS_RESPONSE_SCHEMA = { @@ -55,13 +48,12 @@ def ec_projects_view() -> Any: :return: """ - def _extract_project(entry: model.ECProject): + def _extract_project(entry: ECProject): return { 'nren': entry.nren.name, 'year': entry.year, 'project': entry.project } - with db.session_scope() as session: - result = [_extract_project(project) for project in session.query(model.ECProject)] + result = [_extract_project(project) for project in db.session.scalars(select(ECProject))] return jsonify(result) diff --git a/compendium_v2/routes/funding.py b/compendium_v2/routes/funding.py index ed0e26c8..c02bf136 100644 --- a/compendium_v2/routes/funding.py +++ b/compendium_v2/routes/funding.py @@ -1,22 +1,15 @@ import logging -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db from compendium_v2.routes import common -from compendium_v2.db import model +from compendium_v2.db import db +from compendium_v2.db.model import FundingSource from typing import Any -routes = Blueprint('funding', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('funding', __name__) logger = logging.getLogger(__name__) FUNDING_RESPONSE_SCHEMA = { @@ -59,7 +52,7 @@ def funding_source_view() -> Any: :return: """ - def _extract_data(entry: model.FundingSource): + def _extract_data(entry: FundingSource): return { 'NREN': entry.nren.name, 'YEAR': entry.year, @@ -70,8 +63,8 @@ def funding_source_view() -> Any: 'OTHER': float(entry.other) } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.FundingSource)], - key=lambda d: (d['NREN'], d['YEAR'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))], + key=lambda d: (d['NREN'], d['YEAR']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/organization.py b/compendium_v2/routes/organization.py index 61a43354..8e8ebc8d 100644 --- a/compendium_v2/routes/organization.py +++ b/compendium_v2/routes/organization.py @@ -1,22 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app - -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('organization', __name__) +from flask import Blueprint, jsonify +from sqlalchemy import select - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ParentOrganization, SubOrganization +from compendium_v2.routes import common +routes = Blueprint('organization', __name__) logger = logging.getLogger(__name__) ORGANIZATION_RESPONSE_SCHEMA = { @@ -71,15 +64,14 @@ def parent_organization_view() -> Any: :return: """ - def _extract_parent(entry: model.ParentOrganization): + def _extract_parent(entry: ParentOrganization): return { 'nren': entry.nren.name, 'year': entry.year, 'name': entry.organization } - with db.session_scope() as session: - result = [_extract_parent(org) for org in session.query(model.ParentOrganization)] + result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))] return jsonify(result) @@ -98,7 +90,7 @@ def sub_organization_view() -> Any: :return: """ - def _extract_sub(entry: model.SubOrganization): + def _extract_sub(entry: SubOrganization): return { 'nren': entry.nren.name, 'year': entry.year, @@ -106,6 +98,5 @@ def sub_organization_view() -> Any: 'role': entry.role } - with db.session_scope() as session: - result = [_extract_sub(org) for org in session.query(model.SubOrganization)] + result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))] return jsonify(result) diff --git a/compendium_v2/routes/staff.py b/compendium_v2/routes/staff.py index 73e79c48..b8478266 100644 --- a/compendium_v2/routes/staff.py +++ b/compendium_v2/routes/staff.py @@ -1,22 +1,15 @@ import logging -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db +from compendium_v2.db import db +from compendium_v2.db.model import NREN, NrenStaff from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('staff', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('staff', __name__) logger = logging.getLogger(__name__) STAFF_RESPONSE_SCHEMA = { @@ -57,7 +50,7 @@ def staff_view() -> Any: :return: """ - def _extract_data(entry: model.NrenStaff): + def _extract_data(entry: NrenStaff): return { 'nren': entry.nren.name, 'year': entry.year, @@ -67,7 +60,6 @@ def staff_view() -> Any: 'non_technical_fte': float(entry.non_technical_fte) } - with db.session_scope() as session: - entries = [_extract_data(entry) for entry in session.query( - model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())] + entries = [_extract_data(entry) for entry in db.session.scalars( + select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))] return jsonify(entries) diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py index ce48aaa0..1550ddcb 100644 --- a/compendium_v2/survey_db/__init__.py +++ b/compendium_v2/survey_db/__init__.py @@ -31,11 +31,6 @@ def session_scope( session.close() -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') - - def init_db_model(dsn): global _SESSION_MAKER diff --git a/requirements.txt b/requirements.txt index 12ec2076..98804e02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ click~=8.1 jsonschema~=4.17 flask~=2.2 flask-cors~=3.0 +flask-sqlalchemy~=3.0 openpyxl~=3.1 psycopg2-binary~=2.9 SQLAlchemy~=2.0 diff --git a/setup.py b/setup.py index 213288e5..d051ddd7 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ setup( 'jsonschema~=4.17', 'flask~=2.2', 'flask-cors~=3.0', + 'flask-sqlalchemy~=3.0', 'openpyxl~=3.1', 'psycopg2-binary~=2.9', 'SQLAlchemy~=2.0', diff --git a/test/conftest.py b/test/conftest.py index 79186212..ba6b9a43 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,20 +1,15 @@ -import json +import csv import os -import tempfile -import random - import pytest -import compendium_v2 -from compendium_v2 import db -from compendium_v2.db import model -from compendium_v2 import survey_db -from compendium_v2.survey_db import model as survey_model +import random from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool -import csv +import compendium_v2 +from compendium_v2.db import db, model +from compendium_v2.survey_db import model as survey_model def _test_data_csv(filename): @@ -43,61 +38,29 @@ def mocked_survey_db(mocker): @pytest.fixture -def mocked_db(mocker): - # cf. https://stackoverflow.com/a/33057675 - engine = create_engine( - 'sqlite://', - connect_args={'check_same_thread': False}, - poolclass=StaticPool, - echo=False) - model.Base.metadata.create_all(engine) - mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine)) - mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None) - mocker.patch('compendium_v2.migrate_database', lambda config: None) - - -@pytest.fixture -def test_budget_data(): - with db.session_scope() as session: +def test_budget_data(app): + with app.app_context(): data = [row for row in _test_data_csv("BudgetTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] budget = row["budget"] year = row["year"] - session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) - - with survey_db.session_scope() as session: - data = _test_data_csv("BudgetTestData.csv") - nrens = set() - budgets_data = [] - for row in data: - nren = row["nren"] - budget = row["budget"] - year = row["year"] - country_code = row["nren"] - - nrens.add(nren) - - budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code)) - - for nren in nrens: - session.add(survey_model.Nrens(abbreviation=nren, country_code=nren)) - - session.add_all(budgets_data) + db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) + db.session.commit() @pytest.fixture -def test_funding_source_data(): - with db.session_scope() as session: +def test_funding_source_data(app): + with app.app_context(): data = [row for row in _test_data_csv("FundingSourceTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -108,7 +71,7 @@ def test_funding_source_data(): commercial = row["commercial"] other = row["other"] - session.add( + db.session.add( model.FundingSource( nren=nren, year=year, client_institutions=client, @@ -117,10 +80,11 @@ def test_funding_source_data(): commercial=commercial, other=other) ) + db.session.commit() @pytest.fixture -def test_staff_data(): +def test_staff_data(app): # generator of random test data for 5 years and 100 nrens def _generate_rows(): @@ -135,12 +99,12 @@ def test_staff_data(): "non_technical_fte": random.randint(0, 100) } - with db.session_scope() as session: + with app.app_context(): data = list(_generate_rows()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -150,7 +114,7 @@ def test_staff_data(): technical_fte = row["technical_fte"] non_technical_fte = row["non_technical_fte"] - session.add( + db.session.add( model.NrenStaff( nren=nren, year=year, @@ -160,30 +124,29 @@ def test_staff_data(): non_technical_fte=non_technical_fte ) ) + db.session.commit() @pytest.fixture -def data_config_filename(dummy_config): - with tempfile.NamedTemporaryFile() as f: - f.write(json.dumps(dummy_config).encode('utf-8')) - f.flush() - yield f.name +def app(dummy_config): + app = compendium_v2._create_app_with_db(dummy_config) + with app.app_context(): + db.create_all() + yield app @pytest.fixture -def client(data_config_filename, mocked_db, mocked_survey_db): - os.environ['SETTINGS_FILENAME'] = data_config_filename - with compendium_v2.create_app().test_client() as c: - yield c +def client(app): + return app.test_client() @pytest.fixture -def test_charging_structure_data(): - with db.session_scope() as session: +def test_charging_structure_data(app): + with app.app_context(): data = [row for row in _test_data_csv("ChargingStructureTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -192,15 +155,16 @@ def test_charging_structure_data(): if fee_type == "null": fee_type = None - session.add( + db.session.add( model.ChargingStructure( nren=nren, year=year, fee_type=fee_type) ) + db.session.commit() @pytest.fixture -def test_organization_data(): +def test_organization_data(app): def _generate_sub_org_data(): for nren in ["nren" + str(i) for i in range(1, 50)]: for year in range(2016, 2021): @@ -220,21 +184,21 @@ def test_organization_data(): 'name': 'org' + str(year) } - with db.session_scope() as session: + with app.app_context(): org_data = list(_generate_org_data()) sub_org_data = list(_generate_sub_org_data()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for org in org_data: nren = nren_dict[org["nren"]] year = org["year"] name = org["name"] - session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) + db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) for sub_org in sub_org_data: nren = nren_dict[sub_org["nren"]] @@ -242,13 +206,13 @@ def test_organization_data(): name = sub_org["name"] role = sub_org["role"] - session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) + db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) - session.commit() + db.session.commit() @pytest.fixture -def test_ec_project_data(): +def test_ec_project_data(app): def _generate_ec_project_data(): for nren in ["nren" + str(i) for i in range(1, 50)]: for year in range(2016, 2021): @@ -264,19 +228,19 @@ def test_ec_project_data(): 'project': 'ec_project2', } - with db.session_scope() as session: + with app.app_context(): ec_project_data = list(_generate_ec_project_data()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in set(d['nren'] for d in ec_project_data)} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for ec_project in ec_project_data: nren = nren_dict[ec_project["nren"]] year = ec_project["year"] project = ec_project["project"] - session.add(model.ECProject(nren=nren, year=year, project=project)) + db.session.add(model.ECProject(nren=nren, year=year, project=project)) - session.commit() + db.session.commit() diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py index ec8802ed..433def2c 100644 --- a/test/test_survey_publisher_2022.py +++ b/test/test_survey_publisher_2022.py @@ -1,5 +1,4 @@ -from compendium_v2 import db -from compendium_v2.db import model +from compendium_v2.db import db, model from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \ StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion @@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority" ] -def test_publisher(client, mocker, dummy_config): +def test_publisher(app, mocker, dummy_config): global org_data def get_rows_as_tuples(*args, **kwargs): @@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config): mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data) mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data) - with db.session_scope() as session: - nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] + with app.app_context(): + db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.commit() - _cli(dummy_config) + _cli(dummy_config, app) - with db.session_scope() as session: - budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() + with app.app_context(): + budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() assert len(budgets) == 3 assert budgets[0].nren.name.lower() == 'nren1' assert budgets[0].budget == 100 - funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() + funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() assert len(funding_sources) == 3 assert funding_sources[0].nren.name.lower() == 'nren1' assert funding_sources[0].client_institutions == 10 @@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config): assert funding_sources[2].european_funding == 30 assert funding_sources[2].other == 30 - staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() + staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() assert len(staff_data) == 3 assert staff_data[0].nren.name.lower() == 'nren1' @@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config): assert staff_data[2].permanent_fte == 30 assert staff_data[2].subcontracted_fte == 0 - _org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() + _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() assert len(_org_data) == 2 assert _org_data[0].nren.name.lower() == 'nren1' @@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config): assert _org_data[1].nren.name.lower() == 'nren3' assert _org_data[1].organization == 'Org3' - charging_structures = session.query(model.ChargingStructure).order_by( + charging_structures = db.session.query(model.ChargingStructure).order_by( model.ChargingStructure.nren_id.asc()).all() assert len(charging_structures) == 3 assert charging_structures[0].nren.name.lower() == 'nren1' @@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config): assert charging_structures[2].nren.name.lower() == 'nren3' assert charging_structures[2].fee_type == model.FeeType.other - _ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() + _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() assert len(_ec_data) == 3 assert _ec_data[0].nren.name.lower() == 'nren2' diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py index c7ebc2d9..a6dbbaa1 100644 --- a/test/test_survey_publisher_v1.py +++ b/test/test_survey_publisher_v1.py @@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx") -def test_publisher(client, mocker, dummy_config): +def test_publisher(mocked_survey_db, app, mocker, dummy_config): mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE) - with db.session_scope() as session: + with app.app_context(): nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.commit() - _cli(dummy_config) + _cli(dummy_config, app) - with db.session_scope() as session: - budget_count = session.query(model.BudgetEntry.year).count() + with app.app_context(): + budget_count = db.session.query(model.BudgetEntry.year).count() assert budget_count - funding_source_count = session.query(model.FundingSource.year).count() + funding_source_count = db.session.query(model.FundingSource.year).count() assert funding_source_count - charging_structure_count = session.query(model.ChargingStructure.year).count() + charging_structure_count = db.session.query(model.ChargingStructure.year).count() assert charging_structure_count - staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() + staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() # data should only be saved for the NRENs we have saved in the database staff_data_nrens = set([staff.nren.name for staff in staff_data]) @@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config): assert kifu_data[5].technical_fte == 133 assert kifu_data[5].non_technical_fte == 45 - ecproject_data = session.query(model.ECProject).all() + ecproject_data = db.session.query(model.ECProject).all() # test a couple of random entries surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017] assert len(surf2017) == 1 @@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config): assert len(kifu2019) == 4 assert kifu2019[3].project == 'SuperHeroes for Science' - parent_data = session.query(model.ParentOrganization).all() + parent_data = db.session.query(model.ParentOrganization).all() # test a random entry asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021] assert len(asnet2021) == 1 -- GitLab From 14411075c87043bcd91d2ced0ba1efdd5f150ade Mon Sep 17 00:00:00 2001 From: Remco Tukker <remco.tukker@geant.org> Date: Thu, 4 May 2023 15:50:41 +0200 Subject: [PATCH 2/4] use flask-migrate for the migrations --- README.md | 6 +- compendium_v2/__init__.py | 13 ++--- compendium_v2/alembic.ini | 10 ---- compendium_v2/migrations/README | 1 + compendium_v2/migrations/__init__.py | 0 compendium_v2/migrations/alembic.ini | 50 ++++++++++++++++ compendium_v2/migrations/env.py | 63 ++++++++++++++++----- compendium_v2/migrations/migration_utils.py | 41 -------------- requirements.txt | 1 + setup.py | 1 + 10 files changed, 111 insertions(+), 75 deletions(-) delete mode 100644 compendium_v2/alembic.ini create mode 100644 compendium_v2/migrations/README delete mode 100644 compendium_v2/migrations/__init__.py create mode 100644 compendium_v2/migrations/alembic.ini delete mode 100644 compendium_v2/migrations/migration_utils.py diff --git a/README.md b/README.md index 81b1b0b6..dd9a7dce 100644 --- a/README.md +++ b/README.md @@ -63,11 +63,13 @@ survey-publisher-2022 ## Creating a db migration after editing the sqlalchemy models ```bash -cd compendium_v2 -alembic revision --autogenerate -m "description" +flask db migrate -m "description" ``` Then go to the created migration file to make any necessary additions, for example to migrate data. Also see https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect +Flask-migrate sets `compare_type=True` by default. Note that starting the application applies all upgrades. +This also happens when running `flask db` commands such as `flask db downgrade`, +so if you want to downgrade 2 or more versions you need to do so in one command, eg by specifying the revision number. diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py index 7a69bf9e..4717a486 100644 --- a/compendium_v2/__init__.py +++ b/compendium_v2/__init__.py @@ -6,15 +6,11 @@ import os from flask import Flask from flask_cors import CORS # for debugging +# the currently available stubs for flask_migrate are old (they depend on sqlalchemy 1.4 types) +from flask_migrate import Migrate, upgrade # type: ignore from compendium_v2 import config, environment from compendium_v2.db import db -from compendium_v2.migrations import migration_utils - - -def migrate_database(config: dict) -> None: - dsn = config['SQLALCHEMY_DATABASE_URI'] - migration_utils.upgrade(dsn) def _create_app(app_config) -> Flask: @@ -56,11 +52,14 @@ def create_app() -> Flask: app = _create_app_with_db(app_config) + Migrate(app, db, directory=os.path.join(os.path.dirname(__file__), 'migrations')) + logging.info('Flask app initialized') environment.setup_logging() # run migrations on startup - migrate_database(app_config) + with app.app_context(): + upgrade() return app diff --git a/compendium_v2/alembic.ini b/compendium_v2/alembic.ini deleted file mode 100644 index 2145863b..00000000 --- a/compendium_v2/alembic.ini +++ /dev/null @@ -1,10 +0,0 @@ -# A generic, single database configuration. - -# only needed for generating new revision scripts -[alembic] -# make sure the right line is un / commented depending on which schema you want -# a migration for -script_location = migrations -# script_location = cachedb_migrations -# change this to run migrations from the command line -sqlalchemy.url = postgresql://compendium:compendium321@localhost:65000/compendium diff --git a/compendium_v2/migrations/README b/compendium_v2/migrations/README new file mode 100644 index 00000000..0e048441 --- /dev/null +++ b/compendium_v2/migrations/README @@ -0,0 +1 @@ +Single-database configuration for Flask. diff --git a/compendium_v2/migrations/__init__.py b/compendium_v2/migrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/compendium_v2/migrations/alembic.ini b/compendium_v2/migrations/alembic.ini new file mode 100644 index 00000000..ec9d45c2 --- /dev/null +++ b/compendium_v2/migrations/alembic.ini @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py index 5ea9c8d3..e2408681 100644 --- a/compendium_v2/migrations/env.py +++ b/compendium_v2/migrations/env.py @@ -1,10 +1,9 @@ import logging +from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool +from flask import current_app from alembic import context -from compendium_v2.db import metadata_obj # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -12,13 +11,34 @@ config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. -logging.basicConfig(level=logging.INFO) +if config.config_file_name is not None: + fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine() + except TypeError: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -target_metadata = metadata_obj +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db # other values from the config, defined by the needs of env.py, # can be acquired: @@ -26,6 +46,12 @@ target_metadata = metadata_obj # ... etc. +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -40,10 +66,7 @@ def run_migrations_offline(): """ url = config.get_main_option("sqlalchemy.url") context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, + url=url, target_metadata=get_metadata(), literal_binds=True ) with context.begin_transaction(): @@ -57,15 +80,25 @@ def run_migrations_online(): and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + connectable = get_engine() with connectable.connect() as connection: context.configure( - connection=connection, target_metadata=target_metadata + connection=connection, + target_metadata=get_metadata(), + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args ) with context.begin_transaction(): diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py deleted file mode 100644 index f25f03c9..00000000 --- a/compendium_v2/migrations/migration_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -import os - -from alembic.config import Config -from alembic import command - -logger = logging.getLogger(__name__) -DEFAULT_MIGRATIONS_DIRECTORY = os.path.dirname(__file__) - - -def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY): - """ - migrate db to head version - - cf. https://stackoverflow.com/a/43530495, - https://stackoverflow.com/a/54402853 - - :param dsn: dsn string, passed to alembic - :param migrations_directory: full path to migrations directory - (default is this directory) - :return: - """ - alembic_config = Config() - alembic_config.set_main_option('script_location', migrations_directory) - alembic_config.set_main_option('sqlalchemy.url', dsn) - command.upgrade(alembic_config, 'head') - - -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - upgrade(postgresql_dsn( - db_username='compendium', - db_password='compendium321', - db_hostname='localhost', - db_name='compendium', - port=65000)) diff --git a/requirements.txt b/requirements.txt index 98804e02..f49a7960 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ click~=8.1 jsonschema~=4.17 flask~=2.2 flask-cors~=3.0 +flask-migrate~=4.0 flask-sqlalchemy~=3.0 openpyxl~=3.1 psycopg2-binary~=2.9 diff --git a/setup.py b/setup.py index 48a37f1b..52acb8cb 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ setup( 'jsonschema~=4.17', 'flask~=2.2', 'flask-cors~=3.0', + 'flask-migrate~=4.0', 'flask-sqlalchemy~=3.0', 'openpyxl~=3.1', 'psycopg2-binary~=2.9', -- GitLab From 6b736da7efa507be8cbb92d42834b61a83093e3c Mon Sep 17 00:00:00 2001 From: Remco Tukker <remco.tukker@geant.org> Date: Thu, 4 May 2023 17:04:23 +0200 Subject: [PATCH 3/4] use sqlalchemy2 syntax everywhere --- .../publishers/survey_publisher_2022.py | 14 +++++----- .../publishers/survey_publisher_v1.py | 4 ++- test/test_survey_publisher_2022.py | 27 ++++++++++++++----- test/test_survey_publisher_v1.py | 14 +++++----- 4 files changed, 38 insertions(+), 21 deletions(-) diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py index f5a5cf76..5e449c3e 100644 --- a/compendium_v2/publishers/survey_publisher_2022.py +++ b/compendium_v2/publishers/survey_publisher_2022.py @@ -13,7 +13,7 @@ import math import json import html -from sqlalchemy import text +from sqlalchemy import text, delete from collections import defaultdict import compendium_v2 @@ -228,10 +228,10 @@ def transfer_staff_data(nren_dict): for nren_name, nren_info in data.items(): if sum([nren_info[question] for question in StaffQuestion]) == 0: logger.info(f'{nren_name} has no staff data. Deleting if exists.') - db.session.query(model.NrenStaff).filter( + db.session.execute(delete(model.NrenStaff).where( model.NrenStaff.nren_id == nren_dict[nren_name].id, - model.NrenStaff.year == 2022, - ).delete() + model.NrenStaff.year == 2022 + )) continue employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] @@ -364,9 +364,9 @@ def transfer_charging_structure(nren_dict): def transfer_ec_projects(nren_dict): # delete all existing EC projects, in case something changed - db.session.query(model.ECProject).filter( - model.ECProject.year == 2022, - ).delete() + db.session.execute( + delete(model.ECProject).where(model.ECProject.year == 2022) + ) rows = query_question(ECQuestion.EC_PROJECT) for row in rows: diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py index dce1ba58..c55926bb 100644 --- a/compendium_v2/publishers/survey_publisher_v1.py +++ b/compendium_v2/publishers/survey_publisher_v1.py @@ -11,6 +11,8 @@ import logging import math import click +from sqlalchemy import select + import compendium_v2 from compendium_v2.environment import setup_logging from compendium_v2 import survey_db @@ -29,7 +31,7 @@ def db_budget_migration(nren_dict): with survey_db.session_scope() as survey_session: # move data from Survey DB budget table - data = survey_session.query(survey_model.Nrens) + data = survey_session.scalars(select(survey_model.Nrens)) for nren in data: for budget in nren.budgets: abbrev = nren.abbreviation.upper() diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py index 433def2c..3216caf9 100644 --- a/test/test_survey_publisher_2022.py +++ b/test/test_survey_publisher_2022.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from compendium_v2.db import db, model from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \ StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion @@ -193,12 +195,16 @@ def test_publisher(app, mocker, dummy_config): _cli(dummy_config, app) with app.app_context(): - budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() + budgets = db.session.scalars( + select(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()) + ).all() assert len(budgets) == 3 assert budgets[0].nren.name.lower() == 'nren1' assert budgets[0].budget == 100 - funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() + funding_sources = db.session.scalars( + select(model.FundingSource).order_by(model.FundingSource.nren_id.asc()) + ).all() assert len(funding_sources) == 3 assert funding_sources[0].nren.name.lower() == 'nren1' assert funding_sources[0].client_institutions == 10 @@ -215,7 +221,9 @@ def test_publisher(app, mocker, dummy_config): assert funding_sources[2].european_funding == 30 assert funding_sources[2].other == 30 - staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() + staff_data = db.session.scalars( + select(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()) + ).all() assert len(staff_data) == 3 assert staff_data[0].nren.name.lower() == 'nren1' @@ -236,7 +244,9 @@ def test_publisher(app, mocker, dummy_config): assert staff_data[2].permanent_fte == 30 assert staff_data[2].subcontracted_fte == 0 - _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() + _org_data = db.session.scalars( + select(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()) + ).all() assert len(_org_data) == 2 assert _org_data[0].nren.name.lower() == 'nren1' @@ -245,8 +255,9 @@ def test_publisher(app, mocker, dummy_config): assert _org_data[1].nren.name.lower() == 'nren3' assert _org_data[1].organization == 'Org3' - charging_structures = db.session.query(model.ChargingStructure).order_by( - model.ChargingStructure.nren_id.asc()).all() + charging_structures = db.session.scalars( + select(model.ChargingStructure).order_by(model.ChargingStructure.nren_id.asc()) + ).all() assert len(charging_structures) == 3 assert charging_structures[0].nren.name.lower() == 'nren1' assert charging_structures[0].fee_type == model.FeeType.no_charge @@ -255,7 +266,9 @@ def test_publisher(app, mocker, dummy_config): assert charging_structures[2].nren.name.lower() == 'nren3' assert charging_structures[2].fee_type == model.FeeType.other - _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() + _ec_data = db.session.scalars( + select(model.ECProject).order_by(model.ECProject.nren_id.asc()) + ).all() assert len(_ec_data) == 3 assert _ec_data[0].nren.name.lower() == 'nren2' diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py index a6dbbaa1..88def54b 100644 --- a/test/test_survey_publisher_v1.py +++ b/test/test_survey_publisher_v1.py @@ -1,5 +1,7 @@ import os +from sqlalchemy import select, func + from compendium_v2 import db from compendium_v2.db import model from compendium_v2.publishers.survey_publisher_v1 import _cli @@ -18,13 +20,13 @@ def test_publisher(mocked_survey_db, app, mocker, dummy_config): _cli(dummy_config, app) with app.app_context(): - budget_count = db.session.query(model.BudgetEntry.year).count() + budget_count = db.session.scalar(select(func.count(model.BudgetEntry.year))) assert budget_count - funding_source_count = db.session.query(model.FundingSource.year).count() + funding_source_count = db.session.scalar(select(func.count(model.FundingSource.year))) assert funding_source_count - charging_structure_count = db.session.query(model.ChargingStructure.year).count() + charging_structure_count = db.session.scalar(select(func.count(model.ChargingStructure.year))) assert charging_structure_count - staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() + staff_data = db.session.scalars(select(model.NrenStaff).order_by(model.NrenStaff.year.asc())).all() # data should only be saved for the NRENs we have saved in the database staff_data_nrens = set([staff.nren.name for staff in staff_data]) @@ -70,7 +72,7 @@ def test_publisher(mocked_survey_db, app, mocker, dummy_config): assert kifu_data[5].technical_fte == 133 assert kifu_data[5].non_technical_fte == 45 - ecproject_data = db.session.query(model.ECProject).all() + ecproject_data = db.session.scalars(select(model.ECProject)).all() # test a couple of random entries surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017] assert len(surf2017) == 1 @@ -84,7 +86,7 @@ def test_publisher(mocked_survey_db, app, mocker, dummy_config): assert len(kifu2019) == 4 assert kifu2019[3].project == 'SuperHeroes for Science' - parent_data = db.session.query(model.ParentOrganization).all() + parent_data = db.session.scalars(select(model.ParentOrganization)).all() # test a random entry asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021] assert len(asnet2021) == 1 -- GitLab From 0f775788d87de4966a37ce50fc3888bbdd3ba8ac Mon Sep 17 00:00:00 2001 From: Remco Tukker <remco.tukker@geant.org> Date: Thu, 4 May 2023 21:20:49 +0200 Subject: [PATCH 4/4] also use flask-sqlalchemy for the survey db --- compendium_v2/__init__.py | 5 ++ compendium_v2/publishers/helpers.py | 6 -- .../publishers/survey_publisher_2022.py | 15 ++-- .../publishers/survey_publisher_v1.py | 69 +++++++++---------- compendium_v2/survey_db/__init__.py | 40 ----------- compendium_v2/survey_db/model.py | 17 +++-- test/conftest.py | 25 +++---- test/test_survey_publisher_2022.py | 8 +-- test/test_survey_publisher_v1.py | 8 +-- 9 files changed, 75 insertions(+), 118 deletions(-) diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py index 4717a486..de1b313f 100644 --- a/compendium_v2/__init__.py +++ b/compendium_v2/__init__.py @@ -33,6 +33,11 @@ def _create_app_with_db(app_config) -> Flask: # used by the tests and the publishers app = _create_app(app_config) app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI'] + + if 'SQLALCHEMY_BINDS' in app.config['CONFIG_PARAMS']: + # for the publishers + app.config['SQLALCHEMY_BINDS'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_BINDS'] + db.init_app(app) return app diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py index 43d1bf02..fb9fb40e 100644 --- a/compendium_v2/publishers/helpers.py +++ b/compendium_v2/publishers/helpers.py @@ -1,14 +1,8 @@ from sqlalchemy import select -from compendium_v2 import survey_db from compendium_v2.db import db, model -def init_db(config): - dsn_survey = config['SURVEY_DATABASE_URI'] - survey_db.init_db_model(dsn_survey) - - def get_uppercase_nren_dict(): """ :return: a dictionary of all known NRENs db entities keyed on the diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py index 5e449c3e..c213d265 100644 --- a/compendium_v2/publishers/survey_publisher_2022.py +++ b/compendium_v2/publishers/survey_publisher_2022.py @@ -20,7 +20,7 @@ import compendium_v2 from compendium_v2.db.model import FeeType from compendium_v2.environment import setup_logging from compendium_v2.config import load -from compendium_v2 import survey_db +from compendium_v2.survey_db import model as survey_model from compendium_v2.db import db, model from compendium_v2.publishers import helpers @@ -117,21 +117,18 @@ class ChargingStructure(enum.Enum): def query_budget(): - with survey_db.session_scope() as survey: - return survey.execute(text(BUDGET_QUERY)) + return db.session.execute(text(BUDGET_QUERY), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]}) def query_funding_sources(): for source in FundingSource: query = QUESTION_TEMPLATE_QUERY.format(source.value) - with survey_db.session_scope() as survey: - yield source, survey.execute(text(query)) + yield source, db.session.execute(text(query), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]}) def query_question(question: enum.Enum): query = QUESTION_TEMPLATE_QUERY.format(question.value) - with survey_db.session_scope() as survey: - return survey.execute(text(query)) + return db.session.execute(text(query), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]}) def transfer_budget(nren_dict): @@ -404,7 +401,6 @@ def transfer_ec_projects(nren_dict): def _cli(config, app): - helpers.init_db(config) with app.app_context(): nren_dict = helpers.get_uppercase_nren_dict() transfer_budget(nren_dict) @@ -420,6 +416,9 @@ def _cli(config, app): @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) + + app_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: app_config['SURVEY_DATABASE_URI']} + app = compendium_v2._create_app_with_db(app_config) _cli(app_config, app) diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py index c55926bb..cbb7a05e 100644 --- a/compendium_v2/publishers/survey_publisher_v1.py +++ b/compendium_v2/publishers/survey_publisher_v1.py @@ -15,7 +15,6 @@ from sqlalchemy import select import compendium_v2 from compendium_v2.environment import setup_logging -from compendium_v2 import survey_db from compendium_v2.background_task import parse_excel_data from compendium_v2.config import load from compendium_v2.db import db, model @@ -28,49 +27,47 @@ logger = logging.getLogger('survey-publisher-v1') def db_budget_migration(nren_dict): - with survey_db.session_scope() as survey_session: - - # move data from Survey DB budget table - data = survey_session.scalars(select(survey_model.Nrens)) - for nren in data: - for budget in nren.budgets: - abbrev = nren.abbreviation.upper() - year = budget.year - - if float(budget.budget) > 200: - logger.warning(f'Incorrect Data: {abbrev} has budget set >200M EUR for {year}. ({budget.budget})') - - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - budget_entry = model.BudgetEntry( - nren=nren_dict[abbrev], - nren_id=nren_dict[abbrev].id, - budget=float(budget.budget), - year=year - ) - db.session.merge(budget_entry) - - # Import the data from excel sheet to database - exceldata = parse_excel_data.fetch_budget_excel_data() - - for abbrev, budget, year in exceldata: + # move data from Survey DB budget table + data = db.session.scalars(select(survey_model.Nrens)) + for nren in data: + for budget in nren.budgets: + abbrev = nren.abbreviation.upper() + year = budget.year + + if float(budget.budget) > 200: + logger.warning(f'Incorrect Data: {abbrev} has budget set >200M EUR for {year}. ({budget.budget})') + if abbrev not in nren_dict: logger.warning(f'{abbrev} unknown. Skipping.') continue - if budget > 200: - logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') - budget_entry = model.BudgetEntry( nren=nren_dict[abbrev], nren_id=nren_dict[abbrev].id, - budget=budget, + budget=float(budget.budget), year=year ) db.session.merge(budget_entry) - db.session.commit() + + # Import the data from excel sheet to database + exceldata = parse_excel_data.fetch_budget_excel_data() + + for abbrev, budget, year in exceldata: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + if budget > 200: + logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') + + budget_entry = model.BudgetEntry( + nren=nren_dict[abbrev], + nren_id=nren_dict[abbrev].id, + budget=budget, + year=year + ) + db.session.merge(budget_entry) + db.session.commit() def db_funding_migration(nren_dict): @@ -203,7 +200,6 @@ def db_organizations_migration(nren_dict): def _cli(config, app): - helpers.init_db(config) with app.app_context(): nren_dict = helpers.get_uppercase_nren_dict() db_budget_migration(nren_dict) @@ -218,6 +214,9 @@ def _cli(config, app): @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) + + app_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: app_config['SURVEY_DATABASE_URI']} + app = compendium_v2._create_app_with_db(app_config) print("survey-publisher-v1 starting") _cli(app_config, app) diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py index 1550ddcb..e69de29b 100644 --- a/compendium_v2/survey_db/__init__.py +++ b/compendium_v2/survey_db/__init__.py @@ -1,40 +0,0 @@ -import contextlib -import logging -from typing import Optional, Union, Callable, Iterator - -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker, Session - -logger = logging.getLogger(__name__) -_SESSION_MAKER: Union[None, sessionmaker] = None - - -@contextlib.contextmanager -def session_scope( - callback_before_close: Optional[Callable] = None) -> Iterator[Session]: - # best practice is to keep session scope separate from data processing - # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html - - assert _SESSION_MAKER - session = _SESSION_MAKER() - try: - yield session - session.commit() - if callback_before_close: - callback_before_close() - except SQLAlchemyError: - logger.error('caught sql layer exception, rolling back') - session.rollback() - raise # re-raise, will be handled by main consumer - finally: - session.close() - - -def init_db_model(dsn): - global _SESSION_MAKER - - # cf. https://docs.sqlalchemy.org/en - # /latest/orm/extensions/automap.html - engine = create_engine(dsn, pool_size=10) - _SESSION_MAKER = sessionmaker(bind=engine) diff --git a/compendium_v2/survey_db/model.py b/compendium_v2/survey_db/model.py index a908aa20..605cf2d5 100644 --- a/compendium_v2/survey_db/model.py +++ b/compendium_v2/survey_db/model.py @@ -1,17 +1,23 @@ import logging from typing import List, Optional -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey +from compendium_v2.db import db + + logger = logging.getLogger(__name__) +SURVEY_DB_BIND = 'survey' -class Base(DeclarativeBase): - pass +# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet. +# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140 +# mypy: disable-error-code="name-defined" -class Budgets(Base): +class Budgets(db.Model): + __bind_key__ = SURVEY_DB_BIND __tablename__ = 'budgets' id: Mapped[int] = mapped_column(primary_key=True) budget: Mapped[Optional[str]] @@ -20,7 +26,8 @@ class Budgets(Base): nren: Mapped[Optional['Nrens']] = relationship(back_populates='budgets') -class Nrens(Base): +class Nrens(db.Model): + __bind_key__ = SURVEY_DB_BIND __tablename__ = 'nrens' id: Mapped[int] = mapped_column(primary_key=True) abbreviation: Mapped[Optional[str]] diff --git a/test/conftest.py b/test/conftest.py index ba6b9a43..be9c16e4 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,10 +3,6 @@ import os import pytest import random -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool - import compendium_v2 from compendium_v2.db import db, model from compendium_v2.survey_db import model as survey_model @@ -25,18 +21,6 @@ def dummy_config(): } -@pytest.fixture -def mocked_survey_db(mocker): - engine = create_engine( - 'sqlite://', - connect_args={'check_same_thread': False}, - poolclass=StaticPool, - echo=False) - survey_model.Base.metadata.create_all(engine) - mocker.patch('compendium_v2.survey_db._SESSION_MAKER', sessionmaker(bind=engine)) - mocker.patch('compendium_v2.survey_db.init_db_model', lambda dsn: None) - - @pytest.fixture def test_budget_data(app): with app.app_context(): @@ -129,6 +113,15 @@ def test_staff_data(app): @pytest.fixture def app(dummy_config): + app = compendium_v2._create_app_with_db(dummy_config) + with app.app_context(): + db.create_all(bind_key=None) + yield app + + +@pytest.fixture +def app_with_survey_db(dummy_config): + dummy_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: dummy_config['SURVEY_DATABASE_URI']} app = compendium_v2._create_app_with_db(dummy_config) with app.app_context(): db.create_all() diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py index 3216caf9..72267691 100644 --- a/test/test_survey_publisher_2022.py +++ b/test/test_survey_publisher_2022.py @@ -110,7 +110,7 @@ org_dataKTU,"NOC, administrative authority" ] -def test_publisher(app, mocker, dummy_config): +def test_publisher(app_with_survey_db, mocker, dummy_config): global org_data def get_rows_as_tuples(*args, **kwargs): @@ -188,13 +188,13 @@ def test_publisher(app, mocker, dummy_config): mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data) nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - with app.app_context(): + with app_with_survey_db.app_context(): db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) db.session.commit() - _cli(dummy_config, app) + _cli(dummy_config, app_with_survey_db) - with app.app_context(): + with app_with_survey_db.app_context(): budgets = db.session.scalars( select(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()) ).all() diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py index 88def54b..496cf71b 100644 --- a/test/test_survey_publisher_v1.py +++ b/test/test_survey_publisher_v1.py @@ -9,17 +9,17 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx") -def test_publisher(mocked_survey_db, app, mocker, dummy_config): +def test_publisher(app_with_survey_db, mocker, dummy_config): mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE) - with app.app_context(): + with app_with_survey_db.app_context(): nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) db.session.commit() - _cli(dummy_config, app) + _cli(dummy_config, app_with_survey_db) - with app.app_context(): + with app_with_survey_db.app_context(): budget_count = db.session.scalar(select(func.count(model.BudgetEntry.year))) assert budget_count funding_source_count = db.session.scalar(select(func.count(model.FundingSource.year))) -- GitLab