From 6e62575f75375af493a76a394264a293cb80c7b8 Mon Sep 17 00:00:00 2001 From: Remco Tukker <remco.tukker@geant.org> Date: Thu, 4 May 2023 10:00:14 +0200 Subject: [PATCH] use flask-sqlalchemy for the main db --- compendium_v2/__init__.py | 12 +- compendium_v2/db/__init__.py | 50 +- compendium_v2/db/model.py | 48 +- compendium_v2/migrations/env.py | 2 +- compendium_v2/migrations/migration_utils.py | 8 +- compendium_v2/publishers/helpers.py | 13 +- .../publishers/survey_publisher_2022.py | 485 +++++++++--------- .../publishers/survey_publisher_v1.py | 272 +++++----- compendium_v2/routes/budget.py | 31 +- compendium_v2/routes/charging.py | 29 +- compendium_v2/routes/ec_projects.py | 24 +- compendium_v2/routes/funding.py | 27 +- compendium_v2/routes/organization.py | 29 +- compendium_v2/routes/staff.py | 24 +- compendium_v2/survey_db/__init__.py | 5 - requirements.txt | 1 + setup.py | 1 + test/conftest.py | 122 ++--- test/test_survey_publisher_2022.py | 28 +- test/test_survey_publisher_v1.py | 23 +- 20 files changed, 545 insertions(+), 689 deletions(-) diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py index 622565f2..7a69bf9e 100644 --- a/compendium_v2/__init__.py +++ b/compendium_v2/__init__.py @@ -8,7 +8,7 @@ from flask import Flask from flask_cors import CORS # for debugging from compendium_v2 import config, environment - +from compendium_v2.db import db from compendium_v2.migrations import migration_utils @@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask: return app +def _create_app_with_db(app_config) -> Flask: + # used by the tests and the publishers + app = _create_app(app_config) + app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI'] + db.init_app(app) + return app + + def create_app() -> Flask: """ overrides default settings with those found @@ -46,7 +54,7 @@ def create_app() -> Flask: with open(os.environ['SETTINGS_FILENAME']) as f: app_config = config.load(f) - app = _create_app(app_config) + app = _create_app_with_db(app_config) logging.info('Flask app initialized') diff --git a/compendium_v2/db/__init__.py b/compendium_v2/db/__init__.py index ce48aaa0..13719e0a 100644 --- a/compendium_v2/db/__init__.py +++ b/compendium_v2/db/__init__.py @@ -1,45 +1,17 @@ -import contextlib import logging -from typing import Optional, Union, Callable, Iterator -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker, Session +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData -logger = logging.getLogger(__name__) -_SESSION_MAKER: Union[None, sessionmaker] = None - - -@contextlib.contextmanager -def session_scope( - callback_before_close: Optional[Callable] = None) -> Iterator[Session]: - # best practice is to keep session scope separate from data processing - # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html - - assert _SESSION_MAKER - session = _SESSION_MAKER() - try: - yield session - session.commit() - if callback_before_close: - callback_before_close() - except SQLAlchemyError: - logger.error('caught sql layer exception, rolling back') - session.rollback() - raise # re-raise, will be handled by main consumer - finally: - session.close() - - -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') +logger = logging.getLogger(__name__) -def init_db_model(dsn): - global _SESSION_MAKER +metadata_obj = MetaData(naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +}) - # cf. https://docs.sqlalchemy.org/en - # /latest/orm/extensions/automap.html - engine = create_engine(dsn, pool_size=10) - _SESSION_MAKER = sessionmaker(bind=engine) +db = SQLAlchemy(metadata=metadata_obj) diff --git a/compendium_v2/db/model.py b/compendium_v2/db/model.py index 67672743..2199f759 100644 --- a/compendium_v2/db/model.py +++ b/compendium_v2/db/model.py @@ -4,42 +4,34 @@ from enum import Enum from typing import Optional from typing_extensions import Annotated -from sqlalchemy import MetaData, String -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey +from compendium_v2.db import db -logger = logging.getLogger(__name__) - -convention = { - "ix": "ix_%(column_0_label)s", - "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_%(constraint_name)s", - "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s", -} -metadata_obj = MetaData(naming_convention=convention) +logger = logging.getLogger(__name__) str128 = Annotated[str, 128] +str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)] +str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)] int_pk = Annotated[int, mapped_column(primary_key=True)] int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)] -class Base(DeclarativeBase): - metadata = metadata_obj - type_annotation_map = { - str128: String(128), - } +# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet. +# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140 +# mypy: disable-error-code="name-defined" -class NREN(Base): +class NREN(db.Model): __tablename__ = 'nren' id: Mapped[int_pk] name: Mapped[str128] -class BudgetEntry(Base): +class BudgetEntry(db.Model): __tablename__ = 'budgets' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -47,7 +39,7 @@ class BudgetEntry(Base): budget: Mapped[Decimal] -class FundingSource(Base): +class FundingSource(db.Model): __tablename__ = 'funding_source' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -67,7 +59,7 @@ class FeeType(Enum): other = "other" -class ChargingStructure(Base): +class ChargingStructure(db.Model): __tablename__ = 'charging_structure' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -75,7 +67,7 @@ class ChargingStructure(Base): fee_type: Mapped[Optional[FeeType]] -class NrenStaff(Base): +class NrenStaff(db.Model): __tablename__ = 'nren_staff' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -86,7 +78,7 @@ class NrenStaff(Base): non_technical_fte: Mapped[Decimal] -class ParentOrganization(Base): +class ParentOrganization(db.Model): __tablename__ = 'parent_organization' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -94,18 +86,18 @@ class ParentOrganization(Base): organization: Mapped[str128] -class SubOrganization(Base): +class SubOrganization(db.Model): __tablename__ = 'sub_organization' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') year: Mapped[int_pk] - organization: Mapped[str128] = mapped_column(primary_key=True) + organization: Mapped[str128_pk] role: Mapped[str128] -class ECProject(Base): +class ECProject(db.Model): __tablename__ = 'ec_project' nren_id: Mapped[int_pk_fkNREN] - nren: Mapped[NREN] = relationship(NREN, lazy='joined') + nren: Mapped[NREN] = relationship(lazy='joined') year: Mapped[int_pk] - project: Mapped[str] = mapped_column(String(256), primary_key=True) + project: Mapped[str256_pk] diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py index 0307be33..5ea9c8d3 100644 --- a/compendium_v2/migrations/env.py +++ b/compendium_v2/migrations/env.py @@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool from alembic import context -from compendium_v2.db.model import metadata_obj +from compendium_v2.db import metadata_obj # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py index 3b29b540..f25f03c9 100644 --- a/compendium_v2/migrations/migration_utils.py +++ b/compendium_v2/migrations/migration_utils.py @@ -1,7 +1,6 @@ import logging import os -from compendium_v2 import db from alembic.config import Config from alembic import command @@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY): command.upgrade(alembic_config, 'head') +def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): + return (f'postgresql://{db_username}:{db_password}' + f'@{db_hostname}:{port}/{db_name}') + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - upgrade(db.postgresql_dsn( + upgrade(postgresql_dsn( db_username='compendium', db_password='compendium321', db_hostname='localhost', diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py index d95ad4e1..43d1bf02 100644 --- a/compendium_v2/publishers/helpers.py +++ b/compendium_v2/publishers/helpers.py @@ -1,21 +1,20 @@ -from compendium_v2 import db, survey_db -from compendium_v2.db import model +from sqlalchemy import select + +from compendium_v2 import survey_db +from compendium_v2.db import db, model def init_db(config): - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) dsn_survey = config['SURVEY_DATABASE_URI'] survey_db.init_db_model(dsn_survey) -def get_uppercase_nren_dict(session): +def get_uppercase_nren_dict(): """ - :param session: db session that is used to query the known NRENs :return: a dictionary of all known NRENs db entities keyed on the uppercased name """ - current_nrens = session.query(model.NREN).all() + current_nrens = db.session.scalars(select(model.NREN)) nren_dict = {nren.name.upper(): nren for nren in current_nrens} # add aliases that are used in the source data: nren_dict['ASNET'] = nren_dict['ASNET-AM'] diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py index 710c899d..ea88ee77 100644 --- a/compendium_v2/publishers/survey_publisher_2022.py +++ b/compendium_v2/publishers/survey_publisher_2022.py @@ -16,11 +16,12 @@ import html from sqlalchemy import text from collections import defaultdict +import compendium_v2 from compendium_v2.db.model import FeeType from compendium_v2.environment import setup_logging from compendium_v2.config import load -from compendium_v2 import db, survey_db -from compendium_v2.db import model +from compendium_v2 import survey_db +from compendium_v2.db import db, model from compendium_v2.publishers import helpers setup_logging() @@ -133,163 +134,151 @@ def query_question(question: enum.Enum): return survey.execute(text(query)) -def transfer_budget(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_budget() - for row in rows: +def transfer_budget(nren_dict): + rows = query_budget() + for row in rows: + nren_name = row[0].upper() + _budget = row[1] + try: + budget = float(_budget.replace('"', '').replace(',', '')) + except ValueError: + logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))') + continue + + if budget > 200: + logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})') + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + budget_entry = model.BudgetEntry( + nren=nren_dict[nren_name], + budget=budget, + year=2022, + ) + db.session.merge(budget_entry) + db.session.commit() + + +def transfer_funding_sources(nren_dict): + sourcedata = {} + for source, data in query_funding_sources(): + for row in data: nren_name = row[0].upper() - _budget = row[1] + _value = row[1] try: - budget = float(_budget.replace('"', '').replace(',', '')) + value = float(_value.replace('"', '').replace(',', '')) except ValueError: - logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))') - continue - - if budget > 200: - logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})') - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue + name = source.name + logger.info(f'{nren_name} has invalid value for {name}. ({_value}))') + value = 0 - budget_entry = model.BudgetEntry( - nren=nren_dict[nren_name], - budget=budget, - year=2022, + nren_info = sourcedata.setdefault( + nren_name, + {source_type: 0 for source_type in FundingSource} ) - session.merge(budget_entry) - session.commit() - - -def transfer_funding_sources(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - sourcedata = {} - for source, data in query_funding_sources(): - for row in data: - nren_name = row[0].upper() - _value = row[1] - try: - value = float(_value.replace('"', '').replace(',', '')) - except ValueError: - name = source.name - logger.info(f'{nren_name} has invalid value for {name}. ({_value}))') - value = 0 - - nren_info = sourcedata.setdefault( - nren_name, - {source_type: 0 for source_type in FundingSource} - ) - nren_info[source] = value - - for nren_name, nren_info in sourcedata.items(): - total = sum(nren_info.values()) - - if not math.isclose(total, 100, abs_tol=0.01): - logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})') + nren_info[source] = value + + for nren_name, nren_info in sourcedata.items(): + total = sum(nren_info.values()) + + if not math.isclose(total, 100, abs_tol=0.01): + logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})') + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + funding_source = model.FundingSource( + nren=nren_dict[nren_name], + year=2022, + client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS], + european_funding=nren_info[FundingSource.EUROPEAN_FUNDING], + gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES], + commercial=nren_info[FundingSource.COMMERCIAL], + other=nren_info[FundingSource.OTHER], + ) + db.session.merge(funding_source) + db.session.commit() + + +def transfer_staff_data(nren_dict): + data = {} + for question in StaffQuestion: + rows = query_question(question) + for row in rows: + nren_name = row[0].upper() + _value = row[1] + try: + value = float(_value.replace('"', '').replace(',', '')) + except ValueError: + value = 0 if nren_name not in nren_dict: logger.info(f'{nren_name} unknown. Skipping.') continue - funding_source = model.FundingSource( - nren=nren_dict[nren_name], - year=2022, - client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS], - european_funding=nren_info[FundingSource.EUROPEAN_FUNDING], - gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES], - commercial=nren_info[FundingSource.COMMERCIAL], - other=nren_info[FundingSource.OTHER], - ) - session.merge(funding_source) - session.commit() - - -def transfer_staff_data(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - data = {} - for question in StaffQuestion: - rows = query_question(question) - for row in rows: - nren_name = row[0].upper() - _value = row[1] - try: - value = float(_value.replace('"', '').replace(',', '')) - except ValueError: - value = 0 - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - # initialize on first use, so we don't add data for nrens with no answers - data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[ - question] = value - - for nren_name, nren_info in data.items(): - if sum([nren_info[question] for question in StaffQuestion]) == 0: - logger.info(f'{nren_name} has no staff data. Deleting if exists.') - session.query(model.NrenStaff).filter( - model.NrenStaff.nren_id == nren_dict[nren_name].id, - model.NrenStaff.year == 2022, - ).delete() - continue - - employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] - technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE] - - if not math.isclose(employed, technical, abs_tol=0.01): - logger.info(f'{nren_name} FTE do not equal across employed/technical categories.' - f' ({employed} != {technical})') - - staff_data = model.NrenStaff( - nren_id=nren_dict[nren_name].id, - year=2022, - permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE], - subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE], - technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE], - non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE], - ) - session.merge(staff_data) - session.commit() - - -def transfer_nren_parent_org(): + # initialize on first use, so we don't add data for nrens with no answers + data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[ + question] = value + + for nren_name, nren_info in data.items(): + if sum([nren_info[question] for question in StaffQuestion]) == 0: + logger.info(f'{nren_name} has no staff data. Deleting if exists.') + db.session.query(model.NrenStaff).filter( + model.NrenStaff.nren_id == nren_dict[nren_name].id, + model.NrenStaff.year == 2022, + ).delete() + continue + + employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] + technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE] + + if not math.isclose(employed, technical, abs_tol=0.01): + logger.info(f'{nren_name} FTE do not equal across employed/technical categories.' + f' ({employed} != {technical})') + + staff_data = model.NrenStaff( + nren_id=nren_dict[nren_name].id, + year=2022, + permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE], + subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE], + technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE], + non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE], + ) + db.session.merge(staff_data) + db.session.commit() + + +def transfer_nren_parent_org(nren_dict): # clean up the data a bit by removing some strings strings_to_replace = [ 'We are affiliated to ' ] - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_question(OrgQuestion.PARENT_ORG_NAME) - for row in rows: - nren_name = row[0].upper() - value = str(row[1]).replace('"', '') + rows = query_question(OrgQuestion.PARENT_ORG_NAME) + for row in rows: + nren_name = row[0].upper() + value = str(row[1]).replace('"', '') - for string in strings_to_replace: - value = value.replace(string, '') + for string in strings_to_replace: + value = value.replace(string, '') - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue - parent_org = model.ParentOrganization( - nren_id=nren_dict[nren_name].id, - year=2022, - organization=value, - ) - session.merge(parent_org) - session.commit() + parent_org = model.ParentOrganization( + nren_id=nren_dict[nren_name].id, + year=2022, + organization=value, + ) + db.session.merge(parent_org) + db.session.commit() -def transfer_nren_sub_org(): +def transfer_nren_sub_org(nren_dict): suborg_questions = [ (OrgQuestion.SUB_ORGS_1_NAME, OrgQuestion.SUB_ORGS_1_CHOICE, OrgQuestion.SUB_ORGS_1_ROLE), (OrgQuestion.SUB_ORGS_2_NAME, OrgQuestion.SUB_ORGS_2_CHOICE, OrgQuestion.SUB_ORGS_2_ROLE), @@ -298,140 +287,134 @@ def transfer_nren_sub_org(): (OrgQuestion.SUB_ORGS_5_NAME, OrgQuestion.SUB_ORGS_5_CHOICE, OrgQuestion.SUB_ORGS_5_ROLE) ] - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - lookup = defaultdict(list) - - for name, choice, role in suborg_questions: - _name_rows = query_question(name) - _choice_rows = query_question(choice) - _role_rows = list(query_question(role)) - for _name, _choice in zip(_name_rows, _choice_rows): - nren_name = _name[0].upper() - suborg_name = _name[1].replace('"', '').strip() - role_choice = _choice[1].replace('"', '').strip() - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - if role_choice.lower() == 'other': - for _role in _role_rows: - if _role[0] == _name[0]: - role = _role[1].replace('"', '').strip() - break - else: - role = role_choice - - lookup[nren_name].append((suborg_name, role)) - - for nren_name, suborgs in lookup.items(): - for suborg_name, role in suborgs: - suborg = model.SubOrganization( - nren_id=nren_dict[nren_name].id, - year=2022, - organization=suborg_name, - role=role, - ) - session.merge(suborg) - session.commit() - - -def transfer_charging_structure(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_question(ChargingStructure.charging_structure) - for row in rows: - nren_name = row[0].upper() - value = row[1].replace('"', '').strip() + lookup = defaultdict(list) + + for name, choice, role in suborg_questions: + _name_rows = query_question(name) + _choice_rows = query_question(choice) + _role_rows = list(query_question(role)) + for _name, _choice in zip(_name_rows, _choice_rows): + nren_name = _name[0].upper() + suborg_name = _name[1].replace('"', '').strip() + role_choice = _choice[1].replace('"', '').strip() if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping from charging structure.') + logger.info(f'{nren_name} unknown. Skipping.') continue - if "do not charge" in value: - charging_structure = FeeType.no_charge - elif "combination" in value: - charging_structure = FeeType.combination - elif "flat" in value: - charging_structure = FeeType.flat_fee - elif "usage-based" in value: - charging_structure = FeeType.usage_based_fee - elif "Other" in value: - charging_structure = FeeType.other + if role_choice.lower() == 'other': + for _role in _role_rows: + if _role[0] == _name[0]: + role = _role[1].replace('"', '').strip() + break else: - charging_structure = None + role = role_choice + + lookup[nren_name].append((suborg_name, role)) - charging_structure = model.ChargingStructure( + for nren_name, suborgs in lookup.items(): + for suborg_name, role in suborgs: + suborg = model.SubOrganization( nren_id=nren_dict[nren_name].id, year=2022, - fee_type=charging_structure, + organization=suborg_name, + role=role, ) - session.merge(charging_structure) - session.commit() - - -def transfer_ec_projects(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # delete all existing EC projects, in case something changed - session.query(model.ECProject).filter( - model.ECProject.year == 2022, - ).delete() - - rows = query_question(ECQuestion.EC_PROJECT) - for row in rows: - nren_name = row[0].upper() - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') + db.session.merge(suborg) + db.session.commit() + + +def transfer_charging_structure(nren_dict): + rows = query_question(ChargingStructure.charging_structure) + for row in rows: + nren_name = row[0].upper() + value = row[1].replace('"', '').strip() + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping from charging structure.') + continue + + if "do not charge" in value: + charging_structure = FeeType.no_charge + elif "combination" in value: + charging_structure = FeeType.combination + elif "flat" in value: + charging_structure = FeeType.flat_fee + elif "usage-based" in value: + charging_structure = FeeType.usage_based_fee + elif "Other" in value: + charging_structure = FeeType.other + else: + charging_structure = None + + charging_structure = model.ChargingStructure( + nren_id=nren_dict[nren_name].id, + year=2022, + fee_type=charging_structure, + ) + db.session.merge(charging_structure) + db.session.commit() + + +def transfer_ec_projects(nren_dict): + # delete all existing EC projects, in case something changed + db.session.query(model.ECProject).filter( + model.ECProject.year == 2022, + ).delete() + + rows = query_question(ECQuestion.EC_PROJECT) + for row in rows: + nren_name = row[0].upper() + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + try: + value = json.loads(row[1]) + except json.decoder.JSONDecodeError: + logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.') + continue + + for val in value: + if not val: + logger.info(f'Invalid EC project value for {nren_name}: {val}.') continue - try: - value = json.loads(row[1]) - except json.decoder.JSONDecodeError: - logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.') - continue - - for val in value: - if not val: - logger.info(f'Invalid EC project value for {nren_name}: {val}.') - continue - - # strip html entities/NBSP from val - val = html.unescape(val).replace('\xa0', ' ') + # strip html entities/NBSP from val + val = html.unescape(val).replace('\xa0', ' ') - # some answers include contract numbers, which we don't want here - val = val.split('(contract n')[0] + # some answers include contract numbers, which we don't want here + val = val.split('(contract n')[0] - ec_project = model.ECProject( - nren_id=nren_dict[nren_name].id, - year=2022, - project=str(val).strip() - ) - session.add(ec_project) - session.commit() + ec_project = model.ECProject( + nren_id=nren_dict[nren_name].id, + year=2022, + project=str(val).strip() + ) + db.session.add(ec_project) + db.session.commit() -def _cli(config): +def _cli(config, app): helpers.init_db(config) - transfer_budget() - transfer_funding_sources() - transfer_staff_data() - transfer_nren_parent_org() - transfer_nren_sub_org() - transfer_charging_structure() - transfer_ec_projects() + with app.app_context(): + nren_dict = helpers.get_uppercase_nren_dict() + transfer_budget(nren_dict) + transfer_funding_sources(nren_dict) + transfer_staff_data(nren_dict) + transfer_nren_parent_org(nren_dict) + transfer_nren_sub_org(nren_dict) + transfer_charging_structure(nren_dict) + transfer_ec_projects(nren_dict) @click.command() @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) - _cli(app_config) + app = compendium_v2._create_app_with_db(app_config) + _cli(app_config, app) if __name__ == "__main__": diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py index ba0e4e2a..5f4cd8b2 100644 --- a/compendium_v2/publishers/survey_publisher_v1.py +++ b/compendium_v2/publishers/survey_publisher_v1.py @@ -11,11 +11,12 @@ import logging import math import click +import compendium_v2 from compendium_v2.environment import setup_logging -from compendium_v2 import db, survey_db +from compendium_v2 import survey_db from compendium_v2.background_task import parse_excel_data from compendium_v2.config import load -from compendium_v2.db import model +from compendium_v2.db import db, model from compendium_v2.survey_db import model as survey_model from compendium_v2.publishers import helpers @@ -24,11 +25,8 @@ setup_logging() logger = logging.getLogger('survey-publisher-v1') -def db_budget_migration(): - with survey_db.session_scope() as survey_session, \ - db.session_scope() as session: - - nren_dict = helpers.get_uppercase_nren_dict(session) +def db_budget_migration(nren_dict): + with survey_db.session_scope() as survey_session: # move data from Survey DB budget table data = survey_session.query(survey_model.Nrens) @@ -49,7 +47,7 @@ def db_budget_migration(): budget=float(budget.budget), year=year ) - session.merge(budget_entry) + db.session.merge(budget_entry) # Import the data from excel sheet to database exceldata = parse_excel_data.fetch_budget_excel_data() @@ -63,165 +61,153 @@ def db_budget_migration(): logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year) - session.merge(budget_entry) - session.commit() - - -def db_funding_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # Import the data to database - data = parse_excel_data.fetch_funding_excel_data() - - for (abbrev, year, client_institution, - european_funding, - gov_public_bodies, - commercial, other) in data: - - _data = [client_institution, european_funding, gov_public_bodies, commercial, other] - total = sum(_data) - if not math.isclose(total, 100, abs_tol=0.01) and total != 0: - logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') - - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - budget_entry = model.FundingSource( - nren=nren_dict[abbrev], - year=year, - client_institutions=client_institution, - european_funding=european_funding, - gov_public_bodies=gov_public_bodies, - commercial=commercial, - other=other) - session.merge(budget_entry) - session.commit() - - -def db_charging_structure_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # Import the data to database - data = parse_excel_data.fetch_charging_structure_excel_data() - - for (abbrev, year, charging_structure) in data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - charging_structure_entry = model.ChargingStructure( - nren=nren_dict[abbrev], year=year, fee_type=charging_structure) - session.merge(charging_structure_entry) - session.commit() - - -def db_staffing_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - staff_data = parse_excel_data.fetch_staffing_excel_data() - - nren_staff_map = {} - for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping staff data.') - continue - - nren = nren_dict[abbrev] + db.session.merge(budget_entry) + db.session.commit() + + +def db_funding_migration(nren_dict): + # Import the data to database + data = parse_excel_data.fetch_funding_excel_data() + + for (abbrev, year, client_institution, + european_funding, + gov_public_bodies, + commercial, other) in data: + + _data = [client_institution, european_funding, gov_public_bodies, commercial, other] + total = sum(_data) + if not math.isclose(total, 100, abs_tol=0.01) and total != 0: + logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') + + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + budget_entry = model.FundingSource( + nren=nren_dict[abbrev], + year=year, + client_institutions=client_institution, + european_funding=european_funding, + gov_public_bodies=gov_public_bodies, + commercial=commercial, + other=other) + db.session.merge(budget_entry) + db.session.commit() + + +def db_charging_structure_migration(nren_dict): + # Import the data to database + data = parse_excel_data.fetch_charging_structure_excel_data() + + for (abbrev, year, charging_structure) in data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + charging_structure_entry = model.ChargingStructure( + nren=nren_dict[abbrev], year=year, fee_type=charging_structure) + db.session.merge(charging_structure_entry) + db.session.commit() + + +def db_staffing_migration(nren_dict): + staff_data = parse_excel_data.fetch_staffing_excel_data() + + nren_staff_map = {} + for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping staff data.') + continue + + nren = nren_dict[abbrev] + nren_staff_map[(nren.id, year)] = model.NrenStaff( + nren=nren, + nren_id=nren.id, + year=year, + permanent_fte=permanent_fte, + subcontracted_fte=subcontracted_fte, + technical_fte=0, + non_technical_fte=0 + ) + + function_data = parse_excel_data.fetch_staff_function_excel_data() + for (abbrev, year, technical_fte, non_technical_fte) in function_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping staff function data.') + continue + + nren = nren_dict[abbrev] + if (nren.id, year) in nren_staff_map: + nren_staff_map[(nren.id, year)].technical_fte = technical_fte + nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte + else: nren_staff_map[(nren.id, year)] = model.NrenStaff( nren=nren, nren_id=nren.id, year=year, - permanent_fte=permanent_fte, - subcontracted_fte=subcontracted_fte, - technical_fte=0, - non_technical_fte=0 + permanent_fte=0, + subcontracted_fte=0, + technical_fte=technical_fte, + non_technical_fte=non_technical_fte ) - function_data = parse_excel_data.fetch_staff_function_excel_data() - for (abbrev, year, technical_fte, non_technical_fte) in function_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping staff function data.') - continue - - nren = nren_dict[abbrev] - if (nren.id, year) in nren_staff_map: - nren_staff_map[(nren.id, year)].technical_fte = technical_fte - nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte - else: - nren_staff_map[(nren.id, year)] = model.NrenStaff( - nren=nren, - nren_id=nren.id, - year=year, - permanent_fte=0, - subcontracted_fte=0, - technical_fte=technical_fte, - non_technical_fte=non_technical_fte - ) - - for nren_staff_model in nren_staff_map.values(): - employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte - technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte - if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0: - logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:' - f' FTE do not equal across employed/technical categories ({employed} != {technical})') - - session.merge(nren_staff_model) + for nren_staff_model in nren_staff_map.values(): + employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte + technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte + if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0: + logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:' + f' FTE do not equal across employed/technical categories ({employed} != {technical})') - session.commit() + db.session.merge(nren_staff_model) + db.session.commit() -def db_ecprojects_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - ecproject_data = parse_excel_data.fetch_ecproject_excel_data() - for (abbrev, year, project) in ecproject_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - nren = nren_dict[abbrev] - ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) - session.merge(ecproject_entry) - session.commit() +def db_ecprojects_migration(nren_dict): + ecproject_data = parse_excel_data.fetch_ecproject_excel_data() + for (abbrev, year, project) in ecproject_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + nren = nren_dict[abbrev] + ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) + db.session.merge(ecproject_entry) + db.session.commit() -def db_organizations_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - organization_data = parse_excel_data.fetch_organization_excel_data() - for (abbrev, year, org) in organization_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue +def db_organizations_migration(nren_dict): + organization_data = parse_excel_data.fetch_organization_excel_data() + for (abbrev, year, org) in organization_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue - nren = nren_dict[abbrev] - org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) - session.merge(org_entry) - session.commit() + nren = nren_dict[abbrev] + org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) + db.session.merge(org_entry) + db.session.commit() -def _cli(config): +def _cli(config, app): helpers.init_db(config) - db_budget_migration() - db_funding_migration() - db_charging_structure_migration() - db_staffing_migration() - db_ecprojects_migration() - db_organizations_migration() + with app.app_context(): + nren_dict = helpers.get_uppercase_nren_dict() + db_budget_migration(nren_dict) + db_funding_migration(nren_dict) + db_charging_structure_migration(nren_dict) + db_staffing_migration(nren_dict) + db_ecprojects_migration(nren_dict) + db_organizations_migration(nren_dict) @click.command() @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) + app = compendium_v2._create_app_with_db(app_config) print("survey-publisher-v1 starting") - _cli(app_config) + _cli(app_config, app) if __name__ == "__main__": diff --git a/compendium_v2/routes/budget.py b/compendium_v2/routes/budget.py index 763b3af3..1f4ff850 100644 --- a/compendium_v2/routes/budget.py +++ b/compendium_v2/routes/budget.py @@ -1,28 +1,17 @@ import logging from typing import Any -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db -from compendium_v2.db import model +from compendium_v2.db import db +from compendium_v2.db.model import BudgetEntry from compendium_v2.routes import common -routes = Blueprint('budget', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('budget', __name__) logger = logging.getLogger(__name__) -col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061', - '#bd7ebe', '#ffb55a', '#ffee65', - '#beb9db', '#fdcce5', '#8bd3c7'] - BUDGET_RESPONSE_SCHEMA = { '$schema': 'http://json-schema.org/draft-07/schema#', @@ -58,15 +47,15 @@ def budget_view() -> Any: :return: """ - def _extract_data(entry: model.BudgetEntry): + def _extract_data(entry: BudgetEntry): return { 'NREN': entry.nren.name, 'BUDGET': float(entry.budget), 'BUDGET_YEAR': entry.year, } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.BudgetEntry)], - key=lambda d: (d['BUDGET_YEAR'], d['NREN'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))], + key=lambda d: (d['BUDGET_YEAR'], d['NREN']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/charging.py b/compendium_v2/routes/charging.py index 0b2c2384..57e8114f 100644 --- a/compendium_v2/routes/charging.py +++ b/compendium_v2/routes/charging.py @@ -1,21 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('charging', __name__) - +from flask import Blueprint, jsonify +from sqlalchemy import select -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ChargingStructure +from compendium_v2.routes import common +routes = Blueprint('charging', __name__) logger = logging.getLogger(__name__) CHARGING_STRUCTURE_RESPONSE_SCHEMA = { @@ -53,16 +47,15 @@ def charging_structure_view() -> Any: :return: """ - def _extract_data(entry: model.ChargingStructure): + def _extract_data(entry: ChargingStructure): return { 'NREN': entry.nren.name, 'YEAR': int(entry.year), 'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None, } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.ChargingStructure) - .all()], - key=lambda d: (d['NREN'], d['YEAR'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))], + key=lambda d: (d['NREN'], d['YEAR']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/ec_projects.py b/compendium_v2/routes/ec_projects.py index 7114718d..b58d8931 100644 --- a/compendium_v2/routes/ec_projects.py +++ b/compendium_v2/routes/ec_projects.py @@ -1,22 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app - -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('ec-projects', __name__) +from flask import Blueprint, jsonify +from sqlalchemy import select - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ECProject +from compendium_v2.routes import common +routes = Blueprint('ec-projects', __name__) logger = logging.getLogger(__name__) EC_PROJECTS_RESPONSE_SCHEMA = { @@ -55,13 +48,12 @@ def ec_projects_view() -> Any: :return: """ - def _extract_project(entry: model.ECProject): + def _extract_project(entry: ECProject): return { 'nren': entry.nren.name, 'year': entry.year, 'project': entry.project } - with db.session_scope() as session: - result = [_extract_project(project) for project in session.query(model.ECProject)] + result = [_extract_project(project) for project in db.session.scalars(select(ECProject))] return jsonify(result) diff --git a/compendium_v2/routes/funding.py b/compendium_v2/routes/funding.py index ed0e26c8..c02bf136 100644 --- a/compendium_v2/routes/funding.py +++ b/compendium_v2/routes/funding.py @@ -1,22 +1,15 @@ import logging -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db from compendium_v2.routes import common -from compendium_v2.db import model +from compendium_v2.db import db +from compendium_v2.db.model import FundingSource from typing import Any -routes = Blueprint('funding', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('funding', __name__) logger = logging.getLogger(__name__) FUNDING_RESPONSE_SCHEMA = { @@ -59,7 +52,7 @@ def funding_source_view() -> Any: :return: """ - def _extract_data(entry: model.FundingSource): + def _extract_data(entry: FundingSource): return { 'NREN': entry.nren.name, 'YEAR': entry.year, @@ -70,8 +63,8 @@ def funding_source_view() -> Any: 'OTHER': float(entry.other) } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.FundingSource)], - key=lambda d: (d['NREN'], d['YEAR'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))], + key=lambda d: (d['NREN'], d['YEAR']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/organization.py b/compendium_v2/routes/organization.py index 61a43354..8e8ebc8d 100644 --- a/compendium_v2/routes/organization.py +++ b/compendium_v2/routes/organization.py @@ -1,22 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app - -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('organization', __name__) +from flask import Blueprint, jsonify +from sqlalchemy import select - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ParentOrganization, SubOrganization +from compendium_v2.routes import common +routes = Blueprint('organization', __name__) logger = logging.getLogger(__name__) ORGANIZATION_RESPONSE_SCHEMA = { @@ -71,15 +64,14 @@ def parent_organization_view() -> Any: :return: """ - def _extract_parent(entry: model.ParentOrganization): + def _extract_parent(entry: ParentOrganization): return { 'nren': entry.nren.name, 'year': entry.year, 'name': entry.organization } - with db.session_scope() as session: - result = [_extract_parent(org) for org in session.query(model.ParentOrganization)] + result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))] return jsonify(result) @@ -98,7 +90,7 @@ def sub_organization_view() -> Any: :return: """ - def _extract_sub(entry: model.SubOrganization): + def _extract_sub(entry: SubOrganization): return { 'nren': entry.nren.name, 'year': entry.year, @@ -106,6 +98,5 @@ def sub_organization_view() -> Any: 'role': entry.role } - with db.session_scope() as session: - result = [_extract_sub(org) for org in session.query(model.SubOrganization)] + result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))] return jsonify(result) diff --git a/compendium_v2/routes/staff.py b/compendium_v2/routes/staff.py index 73e79c48..b8478266 100644 --- a/compendium_v2/routes/staff.py +++ b/compendium_v2/routes/staff.py @@ -1,22 +1,15 @@ import logging -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db +from compendium_v2.db import db +from compendium_v2.db.model import NREN, NrenStaff from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('staff', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('staff', __name__) logger = logging.getLogger(__name__) STAFF_RESPONSE_SCHEMA = { @@ -57,7 +50,7 @@ def staff_view() -> Any: :return: """ - def _extract_data(entry: model.NrenStaff): + def _extract_data(entry: NrenStaff): return { 'nren': entry.nren.name, 'year': entry.year, @@ -67,7 +60,6 @@ def staff_view() -> Any: 'non_technical_fte': float(entry.non_technical_fte) } - with db.session_scope() as session: - entries = [_extract_data(entry) for entry in session.query( - model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())] + entries = [_extract_data(entry) for entry in db.session.scalars( + select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))] return jsonify(entries) diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py index ce48aaa0..1550ddcb 100644 --- a/compendium_v2/survey_db/__init__.py +++ b/compendium_v2/survey_db/__init__.py @@ -31,11 +31,6 @@ def session_scope( session.close() -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') - - def init_db_model(dsn): global _SESSION_MAKER diff --git a/requirements.txt b/requirements.txt index 12ec2076..98804e02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ click~=8.1 jsonschema~=4.17 flask~=2.2 flask-cors~=3.0 +flask-sqlalchemy~=3.0 openpyxl~=3.1 psycopg2-binary~=2.9 SQLAlchemy~=2.0 diff --git a/setup.py b/setup.py index 213288e5..d051ddd7 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ setup( 'jsonschema~=4.17', 'flask~=2.2', 'flask-cors~=3.0', + 'flask-sqlalchemy~=3.0', 'openpyxl~=3.1', 'psycopg2-binary~=2.9', 'SQLAlchemy~=2.0', diff --git a/test/conftest.py b/test/conftest.py index 79186212..ba6b9a43 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,20 +1,15 @@ -import json +import csv import os -import tempfile -import random - import pytest -import compendium_v2 -from compendium_v2 import db -from compendium_v2.db import model -from compendium_v2 import survey_db -from compendium_v2.survey_db import model as survey_model +import random from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool -import csv +import compendium_v2 +from compendium_v2.db import db, model +from compendium_v2.survey_db import model as survey_model def _test_data_csv(filename): @@ -43,61 +38,29 @@ def mocked_survey_db(mocker): @pytest.fixture -def mocked_db(mocker): - # cf. https://stackoverflow.com/a/33057675 - engine = create_engine( - 'sqlite://', - connect_args={'check_same_thread': False}, - poolclass=StaticPool, - echo=False) - model.Base.metadata.create_all(engine) - mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine)) - mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None) - mocker.patch('compendium_v2.migrate_database', lambda config: None) - - -@pytest.fixture -def test_budget_data(): - with db.session_scope() as session: +def test_budget_data(app): + with app.app_context(): data = [row for row in _test_data_csv("BudgetTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] budget = row["budget"] year = row["year"] - session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) - - with survey_db.session_scope() as session: - data = _test_data_csv("BudgetTestData.csv") - nrens = set() - budgets_data = [] - for row in data: - nren = row["nren"] - budget = row["budget"] - year = row["year"] - country_code = row["nren"] - - nrens.add(nren) - - budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code)) - - for nren in nrens: - session.add(survey_model.Nrens(abbreviation=nren, country_code=nren)) - - session.add_all(budgets_data) + db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) + db.session.commit() @pytest.fixture -def test_funding_source_data(): - with db.session_scope() as session: +def test_funding_source_data(app): + with app.app_context(): data = [row for row in _test_data_csv("FundingSourceTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -108,7 +71,7 @@ def test_funding_source_data(): commercial = row["commercial"] other = row["other"] - session.add( + db.session.add( model.FundingSource( nren=nren, year=year, client_institutions=client, @@ -117,10 +80,11 @@ def test_funding_source_data(): commercial=commercial, other=other) ) + db.session.commit() @pytest.fixture -def test_staff_data(): +def test_staff_data(app): # generator of random test data for 5 years and 100 nrens def _generate_rows(): @@ -135,12 +99,12 @@ def test_staff_data(): "non_technical_fte": random.randint(0, 100) } - with db.session_scope() as session: + with app.app_context(): data = list(_generate_rows()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -150,7 +114,7 @@ def test_staff_data(): technical_fte = row["technical_fte"] non_technical_fte = row["non_technical_fte"] - session.add( + db.session.add( model.NrenStaff( nren=nren, year=year, @@ -160,30 +124,29 @@ def test_staff_data(): non_technical_fte=non_technical_fte ) ) + db.session.commit() @pytest.fixture -def data_config_filename(dummy_config): - with tempfile.NamedTemporaryFile() as f: - f.write(json.dumps(dummy_config).encode('utf-8')) - f.flush() - yield f.name +def app(dummy_config): + app = compendium_v2._create_app_with_db(dummy_config) + with app.app_context(): + db.create_all() + yield app @pytest.fixture -def client(data_config_filename, mocked_db, mocked_survey_db): - os.environ['SETTINGS_FILENAME'] = data_config_filename - with compendium_v2.create_app().test_client() as c: - yield c +def client(app): + return app.test_client() @pytest.fixture -def test_charging_structure_data(): - with db.session_scope() as session: +def test_charging_structure_data(app): + with app.app_context(): data = [row for row in _test_data_csv("ChargingStructureTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -192,15 +155,16 @@ def test_charging_structure_data(): if fee_type == "null": fee_type = None - session.add( + db.session.add( model.ChargingStructure( nren=nren, year=year, fee_type=fee_type) ) + db.session.commit() @pytest.fixture -def test_organization_data(): +def test_organization_data(app): def _generate_sub_org_data(): for nren in ["nren" + str(i) for i in range(1, 50)]: for year in range(2016, 2021): @@ -220,21 +184,21 @@ def test_organization_data(): 'name': 'org' + str(year) } - with db.session_scope() as session: + with app.app_context(): org_data = list(_generate_org_data()) sub_org_data = list(_generate_sub_org_data()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for org in org_data: nren = nren_dict[org["nren"]] year = org["year"] name = org["name"] - session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) + db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) for sub_org in sub_org_data: nren = nren_dict[sub_org["nren"]] @@ -242,13 +206,13 @@ def test_organization_data(): name = sub_org["name"] role = sub_org["role"] - session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) + db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) - session.commit() + db.session.commit() @pytest.fixture -def test_ec_project_data(): +def test_ec_project_data(app): def _generate_ec_project_data(): for nren in ["nren" + str(i) for i in range(1, 50)]: for year in range(2016, 2021): @@ -264,19 +228,19 @@ def test_ec_project_data(): 'project': 'ec_project2', } - with db.session_scope() as session: + with app.app_context(): ec_project_data = list(_generate_ec_project_data()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in set(d['nren'] for d in ec_project_data)} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for ec_project in ec_project_data: nren = nren_dict[ec_project["nren"]] year = ec_project["year"] project = ec_project["project"] - session.add(model.ECProject(nren=nren, year=year, project=project)) + db.session.add(model.ECProject(nren=nren, year=year, project=project)) - session.commit() + db.session.commit() diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py index ec8802ed..433def2c 100644 --- a/test/test_survey_publisher_2022.py +++ b/test/test_survey_publisher_2022.py @@ -1,5 +1,4 @@ -from compendium_v2 import db -from compendium_v2.db import model +from compendium_v2.db import db, model from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \ StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion @@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority" ] -def test_publisher(client, mocker, dummy_config): +def test_publisher(app, mocker, dummy_config): global org_data def get_rows_as_tuples(*args, **kwargs): @@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config): mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data) mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data) - with db.session_scope() as session: - nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] + with app.app_context(): + db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.commit() - _cli(dummy_config) + _cli(dummy_config, app) - with db.session_scope() as session: - budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() + with app.app_context(): + budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() assert len(budgets) == 3 assert budgets[0].nren.name.lower() == 'nren1' assert budgets[0].budget == 100 - funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() + funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() assert len(funding_sources) == 3 assert funding_sources[0].nren.name.lower() == 'nren1' assert funding_sources[0].client_institutions == 10 @@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config): assert funding_sources[2].european_funding == 30 assert funding_sources[2].other == 30 - staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() + staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() assert len(staff_data) == 3 assert staff_data[0].nren.name.lower() == 'nren1' @@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config): assert staff_data[2].permanent_fte == 30 assert staff_data[2].subcontracted_fte == 0 - _org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() + _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() assert len(_org_data) == 2 assert _org_data[0].nren.name.lower() == 'nren1' @@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config): assert _org_data[1].nren.name.lower() == 'nren3' assert _org_data[1].organization == 'Org3' - charging_structures = session.query(model.ChargingStructure).order_by( + charging_structures = db.session.query(model.ChargingStructure).order_by( model.ChargingStructure.nren_id.asc()).all() assert len(charging_structures) == 3 assert charging_structures[0].nren.name.lower() == 'nren1' @@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config): assert charging_structures[2].nren.name.lower() == 'nren3' assert charging_structures[2].fee_type == model.FeeType.other - _ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() + _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() assert len(_ec_data) == 3 assert _ec_data[0].nren.name.lower() == 'nren2' diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py index c7ebc2d9..a6dbbaa1 100644 --- a/test/test_survey_publisher_v1.py +++ b/test/test_survey_publisher_v1.py @@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx") -def test_publisher(client, mocker, dummy_config): +def test_publisher(mocked_survey_db, app, mocker, dummy_config): mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE) - with db.session_scope() as session: + with app.app_context(): nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.commit() - _cli(dummy_config) + _cli(dummy_config, app) - with db.session_scope() as session: - budget_count = session.query(model.BudgetEntry.year).count() + with app.app_context(): + budget_count = db.session.query(model.BudgetEntry.year).count() assert budget_count - funding_source_count = session.query(model.FundingSource.year).count() + funding_source_count = db.session.query(model.FundingSource.year).count() assert funding_source_count - charging_structure_count = session.query(model.ChargingStructure.year).count() + charging_structure_count = db.session.query(model.ChargingStructure.year).count() assert charging_structure_count - staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() + staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() # data should only be saved for the NRENs we have saved in the database staff_data_nrens = set([staff.nren.name for staff in staff_data]) @@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config): assert kifu_data[5].technical_fte == 133 assert kifu_data[5].non_technical_fte == 45 - ecproject_data = session.query(model.ECProject).all() + ecproject_data = db.session.query(model.ECProject).all() # test a couple of random entries surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017] assert len(surf2017) == 1 @@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config): assert len(kifu2019) == 4 assert kifu2019[3].project == 'SuperHeroes for Science' - parent_data = session.query(model.ParentOrganization).all() + parent_data = db.session.query(model.ParentOrganization).all() # test a random entry asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021] assert len(asnet2021) == 1 -- GitLab