diff --git a/README.md b/README.md index 81b1b0b67f3762753a5688ce4fabd4e0edefcc62..dd9a7dce6b1534fbd04577e8af437a232c3e6334 100644 --- a/README.md +++ b/README.md @@ -63,11 +63,13 @@ survey-publisher-2022 ## Creating a db migration after editing the sqlalchemy models ```bash -cd compendium_v2 -alembic revision --autogenerate -m "description" +flask db migrate -m "description" ``` Then go to the created migration file to make any necessary additions, for example to migrate data. Also see https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect +Flask-migrate sets `compare_type=True` by default. Note that starting the application applies all upgrades. +This also happens when running `flask db` commands such as `flask db downgrade`, +so if you want to downgrade 2 or more versions you need to do so in one command, eg by specifying the revision number. diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py index 622565f2f7a3a5d39c1df9b562e11bd336dfc230..de1b313f5fe6f9d47c02339b240c875f978a306e 100644 --- a/compendium_v2/__init__.py +++ b/compendium_v2/__init__.py @@ -6,15 +6,11 @@ import os from flask import Flask from flask_cors import CORS # for debugging +# the currently available stubs for flask_migrate are old (they depend on sqlalchemy 1.4 types) +from flask_migrate import Migrate, upgrade # type: ignore from compendium_v2 import config, environment - -from compendium_v2.migrations import migration_utils - - -def migrate_database(config: dict) -> None: - dsn = config['SQLALCHEMY_DATABASE_URI'] - migration_utils.upgrade(dsn) +from compendium_v2.db import db def _create_app(app_config) -> Flask: @@ -33,6 +29,19 @@ def _create_app(app_config) -> Flask: return app +def _create_app_with_db(app_config) -> Flask: + # used by the tests and the publishers + app = _create_app(app_config) + app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI'] + + if 'SQLALCHEMY_BINDS' in app.config['CONFIG_PARAMS']: + # for the publishers + app.config['SQLALCHEMY_BINDS'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_BINDS'] + + db.init_app(app) + return app + + def create_app() -> Flask: """ overrides default settings with those found @@ -46,13 +55,16 @@ def create_app() -> Flask: with open(os.environ['SETTINGS_FILENAME']) as f: app_config = config.load(f) - app = _create_app(app_config) + app = _create_app_with_db(app_config) + + Migrate(app, db, directory=os.path.join(os.path.dirname(__file__), 'migrations')) logging.info('Flask app initialized') environment.setup_logging() # run migrations on startup - migrate_database(app_config) + with app.app_context(): + upgrade() return app diff --git a/compendium_v2/alembic.ini b/compendium_v2/alembic.ini deleted file mode 100644 index 2145863baa95551a588dd229ee525abd06f32742..0000000000000000000000000000000000000000 --- a/compendium_v2/alembic.ini +++ /dev/null @@ -1,10 +0,0 @@ -# A generic, single database configuration. - -# only needed for generating new revision scripts -[alembic] -# make sure the right line is un / commented depending on which schema you want -# a migration for -script_location = migrations -# script_location = cachedb_migrations -# change this to run migrations from the command line -sqlalchemy.url = postgresql://compendium:compendium321@localhost:65000/compendium diff --git a/compendium_v2/db/__init__.py b/compendium_v2/db/__init__.py index ce48aaa07a269ae2c20235b31c15f58cf8ff0956..13719e0a80d28d997ecf3f22b687c8cc1049a118 100644 --- a/compendium_v2/db/__init__.py +++ b/compendium_v2/db/__init__.py @@ -1,45 +1,17 @@ -import contextlib import logging -from typing import Optional, Union, Callable, Iterator -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker, Session +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData -logger = logging.getLogger(__name__) -_SESSION_MAKER: Union[None, sessionmaker] = None - - -@contextlib.contextmanager -def session_scope( - callback_before_close: Optional[Callable] = None) -> Iterator[Session]: - # best practice is to keep session scope separate from data processing - # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html - - assert _SESSION_MAKER - session = _SESSION_MAKER() - try: - yield session - session.commit() - if callback_before_close: - callback_before_close() - except SQLAlchemyError: - logger.error('caught sql layer exception, rolling back') - session.rollback() - raise # re-raise, will be handled by main consumer - finally: - session.close() - - -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') +logger = logging.getLogger(__name__) -def init_db_model(dsn): - global _SESSION_MAKER +metadata_obj = MetaData(naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +}) - # cf. https://docs.sqlalchemy.org/en - # /latest/orm/extensions/automap.html - engine = create_engine(dsn, pool_size=10) - _SESSION_MAKER = sessionmaker(bind=engine) +db = SQLAlchemy(metadata=metadata_obj) diff --git a/compendium_v2/db/model.py b/compendium_v2/db/model.py index 676727437984fcf4b0160d91060dfcc343c376c8..2199f759ccdd1dfa06bfec3af7955bc77d25bbe3 100644 --- a/compendium_v2/db/model.py +++ b/compendium_v2/db/model.py @@ -4,42 +4,34 @@ from enum import Enum from typing import Optional from typing_extensions import Annotated -from sqlalchemy import MetaData, String -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey +from compendium_v2.db import db -logger = logging.getLogger(__name__) - -convention = { - "ix": "ix_%(column_0_label)s", - "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_%(constraint_name)s", - "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s", -} -metadata_obj = MetaData(naming_convention=convention) +logger = logging.getLogger(__name__) str128 = Annotated[str, 128] +str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)] +str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)] int_pk = Annotated[int, mapped_column(primary_key=True)] int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)] -class Base(DeclarativeBase): - metadata = metadata_obj - type_annotation_map = { - str128: String(128), - } +# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet. +# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140 +# mypy: disable-error-code="name-defined" -class NREN(Base): +class NREN(db.Model): __tablename__ = 'nren' id: Mapped[int_pk] name: Mapped[str128] -class BudgetEntry(Base): +class BudgetEntry(db.Model): __tablename__ = 'budgets' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -47,7 +39,7 @@ class BudgetEntry(Base): budget: Mapped[Decimal] -class FundingSource(Base): +class FundingSource(db.Model): __tablename__ = 'funding_source' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -67,7 +59,7 @@ class FeeType(Enum): other = "other" -class ChargingStructure(Base): +class ChargingStructure(db.Model): __tablename__ = 'charging_structure' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -75,7 +67,7 @@ class ChargingStructure(Base): fee_type: Mapped[Optional[FeeType]] -class NrenStaff(Base): +class NrenStaff(db.Model): __tablename__ = 'nren_staff' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -86,7 +78,7 @@ class NrenStaff(Base): non_technical_fte: Mapped[Decimal] -class ParentOrganization(Base): +class ParentOrganization(db.Model): __tablename__ = 'parent_organization' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') @@ -94,18 +86,18 @@ class ParentOrganization(Base): organization: Mapped[str128] -class SubOrganization(Base): +class SubOrganization(db.Model): __tablename__ = 'sub_organization' nren_id: Mapped[int_pk_fkNREN] nren: Mapped[NREN] = relationship(lazy='joined') year: Mapped[int_pk] - organization: Mapped[str128] = mapped_column(primary_key=True) + organization: Mapped[str128_pk] role: Mapped[str128] -class ECProject(Base): +class ECProject(db.Model): __tablename__ = 'ec_project' nren_id: Mapped[int_pk_fkNREN] - nren: Mapped[NREN] = relationship(NREN, lazy='joined') + nren: Mapped[NREN] = relationship(lazy='joined') year: Mapped[int_pk] - project: Mapped[str] = mapped_column(String(256), primary_key=True) + project: Mapped[str256_pk] diff --git a/compendium_v2/migrations/README b/compendium_v2/migrations/README new file mode 100644 index 0000000000000000000000000000000000000000..0e048441597444a7e2850d6d7c4ce15550f79bda --- /dev/null +++ b/compendium_v2/migrations/README @@ -0,0 +1 @@ +Single-database configuration for Flask. diff --git a/compendium_v2/migrations/__init__.py b/compendium_v2/migrations/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/compendium_v2/migrations/alembic.ini b/compendium_v2/migrations/alembic.ini new file mode 100644 index 0000000000000000000000000000000000000000..ec9d45c26a6bb54e833fd4e6ce2de29343894f4b --- /dev/null +++ b/compendium_v2/migrations/alembic.ini @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py index 0307be3377e7eb332145d5822b4fb126b6f96e59..e2408681ba289bd300144dddf1c47f754ba931d1 100644 --- a/compendium_v2/migrations/env.py +++ b/compendium_v2/migrations/env.py @@ -1,10 +1,9 @@ import logging +from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool +from flask import current_app from alembic import context -from compendium_v2.db.model import metadata_obj # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -12,13 +11,34 @@ config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. -logging.basicConfig(level=logging.INFO) +if config.config_file_name is not None: + fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine() + except TypeError: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -target_metadata = metadata_obj +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db # other values from the config, defined by the needs of env.py, # can be acquired: @@ -26,6 +46,12 @@ target_metadata = metadata_obj # ... etc. +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -40,10 +66,7 @@ def run_migrations_offline(): """ url = config.get_main_option("sqlalchemy.url") context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, + url=url, target_metadata=get_metadata(), literal_binds=True ) with context.begin_transaction(): @@ -57,15 +80,25 @@ def run_migrations_online(): and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + connectable = get_engine() with connectable.connect() as connection: context.configure( - connection=connection, target_metadata=target_metadata + connection=connection, + target_metadata=get_metadata(), + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args ) with context.begin_transaction(): diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py deleted file mode 100644 index 3b29b540ed3382d8da1416a70d5bf9be5a01e1bc..0000000000000000000000000000000000000000 --- a/compendium_v2/migrations/migration_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging -import os - -from compendium_v2 import db -from alembic.config import Config -from alembic import command - -logger = logging.getLogger(__name__) -DEFAULT_MIGRATIONS_DIRECTORY = os.path.dirname(__file__) - - -def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY): - """ - migrate db to head version - - cf. https://stackoverflow.com/a/43530495, - https://stackoverflow.com/a/54402853 - - :param dsn: dsn string, passed to alembic - :param migrations_directory: full path to migrations directory - (default is this directory) - :return: - """ - alembic_config = Config() - alembic_config.set_main_option('script_location', migrations_directory) - alembic_config.set_main_option('sqlalchemy.url', dsn) - command.upgrade(alembic_config, 'head') - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - upgrade(db.postgresql_dsn( - db_username='compendium', - db_password='compendium321', - db_hostname='localhost', - db_name='compendium', - port=65000)) diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py index d95ad4e172bca6637bbebb3f1da03e90ab9836da..fb9fb40e98ac42bbf4dc95b8e373cf324a251532 100644 --- a/compendium_v2/publishers/helpers.py +++ b/compendium_v2/publishers/helpers.py @@ -1,21 +1,14 @@ -from compendium_v2 import db, survey_db -from compendium_v2.db import model +from sqlalchemy import select +from compendium_v2.db import db, model -def init_db(config): - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - dsn_survey = config['SURVEY_DATABASE_URI'] - survey_db.init_db_model(dsn_survey) - -def get_uppercase_nren_dict(session): +def get_uppercase_nren_dict(): """ - :param session: db session that is used to query the known NRENs :return: a dictionary of all known NRENs db entities keyed on the uppercased name """ - current_nrens = session.query(model.NREN).all() + current_nrens = db.session.scalars(select(model.NREN)) nren_dict = {nren.name.upper(): nren for nren in current_nrens} # add aliases that are used in the source data: nren_dict['ASNET'] = nren_dict['ASNET-AM'] diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py index 9b2db8c4855877ba9a7ca286986bf73fb4c4ba81..c213d2651f0b53f1e8811ec9caeb8faae6904ef1 100644 --- a/compendium_v2/publishers/survey_publisher_2022.py +++ b/compendium_v2/publishers/survey_publisher_2022.py @@ -13,14 +13,15 @@ import math import json import html -from sqlalchemy import text +from sqlalchemy import text, delete from collections import defaultdict +import compendium_v2 from compendium_v2.db.model import FeeType from compendium_v2.environment import setup_logging from compendium_v2.config import load -from compendium_v2 import db, survey_db -from compendium_v2.db import model +from compendium_v2.survey_db import model as survey_model +from compendium_v2.db import db, model from compendium_v2.publishers import helpers setup_logging() @@ -116,184 +117,169 @@ class ChargingStructure(enum.Enum): def query_budget(): - with survey_db.session_scope() as survey: - return survey.execute(text(BUDGET_QUERY)) + return db.session.execute(text(BUDGET_QUERY), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]}) def query_funding_sources(): for source in FundingSource: query = QUESTION_TEMPLATE_QUERY.format(source.value) - with survey_db.session_scope() as survey: - yield source, survey.execute(text(query)) + yield source, db.session.execute(text(query), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]}) def query_question(question: enum.Enum): query = QUESTION_TEMPLATE_QUERY.format(question.value) - with survey_db.session_scope() as survey: - return survey.execute(text(query)) - - -def transfer_budget(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_budget() - for row in rows: + return db.session.execute(text(query), bind_arguments={'bind': db.engines[survey_model.SURVEY_DB_BIND]}) + + +def transfer_budget(nren_dict): + rows = query_budget() + for row in rows: + nren_name = row[0].upper() + _budget = row[1] + try: + budget = float(_budget.replace('"', '').replace(',', '')) + except ValueError: + logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))') + continue + + if budget > 200: + logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})') + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + budget_entry = model.BudgetEntry( + nren=nren_dict[nren_name], + nren_id=nren_dict[nren_name].id, + budget=budget, + year=2022, + ) + db.session.merge(budget_entry) + db.session.commit() + + +def transfer_funding_sources(nren_dict): + sourcedata = {} + for source, data in query_funding_sources(): + for row in data: nren_name = row[0].upper() - _budget = row[1] + _value = row[1] try: - budget = float(_budget.replace('"', '').replace(',', '')) + value = float(_value.replace('"', '').replace(',', '')) except ValueError: - logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))') - continue + name = source.name + logger.info(f'{nren_name} has invalid value for {name}. ({_value}))') + value = 0 - if budget > 200: - logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})') - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - budget_entry = model.BudgetEntry( - nren=nren_dict[nren_name], - nren_id=nren_dict[nren_name].id, - budget=budget, - year=2022, + nren_info = sourcedata.setdefault( + nren_name, + {source_type: 0 for source_type in FundingSource} ) - session.merge(budget_entry) - session.commit() - - -def transfer_funding_sources(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - sourcedata = {} - for source, data in query_funding_sources(): - for row in data: - nren_name = row[0].upper() - _value = row[1] - try: - value = float(_value.replace('"', '').replace(',', '')) - except ValueError: - name = source.name - logger.info(f'{nren_name} has invalid value for {name}. ({_value}))') - value = 0 - - nren_info = sourcedata.setdefault( - nren_name, - {source_type: 0 for source_type in FundingSource} - ) - nren_info[source] = value - - for nren_name, nren_info in sourcedata.items(): - total = sum(nren_info.values()) - - if not math.isclose(total, 100, abs_tol=0.01): - logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})') + nren_info[source] = value + + for nren_name, nren_info in sourcedata.items(): + total = sum(nren_info.values()) + + if not math.isclose(total, 100, abs_tol=0.01): + logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})') + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + funding_source = model.FundingSource( + nren=nren_dict[nren_name], + nren_id=nren_dict[nren_name].id, + year=2022, + client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS], + european_funding=nren_info[FundingSource.EUROPEAN_FUNDING], + gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES], + commercial=nren_info[FundingSource.COMMERCIAL], + other=nren_info[FundingSource.OTHER], + ) + db.session.merge(funding_source) + db.session.commit() + + +def transfer_staff_data(nren_dict): + data = {} + for question in StaffQuestion: + rows = query_question(question) + for row in rows: + nren_name = row[0].upper() + _value = row[1] + try: + value = float(_value.replace('"', '').replace(',', '')) + except ValueError: + value = 0 if nren_name not in nren_dict: logger.info(f'{nren_name} unknown. Skipping.') continue - funding_source = model.FundingSource( - nren=nren_dict[nren_name], - nren_id=nren_dict[nren_name].id, - year=2022, - client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS], - european_funding=nren_info[FundingSource.EUROPEAN_FUNDING], - gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES], - commercial=nren_info[FundingSource.COMMERCIAL], - other=nren_info[FundingSource.OTHER], - ) - session.merge(funding_source) - session.commit() - - -def transfer_staff_data(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - data = {} - for question in StaffQuestion: - rows = query_question(question) - for row in rows: - nren_name = row[0].upper() - _value = row[1] - try: - value = float(_value.replace('"', '').replace(',', '')) - except ValueError: - value = 0 - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - # initialize on first use, so we don't add data for nrens with no answers - data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[ - question] = value - - for nren_name, nren_info in data.items(): - if sum([nren_info[question] for question in StaffQuestion]) == 0: - logger.info(f'{nren_name} has no staff data. Deleting if exists.') - session.query(model.NrenStaff).filter( - model.NrenStaff.nren_id == nren_dict[nren_name].id, - model.NrenStaff.year == 2022, - ).delete() - continue - - employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] - technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE] - - if not math.isclose(employed, technical, abs_tol=0.01): - logger.info(f'{nren_name} FTE do not equal across employed/technical categories.' - f' ({employed} != {technical})') - - staff_data = model.NrenStaff( - nren=nren_dict[nren_name], - nren_id=nren_dict[nren_name].id, - year=2022, - permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE], - subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE], - technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE], - non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE], - ) - session.merge(staff_data) - session.commit() - - -def transfer_nren_parent_org(): + # initialize on first use, so we don't add data for nrens with no answers + data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[ + question] = value + + for nren_name, nren_info in data.items(): + if sum([nren_info[question] for question in StaffQuestion]) == 0: + logger.info(f'{nren_name} has no staff data. Deleting if exists.') + db.session.execute(delete(model.NrenStaff).where( + model.NrenStaff.nren_id == nren_dict[nren_name].id, + model.NrenStaff.year == 2022 + )) + continue + + employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE] + technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE] + + if not math.isclose(employed, technical, abs_tol=0.01): + logger.info(f'{nren_name} FTE do not equal across employed/technical categories.' + f' ({employed} != {technical})') + + staff_data = model.NrenStaff( + nren=nren_dict[nren_name], + nren_id=nren_dict[nren_name].id, + year=2022, + permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE], + subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE], + technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE], + non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE], + ) + db.session.merge(staff_data) + db.session.commit() + + +def transfer_nren_parent_org(nren_dict): # clean up the data a bit by removing some strings strings_to_replace = [ 'We are affiliated to ' ] - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) + rows = query_question(OrgQuestion.PARENT_ORG_NAME) + for row in rows: + nren_name = row[0].upper() + value = str(row[1]).replace('"', '') - rows = query_question(OrgQuestion.PARENT_ORG_NAME) - for row in rows: - nren_name = row[0].upper() - value = str(row[1]).replace('"', '') + for string in strings_to_replace: + value = value.replace(string, '') - for string in strings_to_replace: - value = value.replace(string, '') + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - parent_org = model.ParentOrganization( - nren=nren_dict[nren_name], - nren_id=nren_dict[nren_name].id, - year=2022, - organization=value, - ) - session.merge(parent_org) - session.commit() + parent_org = model.ParentOrganization( + nren=nren_dict[nren_name], + nren_id=nren_dict[nren_name].id, + year=2022, + organization=value, + ) + db.session.merge(parent_org) + db.session.commit() -def transfer_nren_sub_org(): +def transfer_nren_sub_org(nren_dict): suborg_questions = [ (OrgQuestion.SUB_ORGS_1_NAME, OrgQuestion.SUB_ORGS_1_CHOICE, OrgQuestion.SUB_ORGS_1_ROLE), (OrgQuestion.SUB_ORGS_2_NAME, OrgQuestion.SUB_ORGS_2_CHOICE, OrgQuestion.SUB_ORGS_2_ROLE), @@ -302,143 +288,139 @@ def transfer_nren_sub_org(): (OrgQuestion.SUB_ORGS_5_NAME, OrgQuestion.SUB_ORGS_5_CHOICE, OrgQuestion.SUB_ORGS_5_ROLE) ] - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - lookup = defaultdict(list) - - for name, choice, role in suborg_questions: - _name_rows = query_question(name) - _choice_rows = query_question(choice) - _role_rows = list(query_question(role)) - for _name, _choice in zip(_name_rows, _choice_rows): - nren_name = _name[0].upper() - suborg_name = _name[1].replace('"', '').strip() - role_choice = _choice[1].replace('"', '').strip() - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - if role_choice.lower() == 'other': - for _role in _role_rows: - if _role[0] == _name[0]: - role = _role[1].replace('"', '').strip() - break - else: - role = role_choice - - lookup[nren_name].append((suborg_name, role)) - - for nren_name, suborgs in lookup.items(): - for suborg_name, role in suborgs: - suborg = model.SubOrganization( - nren=nren_dict[nren_name], - nren_id=nren_dict[nren_name].id, - year=2022, - organization=suborg_name, - role=role, - ) - session.merge(suborg) - session.commit() - - -def transfer_charging_structure(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - rows = query_question(ChargingStructure.charging_structure) - for row in rows: - nren_name = row[0].upper() - value = row[1].replace('"', '').strip() + lookup = defaultdict(list) + + for name, choice, role in suborg_questions: + _name_rows = query_question(name) + _choice_rows = query_question(choice) + _role_rows = list(query_question(role)) + for _name, _choice in zip(_name_rows, _choice_rows): + nren_name = _name[0].upper() + suborg_name = _name[1].replace('"', '').strip() + role_choice = _choice[1].replace('"', '').strip() if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping from charging structure.') + logger.info(f'{nren_name} unknown. Skipping.') continue - if "do not charge" in value: - charging_structure = FeeType.no_charge - elif "combination" in value: - charging_structure = FeeType.combination - elif "flat" in value: - charging_structure = FeeType.flat_fee - elif "usage-based" in value: - charging_structure = FeeType.usage_based_fee - elif "Other" in value: - charging_structure = FeeType.other + if role_choice.lower() == 'other': + for _role in _role_rows: + if _role[0] == _name[0]: + role = _role[1].replace('"', '').strip() + break else: - charging_structure = None + role = role_choice - charging_structure = model.ChargingStructure( + lookup[nren_name].append((suborg_name, role)) + + for nren_name, suborgs in lookup.items(): + for suborg_name, role in suborgs: + suborg = model.SubOrganization( nren=nren_dict[nren_name], nren_id=nren_dict[nren_name].id, year=2022, - fee_type=charging_structure, + organization=suborg_name, + role=role, ) - session.merge(charging_structure) - session.commit() - - -def transfer_ec_projects(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # delete all existing EC projects, in case something changed - session.query(model.ECProject).filter( - model.ECProject.year == 2022, - ).delete() - - rows = query_question(ECQuestion.EC_PROJECT) - for row in rows: - nren_name = row[0].upper() - - if nren_name not in nren_dict: - logger.info(f'{nren_name} unknown. Skipping.') - continue - - try: - value = json.loads(row[1]) - except json.decoder.JSONDecodeError: - logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.') + db.session.merge(suborg) + db.session.commit() + + +def transfer_charging_structure(nren_dict): + rows = query_question(ChargingStructure.charging_structure) + for row in rows: + nren_name = row[0].upper() + value = row[1].replace('"', '').strip() + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping from charging structure.') + continue + + if "do not charge" in value: + charging_structure = FeeType.no_charge + elif "combination" in value: + charging_structure = FeeType.combination + elif "flat" in value: + charging_structure = FeeType.flat_fee + elif "usage-based" in value: + charging_structure = FeeType.usage_based_fee + elif "Other" in value: + charging_structure = FeeType.other + else: + charging_structure = None + + charging_structure = model.ChargingStructure( + nren=nren_dict[nren_name], + nren_id=nren_dict[nren_name].id, + year=2022, + fee_type=charging_structure, + ) + db.session.merge(charging_structure) + db.session.commit() + + +def transfer_ec_projects(nren_dict): + # delete all existing EC projects, in case something changed + db.session.execute( + delete(model.ECProject).where(model.ECProject.year == 2022) + ) + + rows = query_question(ECQuestion.EC_PROJECT) + for row in rows: + nren_name = row[0].upper() + + if nren_name not in nren_dict: + logger.info(f'{nren_name} unknown. Skipping.') + continue + + try: + value = json.loads(row[1]) + except json.decoder.JSONDecodeError: + logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.') + continue + + for val in value: + if not val: + logger.info(f'Invalid EC project value for {nren_name}: {val}.') continue - for val in value: - if not val: - logger.info(f'Invalid EC project value for {nren_name}: {val}.') - continue - - # strip html entities/NBSP from val - val = html.unescape(val).replace('\xa0', ' ') + # strip html entities/NBSP from val + val = html.unescape(val).replace('\xa0', ' ') - # some answers include contract numbers, which we don't want here - val = val.split('(contract n')[0] + # some answers include contract numbers, which we don't want here + val = val.split('(contract n')[0] - ec_project = model.ECProject( - nren=nren_dict[nren_name], - nren_id=nren_dict[nren_name].id, - year=2022, - project=str(val).strip() - ) - session.add(ec_project) - session.commit() + ec_project = model.ECProject( + nren=nren_dict[nren_name], + nren_id=nren_dict[nren_name].id, + year=2022, + project=str(val).strip() + ) + db.session.add(ec_project) + db.session.commit() -def _cli(config): - helpers.init_db(config) - transfer_budget() - transfer_funding_sources() - transfer_staff_data() - transfer_nren_parent_org() - transfer_nren_sub_org() - transfer_charging_structure() - transfer_ec_projects() +def _cli(config, app): + with app.app_context(): + nren_dict = helpers.get_uppercase_nren_dict() + transfer_budget(nren_dict) + transfer_funding_sources(nren_dict) + transfer_staff_data(nren_dict) + transfer_nren_parent_org(nren_dict) + transfer_nren_sub_org(nren_dict) + transfer_charging_structure(nren_dict) + transfer_ec_projects(nren_dict) @click.command() @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) - _cli(app_config) + + app_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: app_config['SURVEY_DATABASE_URI']} + + app = compendium_v2._create_app_with_db(app_config) + _cli(app_config, app) if __name__ == "__main__": diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py index 90236ac1b40e000dc92e7696fc95354564ac8701..cbb7a05e4d90783bc0d486d1e6f5f65af35aa42d 100644 --- a/compendium_v2/publishers/survey_publisher_v1.py +++ b/compendium_v2/publishers/survey_publisher_v1.py @@ -11,11 +11,13 @@ import logging import math import click +from sqlalchemy import select + +import compendium_v2 from compendium_v2.environment import setup_logging -from compendium_v2 import db, survey_db from compendium_v2.background_task import parse_excel_data from compendium_v2.config import load -from compendium_v2.db import model +from compendium_v2.db import db, model from compendium_v2.survey_db import model as survey_model from compendium_v2.publishers import helpers @@ -24,215 +26,200 @@ setup_logging() logger = logging.getLogger('survey-publisher-v1') -def db_budget_migration(): - with survey_db.session_scope() as survey_session, \ - db.session_scope() as session: - - nren_dict = helpers.get_uppercase_nren_dict(session) - - # move data from Survey DB budget table - data = survey_session.query(survey_model.Nrens) - for nren in data: - for budget in nren.budgets: - abbrev = nren.abbreviation.upper() - year = budget.year - - if float(budget.budget) > 200: - logger.warning(f'Incorrect Data: {abbrev} has budget set >200M EUR for {year}. ({budget.budget})') +def db_budget_migration(nren_dict): + # move data from Survey DB budget table + data = db.session.scalars(select(survey_model.Nrens)) + for nren in data: + for budget in nren.budgets: + abbrev = nren.abbreviation.upper() + year = budget.year - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue + if float(budget.budget) > 200: + logger.warning(f'Incorrect Data: {abbrev} has budget set >200M EUR for {year}. ({budget.budget})') - budget_entry = model.BudgetEntry( - nren=nren_dict[abbrev], - nren_id=nren_dict[abbrev].id, - budget=float(budget.budget), - year=year - ) - session.merge(budget_entry) - - # Import the data from excel sheet to database - exceldata = parse_excel_data.fetch_budget_excel_data() - - for abbrev, budget, year in exceldata: if abbrev not in nren_dict: logger.warning(f'{abbrev} unknown. Skipping.') continue - if budget > 200: - logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') - budget_entry = model.BudgetEntry( nren=nren_dict[abbrev], nren_id=nren_dict[abbrev].id, - budget=budget, + budget=float(budget.budget), year=year ) - session.merge(budget_entry) - session.commit() - - -def db_funding_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # Import the data to database - data = parse_excel_data.fetch_funding_excel_data() - - for (abbrev, year, client_institution, - european_funding, - gov_public_bodies, - commercial, other) in data: - - _data = [client_institution, european_funding, gov_public_bodies, commercial, other] - total = sum(_data) - if not math.isclose(total, 100, abs_tol=0.01) and total != 0: - logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') - - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - budget_entry = model.FundingSource( - nren=nren_dict[abbrev], - nren_id=nren_dict[abbrev].id, - year=year, - client_institutions=client_institution, - european_funding=european_funding, - gov_public_bodies=gov_public_bodies, - commercial=commercial, - other=other) - session.merge(budget_entry) - session.commit() - - -def db_charging_structure_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - # Import the data to database - data = parse_excel_data.fetch_charging_structure_excel_data() - - for (abbrev, year, charging_structure) in data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - - charging_structure_entry = model.ChargingStructure( - nren=nren_dict[abbrev], - nren_id=nren_dict[abbrev].id, - year=year, - fee_type=charging_structure - ) - session.merge(charging_structure_entry) - session.commit() - - -def db_staffing_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - staff_data = parse_excel_data.fetch_staffing_excel_data() - - nren_staff_map = {} - for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping staff data.') - continue - - nren = nren_dict[abbrev] + db.session.merge(budget_entry) + + # Import the data from excel sheet to database + exceldata = parse_excel_data.fetch_budget_excel_data() + + for abbrev, budget, year in exceldata: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + if budget > 200: + logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') + + budget_entry = model.BudgetEntry( + nren=nren_dict[abbrev], + nren_id=nren_dict[abbrev].id, + budget=budget, + year=year + ) + db.session.merge(budget_entry) + db.session.commit() + + +def db_funding_migration(nren_dict): + # Import the data to database + data = parse_excel_data.fetch_funding_excel_data() + + for (abbrev, year, client_institution, + european_funding, + gov_public_bodies, + commercial, other) in data: + + _data = [client_institution, european_funding, gov_public_bodies, commercial, other] + total = sum(_data) + if not math.isclose(total, 100, abs_tol=0.01) and total != 0: + logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') + + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + budget_entry = model.FundingSource( + nren=nren_dict[abbrev], + nren_id=nren_dict[abbrev].id, + year=year, + client_institutions=client_institution, + european_funding=european_funding, + gov_public_bodies=gov_public_bodies, + commercial=commercial, + other=other) + db.session.merge(budget_entry) + db.session.commit() + + +def db_charging_structure_migration(nren_dict): + # Import the data to database + data = parse_excel_data.fetch_charging_structure_excel_data() + + for (abbrev, year, charging_structure) in data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + charging_structure_entry = model.ChargingStructure( + nren=nren_dict[abbrev], + nren_id=nren_dict[abbrev].id, + year=year, + fee_type=charging_structure + ) + db.session.merge(charging_structure_entry) + db.session.commit() + + +def db_staffing_migration(nren_dict): + staff_data = parse_excel_data.fetch_staffing_excel_data() + + nren_staff_map = {} + for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping staff data.') + continue + + nren = nren_dict[abbrev] + nren_staff_map[(nren.id, year)] = model.NrenStaff( + nren=nren, + nren_id=nren.id, + year=year, + permanent_fte=permanent_fte, + subcontracted_fte=subcontracted_fte, + technical_fte=0, + non_technical_fte=0 + ) + + function_data = parse_excel_data.fetch_staff_function_excel_data() + for (abbrev, year, technical_fte, non_technical_fte) in function_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping staff function data.') + continue + + nren = nren_dict[abbrev] + if (nren.id, year) in nren_staff_map: + nren_staff_map[(nren.id, year)].technical_fte = technical_fte + nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte + else: nren_staff_map[(nren.id, year)] = model.NrenStaff( nren=nren, nren_id=nren.id, year=year, - permanent_fte=permanent_fte, - subcontracted_fte=subcontracted_fte, - technical_fte=0, - non_technical_fte=0 + permanent_fte=0, + subcontracted_fte=0, + technical_fte=technical_fte, + non_technical_fte=non_technical_fte ) - function_data = parse_excel_data.fetch_staff_function_excel_data() - for (abbrev, year, technical_fte, non_technical_fte) in function_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping staff function data.') - continue + for nren_staff_model in nren_staff_map.values(): + employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte + technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte + if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0: + logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:' + f' FTE do not equal across employed/technical categories ({employed} != {technical})') - nren = nren_dict[abbrev] - if (nren.id, year) in nren_staff_map: - nren_staff_map[(nren.id, year)].technical_fte = technical_fte - nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte - else: - nren_staff_map[(nren.id, year)] = model.NrenStaff( - nren=nren, - nren_id=nren.id, - year=year, - permanent_fte=0, - subcontracted_fte=0, - technical_fte=technical_fte, - non_technical_fte=non_technical_fte - ) - - for nren_staff_model in nren_staff_map.values(): - employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte - technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte - if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0: - logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:' - f' FTE do not equal across employed/technical categories ({employed} != {technical})') - - session.merge(nren_staff_model) - - session.commit() - - -def db_ecprojects_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) - - ecproject_data = parse_excel_data.fetch_ecproject_excel_data() - for (abbrev, year, project) in ecproject_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue + db.session.merge(nren_staff_model) - nren = nren_dict[abbrev] - ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) - session.merge(ecproject_entry) - session.commit() + db.session.commit() -def db_organizations_migration(): - with db.session_scope() as session: - nren_dict = helpers.get_uppercase_nren_dict(session) +def db_ecprojects_migration(nren_dict): + ecproject_data = parse_excel_data.fetch_ecproject_excel_data() + for (abbrev, year, project) in ecproject_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + + nren = nren_dict[abbrev] + ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) + db.session.merge(ecproject_entry) + db.session.commit() - organization_data = parse_excel_data.fetch_organization_excel_data() - for (abbrev, year, org) in organization_data: - if abbrev not in nren_dict: - logger.warning(f'{abbrev} unknown. Skipping.') - continue - nren = nren_dict[abbrev] - org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) - session.merge(org_entry) - session.commit() +def db_organizations_migration(nren_dict): + organization_data = parse_excel_data.fetch_organization_excel_data() + for (abbrev, year, org) in organization_data: + if abbrev not in nren_dict: + logger.warning(f'{abbrev} unknown. Skipping.') + continue + nren = nren_dict[abbrev] + org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) + db.session.merge(org_entry) + db.session.commit() -def _cli(config): - helpers.init_db(config) - db_budget_migration() - db_funding_migration() - db_charging_structure_migration() - db_staffing_migration() - db_ecprojects_migration() - db_organizations_migration() + +def _cli(config, app): + with app.app_context(): + nren_dict = helpers.get_uppercase_nren_dict() + db_budget_migration(nren_dict) + db_funding_migration(nren_dict) + db_charging_structure_migration(nren_dict) + db_staffing_migration(nren_dict) + db_ecprojects_migration(nren_dict) + db_organizations_migration(nren_dict) @click.command() @click.option('--config', type=click.STRING, default='config.json') def cli(config): app_config = load(open(config, 'r')) + + app_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: app_config['SURVEY_DATABASE_URI']} + + app = compendium_v2._create_app_with_db(app_config) print("survey-publisher-v1 starting") - _cli(app_config) + _cli(app_config, app) if __name__ == "__main__": diff --git a/compendium_v2/routes/budget.py b/compendium_v2/routes/budget.py index 763b3af37256811613edd300ff6a354f0c7b2e2a..1f4ff8509e63ad8cbcb4f030c578fe2621d9bdd8 100644 --- a/compendium_v2/routes/budget.py +++ b/compendium_v2/routes/budget.py @@ -1,28 +1,17 @@ import logging from typing import Any -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db -from compendium_v2.db import model +from compendium_v2.db import db +from compendium_v2.db.model import BudgetEntry from compendium_v2.routes import common -routes = Blueprint('budget', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('budget', __name__) logger = logging.getLogger(__name__) -col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061', - '#bd7ebe', '#ffb55a', '#ffee65', - '#beb9db', '#fdcce5', '#8bd3c7'] - BUDGET_RESPONSE_SCHEMA = { '$schema': 'http://json-schema.org/draft-07/schema#', @@ -58,15 +47,15 @@ def budget_view() -> Any: :return: """ - def _extract_data(entry: model.BudgetEntry): + def _extract_data(entry: BudgetEntry): return { 'NREN': entry.nren.name, 'BUDGET': float(entry.budget), 'BUDGET_YEAR': entry.year, } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.BudgetEntry)], - key=lambda d: (d['BUDGET_YEAR'], d['NREN'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))], + key=lambda d: (d['BUDGET_YEAR'], d['NREN']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/charging.py b/compendium_v2/routes/charging.py index 0b2c238414bda0f4460762bb48a27464a2101f3c..57e8114fcd64f07c3ffbd7a09cc9d59900b86c8d 100644 --- a/compendium_v2/routes/charging.py +++ b/compendium_v2/routes/charging.py @@ -1,21 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('charging', __name__) - +from flask import Blueprint, jsonify +from sqlalchemy import select -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ChargingStructure +from compendium_v2.routes import common +routes = Blueprint('charging', __name__) logger = logging.getLogger(__name__) CHARGING_STRUCTURE_RESPONSE_SCHEMA = { @@ -53,16 +47,15 @@ def charging_structure_view() -> Any: :return: """ - def _extract_data(entry: model.ChargingStructure): + def _extract_data(entry: ChargingStructure): return { 'NREN': entry.nren.name, 'YEAR': int(entry.year), 'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None, } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.ChargingStructure) - .all()], - key=lambda d: (d['NREN'], d['YEAR'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))], + key=lambda d: (d['NREN'], d['YEAR']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/ec_projects.py b/compendium_v2/routes/ec_projects.py index 7114718dadf3ffdf614f96e475c9201803cdd549..b58d8931cbe881f7d2d77a9468e6d6afb1126855 100644 --- a/compendium_v2/routes/ec_projects.py +++ b/compendium_v2/routes/ec_projects.py @@ -1,22 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app - -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('ec-projects', __name__) +from flask import Blueprint, jsonify +from sqlalchemy import select - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ECProject +from compendium_v2.routes import common +routes = Blueprint('ec-projects', __name__) logger = logging.getLogger(__name__) EC_PROJECTS_RESPONSE_SCHEMA = { @@ -55,13 +48,12 @@ def ec_projects_view() -> Any: :return: """ - def _extract_project(entry: model.ECProject): + def _extract_project(entry: ECProject): return { 'nren': entry.nren.name, 'year': entry.year, 'project': entry.project } - with db.session_scope() as session: - result = [_extract_project(project) for project in session.query(model.ECProject)] + result = [_extract_project(project) for project in db.session.scalars(select(ECProject))] return jsonify(result) diff --git a/compendium_v2/routes/funding.py b/compendium_v2/routes/funding.py index ed0e26c8809b0b3f8187af5a8afe25ed2ec16010..c02bf136e8ec2173dcd1ce3525ffc6375e05e167 100644 --- a/compendium_v2/routes/funding.py +++ b/compendium_v2/routes/funding.py @@ -1,22 +1,15 @@ import logging -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db from compendium_v2.routes import common -from compendium_v2.db import model +from compendium_v2.db import db +from compendium_v2.db.model import FundingSource from typing import Any -routes = Blueprint('funding', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('funding', __name__) logger = logging.getLogger(__name__) FUNDING_RESPONSE_SCHEMA = { @@ -59,7 +52,7 @@ def funding_source_view() -> Any: :return: """ - def _extract_data(entry: model.FundingSource): + def _extract_data(entry: FundingSource): return { 'NREN': entry.nren.name, 'YEAR': entry.year, @@ -70,8 +63,8 @@ def funding_source_view() -> Any: 'OTHER': float(entry.other) } - with db.session_scope() as session: - entries = sorted([_extract_data(entry) - for entry in session.query(model.FundingSource)], - key=lambda d: (d['NREN'], d['YEAR'])) + entries = sorted( + [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))], + key=lambda d: (d['NREN'], d['YEAR']) + ) return jsonify(entries) diff --git a/compendium_v2/routes/organization.py b/compendium_v2/routes/organization.py index 61a43354cbfe53dfc0c9bb7723417a46b134e26e..8e8ebc8dbe34473be5082df0ce0ec9e495ca0608 100644 --- a/compendium_v2/routes/organization.py +++ b/compendium_v2/routes/organization.py @@ -1,22 +1,15 @@ import logging - -from flask import Blueprint, jsonify, current_app - -from compendium_v2 import db -from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('organization', __name__) +from flask import Blueprint, jsonify +from sqlalchemy import select - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) +from compendium_v2.db import db +from compendium_v2.db.model import ParentOrganization, SubOrganization +from compendium_v2.routes import common +routes = Blueprint('organization', __name__) logger = logging.getLogger(__name__) ORGANIZATION_RESPONSE_SCHEMA = { @@ -71,15 +64,14 @@ def parent_organization_view() -> Any: :return: """ - def _extract_parent(entry: model.ParentOrganization): + def _extract_parent(entry: ParentOrganization): return { 'nren': entry.nren.name, 'year': entry.year, 'name': entry.organization } - with db.session_scope() as session: - result = [_extract_parent(org) for org in session.query(model.ParentOrganization)] + result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))] return jsonify(result) @@ -98,7 +90,7 @@ def sub_organization_view() -> Any: :return: """ - def _extract_sub(entry: model.SubOrganization): + def _extract_sub(entry: SubOrganization): return { 'nren': entry.nren.name, 'year': entry.year, @@ -106,6 +98,5 @@ def sub_organization_view() -> Any: 'role': entry.role } - with db.session_scope() as session: - result = [_extract_sub(org) for org in session.query(model.SubOrganization)] + result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))] return jsonify(result) diff --git a/compendium_v2/routes/staff.py b/compendium_v2/routes/staff.py index 73e79c4836bb42e2959ae41a22c1b094946e0f29..b8478266f3ec653d2d79375e6e7baadd512a6b2e 100644 --- a/compendium_v2/routes/staff.py +++ b/compendium_v2/routes/staff.py @@ -1,22 +1,15 @@ import logging -from flask import Blueprint, jsonify, current_app +from flask import Blueprint, jsonify +from sqlalchemy import select -from compendium_v2 import db +from compendium_v2.db import db +from compendium_v2.db.model import NREN, NrenStaff from compendium_v2.routes import common -from compendium_v2.db import model from typing import Any -routes = Blueprint('staff', __name__) - - -@routes.before_request -def before_request(): - config = current_app.config['CONFIG_PARAMS'] - dsn_prn = config['SQLALCHEMY_DATABASE_URI'] - db.init_db_model(dsn_prn) - +routes = Blueprint('staff', __name__) logger = logging.getLogger(__name__) STAFF_RESPONSE_SCHEMA = { @@ -57,7 +50,7 @@ def staff_view() -> Any: :return: """ - def _extract_data(entry: model.NrenStaff): + def _extract_data(entry: NrenStaff): return { 'nren': entry.nren.name, 'year': entry.year, @@ -67,7 +60,6 @@ def staff_view() -> Any: 'non_technical_fte': float(entry.non_technical_fte) } - with db.session_scope() as session: - entries = [_extract_data(entry) for entry in session.query( - model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())] + entries = [_extract_data(entry) for entry in db.session.scalars( + select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))] return jsonify(entries) diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py index ce48aaa07a269ae2c20235b31c15f58cf8ff0956..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/compendium_v2/survey_db/__init__.py +++ b/compendium_v2/survey_db/__init__.py @@ -1,45 +0,0 @@ -import contextlib -import logging -from typing import Optional, Union, Callable, Iterator - -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker, Session - -logger = logging.getLogger(__name__) -_SESSION_MAKER: Union[None, sessionmaker] = None - - -@contextlib.contextmanager -def session_scope( - callback_before_close: Optional[Callable] = None) -> Iterator[Session]: - # best practice is to keep session scope separate from data processing - # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html - - assert _SESSION_MAKER - session = _SESSION_MAKER() - try: - yield session - session.commit() - if callback_before_close: - callback_before_close() - except SQLAlchemyError: - logger.error('caught sql layer exception, rolling back') - session.rollback() - raise # re-raise, will be handled by main consumer - finally: - session.close() - - -def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432): - return (f'postgresql://{db_username}:{db_password}' - f'@{db_hostname}:{port}/{db_name}') - - -def init_db_model(dsn): - global _SESSION_MAKER - - # cf. https://docs.sqlalchemy.org/en - # /latest/orm/extensions/automap.html - engine = create_engine(dsn, pool_size=10) - _SESSION_MAKER = sessionmaker(bind=engine) diff --git a/compendium_v2/survey_db/model.py b/compendium_v2/survey_db/model.py index a908aa201aaae8c194614ece4ab067aa1664a43b..605cf2d5d6fa8ec4b45106838d002cbed08919aa 100644 --- a/compendium_v2/survey_db/model.py +++ b/compendium_v2/survey_db/model.py @@ -1,17 +1,23 @@ import logging from typing import List, Optional -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey +from compendium_v2.db import db + + logger = logging.getLogger(__name__) +SURVEY_DB_BIND = 'survey' -class Base(DeclarativeBase): - pass +# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet. +# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140 +# mypy: disable-error-code="name-defined" -class Budgets(Base): +class Budgets(db.Model): + __bind_key__ = SURVEY_DB_BIND __tablename__ = 'budgets' id: Mapped[int] = mapped_column(primary_key=True) budget: Mapped[Optional[str]] @@ -20,7 +26,8 @@ class Budgets(Base): nren: Mapped[Optional['Nrens']] = relationship(back_populates='budgets') -class Nrens(Base): +class Nrens(db.Model): + __bind_key__ = SURVEY_DB_BIND __tablename__ = 'nrens' id: Mapped[int] = mapped_column(primary_key=True) abbreviation: Mapped[Optional[str]] diff --git a/requirements.txt b/requirements.txt index 12ec2076b01b867c7e1c8c08ba7f8e07b9a1ad40..f49a79607d14fd2eed998d4429390cc95acb18b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ click~=8.1 jsonschema~=4.17 flask~=2.2 flask-cors~=3.0 +flask-migrate~=4.0 +flask-sqlalchemy~=3.0 openpyxl~=3.1 psycopg2-binary~=2.9 SQLAlchemy~=2.0 diff --git a/setup.py b/setup.py index 8425dec4eb77c07b7de44c25df2ab27a846719aa..52acb8cb8ef2714e7f7b93352b260b4ce5b35974 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ setup( 'jsonschema~=4.17', 'flask~=2.2', 'flask-cors~=3.0', + 'flask-migrate~=4.0', + 'flask-sqlalchemy~=3.0', 'openpyxl~=3.1', 'psycopg2-binary~=2.9', 'SQLAlchemy~=2.0', diff --git a/test/conftest.py b/test/conftest.py index 79186212deb7f5bf06fdd61df781349e5574ddff..be9c16e45a59ffa6ac3a593154d3e3dd03a61ac7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,21 +1,12 @@ -import json +import csv import os -import tempfile +import pytest import random -import pytest import compendium_v2 -from compendium_v2 import db -from compendium_v2.db import model -from compendium_v2 import survey_db +from compendium_v2.db import db, model from compendium_v2.survey_db import model as survey_model -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool - -import csv - def _test_data_csv(filename): data_filename = os.path.join(os.path.dirname(__file__), 'data', filename) @@ -31,73 +22,29 @@ def dummy_config(): @pytest.fixture -def mocked_survey_db(mocker): - engine = create_engine( - 'sqlite://', - connect_args={'check_same_thread': False}, - poolclass=StaticPool, - echo=False) - survey_model.Base.metadata.create_all(engine) - mocker.patch('compendium_v2.survey_db._SESSION_MAKER', sessionmaker(bind=engine)) - mocker.patch('compendium_v2.survey_db.init_db_model', lambda dsn: None) - - -@pytest.fixture -def mocked_db(mocker): - # cf. https://stackoverflow.com/a/33057675 - engine = create_engine( - 'sqlite://', - connect_args={'check_same_thread': False}, - poolclass=StaticPool, - echo=False) - model.Base.metadata.create_all(engine) - mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine)) - mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None) - mocker.patch('compendium_v2.migrate_database', lambda config: None) - - -@pytest.fixture -def test_budget_data(): - with db.session_scope() as session: +def test_budget_data(app): + with app.app_context(): data = [row for row in _test_data_csv("BudgetTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] budget = row["budget"] year = row["year"] - session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) - - with survey_db.session_scope() as session: - data = _test_data_csv("BudgetTestData.csv") - nrens = set() - budgets_data = [] - for row in data: - nren = row["nren"] - budget = row["budget"] - year = row["year"] - country_code = row["nren"] - - nrens.add(nren) - - budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code)) - - for nren in nrens: - session.add(survey_model.Nrens(abbreviation=nren, country_code=nren)) - - session.add_all(budgets_data) + db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) + db.session.commit() @pytest.fixture -def test_funding_source_data(): - with db.session_scope() as session: +def test_funding_source_data(app): + with app.app_context(): data = [row for row in _test_data_csv("FundingSourceTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -108,7 +55,7 @@ def test_funding_source_data(): commercial = row["commercial"] other = row["other"] - session.add( + db.session.add( model.FundingSource( nren=nren, year=year, client_institutions=client, @@ -117,10 +64,11 @@ def test_funding_source_data(): commercial=commercial, other=other) ) + db.session.commit() @pytest.fixture -def test_staff_data(): +def test_staff_data(app): # generator of random test data for 5 years and 100 nrens def _generate_rows(): @@ -135,12 +83,12 @@ def test_staff_data(): "non_technical_fte": random.randint(0, 100) } - with db.session_scope() as session: + with app.app_context(): data = list(_generate_rows()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -150,7 +98,7 @@ def test_staff_data(): technical_fte = row["technical_fte"] non_technical_fte = row["non_technical_fte"] - session.add( + db.session.add( model.NrenStaff( nren=nren, year=year, @@ -160,30 +108,38 @@ def test_staff_data(): non_technical_fte=non_technical_fte ) ) + db.session.commit() + + +@pytest.fixture +def app(dummy_config): + app = compendium_v2._create_app_with_db(dummy_config) + with app.app_context(): + db.create_all(bind_key=None) + yield app @pytest.fixture -def data_config_filename(dummy_config): - with tempfile.NamedTemporaryFile() as f: - f.write(json.dumps(dummy_config).encode('utf-8')) - f.flush() - yield f.name +def app_with_survey_db(dummy_config): + dummy_config['SQLALCHEMY_BINDS'] = {survey_model.SURVEY_DB_BIND: dummy_config['SURVEY_DATABASE_URI']} + app = compendium_v2._create_app_with_db(dummy_config) + with app.app_context(): + db.create_all() + yield app @pytest.fixture -def client(data_config_filename, mocked_db, mocked_survey_db): - os.environ['SETTINGS_FILENAME'] = data_config_filename - with compendium_v2.create_app().test_client() as c: - yield c +def client(app): + return app.test_client() @pytest.fixture -def test_charging_structure_data(): - with db.session_scope() as session: +def test_charging_structure_data(app): + with app.app_context(): data = [row for row in _test_data_csv("ChargingStructureTestData.csv")] nren_names = set([row["nren"] for row in data]) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for row in data: nren = nren_dict[row["nren"]] @@ -192,15 +148,16 @@ def test_charging_structure_data(): if fee_type == "null": fee_type = None - session.add( + db.session.add( model.ChargingStructure( nren=nren, year=year, fee_type=fee_type) ) + db.session.commit() @pytest.fixture -def test_organization_data(): +def test_organization_data(app): def _generate_sub_org_data(): for nren in ["nren" + str(i) for i in range(1, 50)]: for year in range(2016, 2021): @@ -220,21 +177,21 @@ def test_organization_data(): 'name': 'org' + str(year) } - with db.session_scope() as session: + with app.app_context(): org_data = list(_generate_org_data()) sub_org_data = list(_generate_sub_org_data()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for org in org_data: nren = nren_dict[org["nren"]] year = org["year"] name = org["name"] - session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) + db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) for sub_org in sub_org_data: nren = nren_dict[sub_org["nren"]] @@ -242,13 +199,13 @@ def test_organization_data(): name = sub_org["name"] role = sub_org["role"] - session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) + db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) - session.commit() + db.session.commit() @pytest.fixture -def test_ec_project_data(): +def test_ec_project_data(app): def _generate_ec_project_data(): for nren in ["nren" + str(i) for i in range(1, 50)]: for year in range(2016, 2021): @@ -264,19 +221,19 @@ def test_ec_project_data(): 'project': 'ec_project2', } - with db.session_scope() as session: + with app.app_context(): ec_project_data = list(_generate_ec_project_data()) nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in set(d['nren'] for d in ec_project_data)} - session.add_all(nren_dict.values()) + db.session.add_all(nren_dict.values()) for ec_project in ec_project_data: nren = nren_dict[ec_project["nren"]] year = ec_project["year"] project = ec_project["project"] - session.add(model.ECProject(nren=nren, year=year, project=project)) + db.session.add(model.ECProject(nren=nren, year=year, project=project)) - session.commit() + db.session.commit() diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py index ec8802edec3d0a3f16d7e62cb01277dcce554ed8..72267691c3e2e5568c77cb9ea3c38b401642962b 100644 --- a/test/test_survey_publisher_2022.py +++ b/test/test_survey_publisher_2022.py @@ -1,5 +1,6 @@ -from compendium_v2 import db -from compendium_v2.db import model +from sqlalchemy import select + +from compendium_v2.db import db, model from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \ StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion @@ -109,7 +110,7 @@ org_dataKTU,"NOC, administrative authority" ] -def test_publisher(client, mocker, dummy_config): +def test_publisher(app_with_survey_db, mocker, dummy_config): global org_data def get_rows_as_tuples(*args, **kwargs): @@ -186,19 +187,24 @@ def test_publisher(client, mocker, dummy_config): mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data) mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data) - with db.session_scope() as session: - nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] + with app_with_survey_db.app_context(): + db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.commit() - _cli(dummy_config) + _cli(dummy_config, app_with_survey_db) - with db.session_scope() as session: - budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() + with app_with_survey_db.app_context(): + budgets = db.session.scalars( + select(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()) + ).all() assert len(budgets) == 3 assert budgets[0].nren.name.lower() == 'nren1' assert budgets[0].budget == 100 - funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() + funding_sources = db.session.scalars( + select(model.FundingSource).order_by(model.FundingSource.nren_id.asc()) + ).all() assert len(funding_sources) == 3 assert funding_sources[0].nren.name.lower() == 'nren1' assert funding_sources[0].client_institutions == 10 @@ -215,7 +221,9 @@ def test_publisher(client, mocker, dummy_config): assert funding_sources[2].european_funding == 30 assert funding_sources[2].other == 30 - staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() + staff_data = db.session.scalars( + select(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()) + ).all() assert len(staff_data) == 3 assert staff_data[0].nren.name.lower() == 'nren1' @@ -236,7 +244,9 @@ def test_publisher(client, mocker, dummy_config): assert staff_data[2].permanent_fte == 30 assert staff_data[2].subcontracted_fte == 0 - _org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() + _org_data = db.session.scalars( + select(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()) + ).all() assert len(_org_data) == 2 assert _org_data[0].nren.name.lower() == 'nren1' @@ -245,8 +255,9 @@ def test_publisher(client, mocker, dummy_config): assert _org_data[1].nren.name.lower() == 'nren3' assert _org_data[1].organization == 'Org3' - charging_structures = session.query(model.ChargingStructure).order_by( - model.ChargingStructure.nren_id.asc()).all() + charging_structures = db.session.scalars( + select(model.ChargingStructure).order_by(model.ChargingStructure.nren_id.asc()) + ).all() assert len(charging_structures) == 3 assert charging_structures[0].nren.name.lower() == 'nren1' assert charging_structures[0].fee_type == model.FeeType.no_charge @@ -255,7 +266,9 @@ def test_publisher(client, mocker, dummy_config): assert charging_structures[2].nren.name.lower() == 'nren3' assert charging_structures[2].fee_type == model.FeeType.other - _ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() + _ec_data = db.session.scalars( + select(model.ECProject).order_by(model.ECProject.nren_id.asc()) + ).all() assert len(_ec_data) == 3 assert _ec_data[0].nren.name.lower() == 'nren2' diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py index c7ebc2d909e15f3be872d01d52a30ebba94af024..496cf71be61a4e3a683c8ae9b1d8cb9e1980881a 100644 --- a/test/test_survey_publisher_v1.py +++ b/test/test_survey_publisher_v1.py @@ -1,5 +1,7 @@ import os +from sqlalchemy import select, func + from compendium_v2 import db from compendium_v2.db import model from compendium_v2.publishers.survey_publisher_v1 import _cli @@ -7,23 +9,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx") -def test_publisher(client, mocker, dummy_config): +def test_publisher(app_with_survey_db, mocker, dummy_config): mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE) - with db.session_scope() as session: + with app_with_survey_db.app_context(): nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] - session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) + db.session.commit() - _cli(dummy_config) + _cli(dummy_config, app_with_survey_db) - with db.session_scope() as session: - budget_count = session.query(model.BudgetEntry.year).count() + with app_with_survey_db.app_context(): + budget_count = db.session.scalar(select(func.count(model.BudgetEntry.year))) assert budget_count - funding_source_count = session.query(model.FundingSource.year).count() + funding_source_count = db.session.scalar(select(func.count(model.FundingSource.year))) assert funding_source_count - charging_structure_count = session.query(model.ChargingStructure.year).count() + charging_structure_count = db.session.scalar(select(func.count(model.ChargingStructure.year))) assert charging_structure_count - staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() + staff_data = db.session.scalars(select(model.NrenStaff).order_by(model.NrenStaff.year.asc())).all() # data should only be saved for the NRENs we have saved in the database staff_data_nrens = set([staff.nren.name for staff in staff_data]) @@ -69,7 +72,7 @@ def test_publisher(client, mocker, dummy_config): assert kifu_data[5].technical_fte == 133 assert kifu_data[5].non_technical_fte == 45 - ecproject_data = session.query(model.ECProject).all() + ecproject_data = db.session.scalars(select(model.ECProject)).all() # test a couple of random entries surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017] assert len(surf2017) == 1 @@ -83,7 +86,7 @@ def test_publisher(client, mocker, dummy_config): assert len(kifu2019) == 4 assert kifu2019[3].project == 'SuperHeroes for Science' - parent_data = session.query(model.ParentOrganization).all() + parent_data = db.session.scalars(select(model.ParentOrganization)).all() # test a random entry asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021] assert len(asnet2021) == 1