From 6e62575f75375af493a76a394264a293cb80c7b8 Mon Sep 17 00:00:00 2001
From: Remco Tukker <remco.tukker@geant.org>
Date: Thu, 4 May 2023 10:00:14 +0200
Subject: [PATCH] use flask-sqlalchemy for the main db
---
compendium_v2/__init__.py | 12 +-
compendium_v2/db/__init__.py | 50 +-
compendium_v2/db/model.py | 48 +-
compendium_v2/migrations/env.py | 2 +-
compendium_v2/migrations/migration_utils.py | 8 +-
compendium_v2/publishers/helpers.py | 13 +-
.../publishers/survey_publisher_2022.py | 485 +++++++++---------
.../publishers/survey_publisher_v1.py | 272 +++++-----
compendium_v2/routes/budget.py | 31 +-
compendium_v2/routes/charging.py | 29 +-
compendium_v2/routes/ec_projects.py | 24 +-
compendium_v2/routes/funding.py | 27 +-
compendium_v2/routes/organization.py | 29 +-
compendium_v2/routes/staff.py | 24 +-
compendium_v2/survey_db/__init__.py | 5 -
requirements.txt | 1 +
setup.py | 1 +
test/conftest.py | 122 ++---
test/test_survey_publisher_2022.py | 28 +-
test/test_survey_publisher_v1.py | 23 +-
20 files changed, 545 insertions(+), 689 deletions(-)
diff --git a/compendium_v2/__init__.py b/compendium_v2/__init__.py
index 622565f2..7a69bf9e 100644
--- a/compendium_v2/__init__.py
+++ b/compendium_v2/__init__.py
@@ -8,7 +8,7 @@ from flask import Flask
from flask_cors import CORS # for debugging
from compendium_v2 import config, environment
-
+from compendium_v2.db import db
from compendium_v2.migrations import migration_utils
@@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask:
return app
+def _create_app_with_db(app_config) -> Flask:
+ # used by the tests and the publishers
+ app = _create_app(app_config)
+ app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI']
+ db.init_app(app)
+ return app
+
+
def create_app() -> Flask:
"""
overrides default settings with those found
@@ -46,7 +54,7 @@ def create_app() -> Flask:
with open(os.environ['SETTINGS_FILENAME']) as f:
app_config = config.load(f)
- app = _create_app(app_config)
+ app = _create_app_with_db(app_config)
logging.info('Flask app initialized')
diff --git a/compendium_v2/db/__init__.py b/compendium_v2/db/__init__.py
index ce48aaa0..13719e0a 100644
--- a/compendium_v2/db/__init__.py
+++ b/compendium_v2/db/__init__.py
@@ -1,45 +1,17 @@
-import contextlib
import logging
-from typing import Optional, Union, Callable, Iterator
-from sqlalchemy import create_engine
-from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.orm import sessionmaker, Session
+from flask_sqlalchemy import SQLAlchemy
+from sqlalchemy import MetaData
-logger = logging.getLogger(__name__)
-_SESSION_MAKER: Union[None, sessionmaker] = None
-
-
-@contextlib.contextmanager
-def session_scope(
- callback_before_close: Optional[Callable] = None) -> Iterator[Session]:
- # best practice is to keep session scope separate from data processing
- # cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html
-
- assert _SESSION_MAKER
- session = _SESSION_MAKER()
- try:
- yield session
- session.commit()
- if callback_before_close:
- callback_before_close()
- except SQLAlchemyError:
- logger.error('caught sql layer exception, rolling back')
- session.rollback()
- raise # re-raise, will be handled by main consumer
- finally:
- session.close()
-
-
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
- return (f'postgresql://{db_username}:{db_password}'
- f'@{db_hostname}:{port}/{db_name}')
+logger = logging.getLogger(__name__)
-def init_db_model(dsn):
- global _SESSION_MAKER
+metadata_obj = MetaData(naming_convention={
+ "ix": "ix_%(column_0_label)s",
+ "uq": "uq_%(table_name)s_%(column_0_name)s",
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+ "pk": "pk_%(table_name)s",
+})
- # cf. https://docs.sqlalchemy.org/en
- # /latest/orm/extensions/automap.html
- engine = create_engine(dsn, pool_size=10)
- _SESSION_MAKER = sessionmaker(bind=engine)
+db = SQLAlchemy(metadata=metadata_obj)
diff --git a/compendium_v2/db/model.py b/compendium_v2/db/model.py
index 67672743..2199f759 100644
--- a/compendium_v2/db/model.py
+++ b/compendium_v2/db/model.py
@@ -4,42 +4,34 @@ from enum import Enum
from typing import Optional
from typing_extensions import Annotated
-from sqlalchemy import MetaData, String
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from sqlalchemy import String
+from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.schema import ForeignKey
+from compendium_v2.db import db
-logger = logging.getLogger(__name__)
-
-convention = {
- "ix": "ix_%(column_0_label)s",
- "uq": "uq_%(table_name)s_%(column_0_name)s",
- "ck": "ck_%(table_name)s_%(constraint_name)s",
- "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
- "pk": "pk_%(table_name)s",
-}
-metadata_obj = MetaData(naming_convention=convention)
+logger = logging.getLogger(__name__)
str128 = Annotated[str, 128]
+str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)]
+str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)]
int_pk = Annotated[int, mapped_column(primary_key=True)]
int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)]
-class Base(DeclarativeBase):
- metadata = metadata_obj
- type_annotation_map = {
- str128: String(128),
- }
+# Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet.
+# See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140
+# mypy: disable-error-code="name-defined"
-class NREN(Base):
+class NREN(db.Model):
__tablename__ = 'nren'
id: Mapped[int_pk]
name: Mapped[str128]
-class BudgetEntry(Base):
+class BudgetEntry(db.Model):
__tablename__ = 'budgets'
nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined')
@@ -47,7 +39,7 @@ class BudgetEntry(Base):
budget: Mapped[Decimal]
-class FundingSource(Base):
+class FundingSource(db.Model):
__tablename__ = 'funding_source'
nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined')
@@ -67,7 +59,7 @@ class FeeType(Enum):
other = "other"
-class ChargingStructure(Base):
+class ChargingStructure(db.Model):
__tablename__ = 'charging_structure'
nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined')
@@ -75,7 +67,7 @@ class ChargingStructure(Base):
fee_type: Mapped[Optional[FeeType]]
-class NrenStaff(Base):
+class NrenStaff(db.Model):
__tablename__ = 'nren_staff'
nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined')
@@ -86,7 +78,7 @@ class NrenStaff(Base):
non_technical_fte: Mapped[Decimal]
-class ParentOrganization(Base):
+class ParentOrganization(db.Model):
__tablename__ = 'parent_organization'
nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined')
@@ -94,18 +86,18 @@ class ParentOrganization(Base):
organization: Mapped[str128]
-class SubOrganization(Base):
+class SubOrganization(db.Model):
__tablename__ = 'sub_organization'
nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined')
year: Mapped[int_pk]
- organization: Mapped[str128] = mapped_column(primary_key=True)
+ organization: Mapped[str128_pk]
role: Mapped[str128]
-class ECProject(Base):
+class ECProject(db.Model):
__tablename__ = 'ec_project'
nren_id: Mapped[int_pk_fkNREN]
- nren: Mapped[NREN] = relationship(NREN, lazy='joined')
+ nren: Mapped[NREN] = relationship(lazy='joined')
year: Mapped[int_pk]
- project: Mapped[str] = mapped_column(String(256), primary_key=True)
+ project: Mapped[str256_pk]
diff --git a/compendium_v2/migrations/env.py b/compendium_v2/migrations/env.py
index 0307be33..5ea9c8d3 100644
--- a/compendium_v2/migrations/env.py
+++ b/compendium_v2/migrations/env.py
@@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
-from compendium_v2.db.model import metadata_obj
+from compendium_v2.db import metadata_obj
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
diff --git a/compendium_v2/migrations/migration_utils.py b/compendium_v2/migrations/migration_utils.py
index 3b29b540..f25f03c9 100644
--- a/compendium_v2/migrations/migration_utils.py
+++ b/compendium_v2/migrations/migration_utils.py
@@ -1,7 +1,6 @@
import logging
import os
-from compendium_v2 import db
from alembic.config import Config
from alembic import command
@@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY):
command.upgrade(alembic_config, 'head')
+def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
+ return (f'postgresql://{db_username}:{db_password}'
+ f'@{db_hostname}:{port}/{db_name}')
+
+
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
- upgrade(db.postgresql_dsn(
+ upgrade(postgresql_dsn(
db_username='compendium',
db_password='compendium321',
db_hostname='localhost',
diff --git a/compendium_v2/publishers/helpers.py b/compendium_v2/publishers/helpers.py
index d95ad4e1..43d1bf02 100644
--- a/compendium_v2/publishers/helpers.py
+++ b/compendium_v2/publishers/helpers.py
@@ -1,21 +1,20 @@
-from compendium_v2 import db, survey_db
-from compendium_v2.db import model
+from sqlalchemy import select
+
+from compendium_v2 import survey_db
+from compendium_v2.db import db, model
def init_db(config):
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
dsn_survey = config['SURVEY_DATABASE_URI']
survey_db.init_db_model(dsn_survey)
-def get_uppercase_nren_dict(session):
+def get_uppercase_nren_dict():
"""
- :param session: db session that is used to query the known NRENs
:return: a dictionary of all known NRENs db entities keyed on the
uppercased name
"""
- current_nrens = session.query(model.NREN).all()
+ current_nrens = db.session.scalars(select(model.NREN))
nren_dict = {nren.name.upper(): nren for nren in current_nrens}
# add aliases that are used in the source data:
nren_dict['ASNET'] = nren_dict['ASNET-AM']
diff --git a/compendium_v2/publishers/survey_publisher_2022.py b/compendium_v2/publishers/survey_publisher_2022.py
index 710c899d..ea88ee77 100644
--- a/compendium_v2/publishers/survey_publisher_2022.py
+++ b/compendium_v2/publishers/survey_publisher_2022.py
@@ -16,11 +16,12 @@ import html
from sqlalchemy import text
from collections import defaultdict
+import compendium_v2
from compendium_v2.db.model import FeeType
from compendium_v2.environment import setup_logging
from compendium_v2.config import load
-from compendium_v2 import db, survey_db
-from compendium_v2.db import model
+from compendium_v2 import survey_db
+from compendium_v2.db import db, model
from compendium_v2.publishers import helpers
setup_logging()
@@ -133,163 +134,151 @@ def query_question(question: enum.Enum):
return survey.execute(text(query))
-def transfer_budget():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- rows = query_budget()
- for row in rows:
+def transfer_budget(nren_dict):
+ rows = query_budget()
+ for row in rows:
+ nren_name = row[0].upper()
+ _budget = row[1]
+ try:
+ budget = float(_budget.replace('"', '').replace(',', ''))
+ except ValueError:
+ logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))')
+ continue
+
+ if budget > 200:
+ logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})')
+
+ if nren_name not in nren_dict:
+ logger.info(f'{nren_name} unknown. Skipping.')
+ continue
+
+ budget_entry = model.BudgetEntry(
+ nren=nren_dict[nren_name],
+ budget=budget,
+ year=2022,
+ )
+ db.session.merge(budget_entry)
+ db.session.commit()
+
+
+def transfer_funding_sources(nren_dict):
+ sourcedata = {}
+ for source, data in query_funding_sources():
+ for row in data:
nren_name = row[0].upper()
- _budget = row[1]
+ _value = row[1]
try:
- budget = float(_budget.replace('"', '').replace(',', ''))
+ value = float(_value.replace('"', '').replace(',', ''))
except ValueError:
- logger.info(f'{nren_name} has no budget for 2022. Skipping. ({_budget}))')
- continue
-
- if budget > 200:
- logger.info(f'{nren_name} has budget set to >200M EUR for 2022. ({budget})')
-
- if nren_name not in nren_dict:
- logger.info(f'{nren_name} unknown. Skipping.')
- continue
+ name = source.name
+ logger.info(f'{nren_name} has invalid value for {name}. ({_value}))')
+ value = 0
- budget_entry = model.BudgetEntry(
- nren=nren_dict[nren_name],
- budget=budget,
- year=2022,
+ nren_info = sourcedata.setdefault(
+ nren_name,
+ {source_type: 0 for source_type in FundingSource}
)
- session.merge(budget_entry)
- session.commit()
-
-
-def transfer_funding_sources():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- sourcedata = {}
- for source, data in query_funding_sources():
- for row in data:
- nren_name = row[0].upper()
- _value = row[1]
- try:
- value = float(_value.replace('"', '').replace(',', ''))
- except ValueError:
- name = source.name
- logger.info(f'{nren_name} has invalid value for {name}. ({_value}))')
- value = 0
-
- nren_info = sourcedata.setdefault(
- nren_name,
- {source_type: 0 for source_type in FundingSource}
- )
- nren_info[source] = value
-
- for nren_name, nren_info in sourcedata.items():
- total = sum(nren_info.values())
-
- if not math.isclose(total, 100, abs_tol=0.01):
- logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})')
+ nren_info[source] = value
+
+ for nren_name, nren_info in sourcedata.items():
+ total = sum(nren_info.values())
+
+ if not math.isclose(total, 100, abs_tol=0.01):
+ logger.info(f'{nren_name} funding sources do not sum to 100%. ({total})')
+
+ if nren_name not in nren_dict:
+ logger.info(f'{nren_name} unknown. Skipping.')
+ continue
+
+ funding_source = model.FundingSource(
+ nren=nren_dict[nren_name],
+ year=2022,
+ client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS],
+ european_funding=nren_info[FundingSource.EUROPEAN_FUNDING],
+ gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES],
+ commercial=nren_info[FundingSource.COMMERCIAL],
+ other=nren_info[FundingSource.OTHER],
+ )
+ db.session.merge(funding_source)
+ db.session.commit()
+
+
+def transfer_staff_data(nren_dict):
+ data = {}
+ for question in StaffQuestion:
+ rows = query_question(question)
+ for row in rows:
+ nren_name = row[0].upper()
+ _value = row[1]
+ try:
+ value = float(_value.replace('"', '').replace(',', ''))
+ except ValueError:
+ value = 0
if nren_name not in nren_dict:
logger.info(f'{nren_name} unknown. Skipping.')
continue
- funding_source = model.FundingSource(
- nren=nren_dict[nren_name],
- year=2022,
- client_institutions=nren_info[FundingSource.CLIENT_INSTITUTIONS],
- european_funding=nren_info[FundingSource.EUROPEAN_FUNDING],
- gov_public_bodies=nren_info[FundingSource.GOV_PUBLIC_BODIES],
- commercial=nren_info[FundingSource.COMMERCIAL],
- other=nren_info[FundingSource.OTHER],
- )
- session.merge(funding_source)
- session.commit()
-
-
-def transfer_staff_data():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- data = {}
- for question in StaffQuestion:
- rows = query_question(question)
- for row in rows:
- nren_name = row[0].upper()
- _value = row[1]
- try:
- value = float(_value.replace('"', '').replace(',', ''))
- except ValueError:
- value = 0
-
- if nren_name not in nren_dict:
- logger.info(f'{nren_name} unknown. Skipping.')
- continue
-
- # initialize on first use, so we don't add data for nrens with no answers
- data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[
- question] = value
-
- for nren_name, nren_info in data.items():
- if sum([nren_info[question] for question in StaffQuestion]) == 0:
- logger.info(f'{nren_name} has no staff data. Deleting if exists.')
- session.query(model.NrenStaff).filter(
- model.NrenStaff.nren_id == nren_dict[nren_name].id,
- model.NrenStaff.year == 2022,
- ).delete()
- continue
-
- employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
- technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE]
-
- if not math.isclose(employed, technical, abs_tol=0.01):
- logger.info(f'{nren_name} FTE do not equal across employed/technical categories.'
- f' ({employed} != {technical})')
-
- staff_data = model.NrenStaff(
- nren_id=nren_dict[nren_name].id,
- year=2022,
- permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE],
- subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE],
- technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE],
- non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE],
- )
- session.merge(staff_data)
- session.commit()
-
-
-def transfer_nren_parent_org():
+ # initialize on first use, so we don't add data for nrens with no answers
+ data.setdefault(nren_name, {question: 0 for question in StaffQuestion})[
+ question] = value
+
+ for nren_name, nren_info in data.items():
+ if sum([nren_info[question] for question in StaffQuestion]) == 0:
+ logger.info(f'{nren_name} has no staff data. Deleting if exists.')
+ db.session.query(model.NrenStaff).filter(
+ model.NrenStaff.nren_id == nren_dict[nren_name].id,
+ model.NrenStaff.year == 2022,
+ ).delete()
+ continue
+
+ employed = nren_info[StaffQuestion.PERMANENT_FTE] + nren_info[StaffQuestion.SUBCONTRACTED_FTE]
+ technical = nren_info[StaffQuestion.TECHNICAL_FTE] + nren_info[StaffQuestion.NON_TECHNICAL_FTE]
+
+ if not math.isclose(employed, technical, abs_tol=0.01):
+ logger.info(f'{nren_name} FTE do not equal across employed/technical categories.'
+ f' ({employed} != {technical})')
+
+ staff_data = model.NrenStaff(
+ nren_id=nren_dict[nren_name].id,
+ year=2022,
+ permanent_fte=nren_info[StaffQuestion.PERMANENT_FTE],
+ subcontracted_fte=nren_info[StaffQuestion.SUBCONTRACTED_FTE],
+ technical_fte=nren_info[StaffQuestion.TECHNICAL_FTE],
+ non_technical_fte=nren_info[StaffQuestion.NON_TECHNICAL_FTE],
+ )
+ db.session.merge(staff_data)
+ db.session.commit()
+
+
+def transfer_nren_parent_org(nren_dict):
# clean up the data a bit by removing some strings
strings_to_replace = [
'We are affiliated to '
]
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- rows = query_question(OrgQuestion.PARENT_ORG_NAME)
- for row in rows:
- nren_name = row[0].upper()
- value = str(row[1]).replace('"', '')
+ rows = query_question(OrgQuestion.PARENT_ORG_NAME)
+ for row in rows:
+ nren_name = row[0].upper()
+ value = str(row[1]).replace('"', '')
- for string in strings_to_replace:
- value = value.replace(string, '')
+ for string in strings_to_replace:
+ value = value.replace(string, '')
- if nren_name not in nren_dict:
- logger.info(f'{nren_name} unknown. Skipping.')
- continue
+ if nren_name not in nren_dict:
+ logger.info(f'{nren_name} unknown. Skipping.')
+ continue
- parent_org = model.ParentOrganization(
- nren_id=nren_dict[nren_name].id,
- year=2022,
- organization=value,
- )
- session.merge(parent_org)
- session.commit()
+ parent_org = model.ParentOrganization(
+ nren_id=nren_dict[nren_name].id,
+ year=2022,
+ organization=value,
+ )
+ db.session.merge(parent_org)
+ db.session.commit()
-def transfer_nren_sub_org():
+def transfer_nren_sub_org(nren_dict):
suborg_questions = [
(OrgQuestion.SUB_ORGS_1_NAME, OrgQuestion.SUB_ORGS_1_CHOICE, OrgQuestion.SUB_ORGS_1_ROLE),
(OrgQuestion.SUB_ORGS_2_NAME, OrgQuestion.SUB_ORGS_2_CHOICE, OrgQuestion.SUB_ORGS_2_ROLE),
@@ -298,140 +287,134 @@ def transfer_nren_sub_org():
(OrgQuestion.SUB_ORGS_5_NAME, OrgQuestion.SUB_ORGS_5_CHOICE, OrgQuestion.SUB_ORGS_5_ROLE)
]
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- lookup = defaultdict(list)
-
- for name, choice, role in suborg_questions:
- _name_rows = query_question(name)
- _choice_rows = query_question(choice)
- _role_rows = list(query_question(role))
- for _name, _choice in zip(_name_rows, _choice_rows):
- nren_name = _name[0].upper()
- suborg_name = _name[1].replace('"', '').strip()
- role_choice = _choice[1].replace('"', '').strip()
-
- if nren_name not in nren_dict:
- logger.info(f'{nren_name} unknown. Skipping.')
- continue
-
- if role_choice.lower() == 'other':
- for _role in _role_rows:
- if _role[0] == _name[0]:
- role = _role[1].replace('"', '').strip()
- break
- else:
- role = role_choice
-
- lookup[nren_name].append((suborg_name, role))
-
- for nren_name, suborgs in lookup.items():
- for suborg_name, role in suborgs:
- suborg = model.SubOrganization(
- nren_id=nren_dict[nren_name].id,
- year=2022,
- organization=suborg_name,
- role=role,
- )
- session.merge(suborg)
- session.commit()
-
-
-def transfer_charging_structure():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- rows = query_question(ChargingStructure.charging_structure)
- for row in rows:
- nren_name = row[0].upper()
- value = row[1].replace('"', '').strip()
+ lookup = defaultdict(list)
+
+ for name, choice, role in suborg_questions:
+ _name_rows = query_question(name)
+ _choice_rows = query_question(choice)
+ _role_rows = list(query_question(role))
+ for _name, _choice in zip(_name_rows, _choice_rows):
+ nren_name = _name[0].upper()
+ suborg_name = _name[1].replace('"', '').strip()
+ role_choice = _choice[1].replace('"', '').strip()
if nren_name not in nren_dict:
- logger.info(f'{nren_name} unknown. Skipping from charging structure.')
+ logger.info(f'{nren_name} unknown. Skipping.')
continue
- if "do not charge" in value:
- charging_structure = FeeType.no_charge
- elif "combination" in value:
- charging_structure = FeeType.combination
- elif "flat" in value:
- charging_structure = FeeType.flat_fee
- elif "usage-based" in value:
- charging_structure = FeeType.usage_based_fee
- elif "Other" in value:
- charging_structure = FeeType.other
+ if role_choice.lower() == 'other':
+ for _role in _role_rows:
+ if _role[0] == _name[0]:
+ role = _role[1].replace('"', '').strip()
+ break
else:
- charging_structure = None
+ role = role_choice
+
+ lookup[nren_name].append((suborg_name, role))
- charging_structure = model.ChargingStructure(
+ for nren_name, suborgs in lookup.items():
+ for suborg_name, role in suborgs:
+ suborg = model.SubOrganization(
nren_id=nren_dict[nren_name].id,
year=2022,
- fee_type=charging_structure,
+ organization=suborg_name,
+ role=role,
)
- session.merge(charging_structure)
- session.commit()
-
-
-def transfer_ec_projects():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- # delete all existing EC projects, in case something changed
- session.query(model.ECProject).filter(
- model.ECProject.year == 2022,
- ).delete()
-
- rows = query_question(ECQuestion.EC_PROJECT)
- for row in rows:
- nren_name = row[0].upper()
-
- if nren_name not in nren_dict:
- logger.info(f'{nren_name} unknown. Skipping.')
+ db.session.merge(suborg)
+ db.session.commit()
+
+
+def transfer_charging_structure(nren_dict):
+ rows = query_question(ChargingStructure.charging_structure)
+ for row in rows:
+ nren_name = row[0].upper()
+ value = row[1].replace('"', '').strip()
+
+ if nren_name not in nren_dict:
+ logger.info(f'{nren_name} unknown. Skipping from charging structure.')
+ continue
+
+ if "do not charge" in value:
+ charging_structure = FeeType.no_charge
+ elif "combination" in value:
+ charging_structure = FeeType.combination
+ elif "flat" in value:
+ charging_structure = FeeType.flat_fee
+ elif "usage-based" in value:
+ charging_structure = FeeType.usage_based_fee
+ elif "Other" in value:
+ charging_structure = FeeType.other
+ else:
+ charging_structure = None
+
+ charging_structure = model.ChargingStructure(
+ nren_id=nren_dict[nren_name].id,
+ year=2022,
+ fee_type=charging_structure,
+ )
+ db.session.merge(charging_structure)
+ db.session.commit()
+
+
+def transfer_ec_projects(nren_dict):
+ # delete all existing EC projects, in case something changed
+ db.session.query(model.ECProject).filter(
+ model.ECProject.year == 2022,
+ ).delete()
+
+ rows = query_question(ECQuestion.EC_PROJECT)
+ for row in rows:
+ nren_name = row[0].upper()
+
+ if nren_name not in nren_dict:
+ logger.info(f'{nren_name} unknown. Skipping.')
+ continue
+
+ try:
+ value = json.loads(row[1])
+ except json.decoder.JSONDecodeError:
+ logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.')
+ continue
+
+ for val in value:
+ if not val:
+ logger.info(f'Invalid EC project value for {nren_name}: {val}.')
continue
- try:
- value = json.loads(row[1])
- except json.decoder.JSONDecodeError:
- logger.info(f'JSON decode error for EC project data for {nren_name}. Skipping.')
- continue
-
- for val in value:
- if not val:
- logger.info(f'Invalid EC project value for {nren_name}: {val}.')
- continue
-
- # strip html entities/NBSP from val
- val = html.unescape(val).replace('\xa0', ' ')
+ # strip html entities/NBSP from val
+ val = html.unescape(val).replace('\xa0', ' ')
- # some answers include contract numbers, which we don't want here
- val = val.split('(contract n')[0]
+ # some answers include contract numbers, which we don't want here
+ val = val.split('(contract n')[0]
- ec_project = model.ECProject(
- nren_id=nren_dict[nren_name].id,
- year=2022,
- project=str(val).strip()
- )
- session.add(ec_project)
- session.commit()
+ ec_project = model.ECProject(
+ nren_id=nren_dict[nren_name].id,
+ year=2022,
+ project=str(val).strip()
+ )
+ db.session.add(ec_project)
+ db.session.commit()
-def _cli(config):
+def _cli(config, app):
helpers.init_db(config)
- transfer_budget()
- transfer_funding_sources()
- transfer_staff_data()
- transfer_nren_parent_org()
- transfer_nren_sub_org()
- transfer_charging_structure()
- transfer_ec_projects()
+ with app.app_context():
+ nren_dict = helpers.get_uppercase_nren_dict()
+ transfer_budget(nren_dict)
+ transfer_funding_sources(nren_dict)
+ transfer_staff_data(nren_dict)
+ transfer_nren_parent_org(nren_dict)
+ transfer_nren_sub_org(nren_dict)
+ transfer_charging_structure(nren_dict)
+ transfer_ec_projects(nren_dict)
@click.command()
@click.option('--config', type=click.STRING, default='config.json')
def cli(config):
app_config = load(open(config, 'r'))
- _cli(app_config)
+ app = compendium_v2._create_app_with_db(app_config)
+ _cli(app_config, app)
if __name__ == "__main__":
diff --git a/compendium_v2/publishers/survey_publisher_v1.py b/compendium_v2/publishers/survey_publisher_v1.py
index ba0e4e2a..5f4cd8b2 100644
--- a/compendium_v2/publishers/survey_publisher_v1.py
+++ b/compendium_v2/publishers/survey_publisher_v1.py
@@ -11,11 +11,12 @@ import logging
import math
import click
+import compendium_v2
from compendium_v2.environment import setup_logging
-from compendium_v2 import db, survey_db
+from compendium_v2 import survey_db
from compendium_v2.background_task import parse_excel_data
from compendium_v2.config import load
-from compendium_v2.db import model
+from compendium_v2.db import db, model
from compendium_v2.survey_db import model as survey_model
from compendium_v2.publishers import helpers
@@ -24,11 +25,8 @@ setup_logging()
logger = logging.getLogger('survey-publisher-v1')
-def db_budget_migration():
- with survey_db.session_scope() as survey_session, \
- db.session_scope() as session:
-
- nren_dict = helpers.get_uppercase_nren_dict(session)
+def db_budget_migration(nren_dict):
+ with survey_db.session_scope() as survey_session:
# move data from Survey DB budget table
data = survey_session.query(survey_model.Nrens)
@@ -49,7 +47,7 @@ def db_budget_migration():
budget=float(budget.budget),
year=year
)
- session.merge(budget_entry)
+ db.session.merge(budget_entry)
# Import the data from excel sheet to database
exceldata = parse_excel_data.fetch_budget_excel_data()
@@ -63,165 +61,153 @@ def db_budget_migration():
logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})')
budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year)
- session.merge(budget_entry)
- session.commit()
-
-
-def db_funding_migration():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- # Import the data to database
- data = parse_excel_data.fetch_funding_excel_data()
-
- for (abbrev, year, client_institution,
- european_funding,
- gov_public_bodies,
- commercial, other) in data:
-
- _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
- total = sum(_data)
- if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
- logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
-
- if abbrev not in nren_dict:
- logger.warning(f'{abbrev} unknown. Skipping.')
- continue
-
- budget_entry = model.FundingSource(
- nren=nren_dict[abbrev],
- year=year,
- client_institutions=client_institution,
- european_funding=european_funding,
- gov_public_bodies=gov_public_bodies,
- commercial=commercial,
- other=other)
- session.merge(budget_entry)
- session.commit()
-
-
-def db_charging_structure_migration():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- # Import the data to database
- data = parse_excel_data.fetch_charging_structure_excel_data()
-
- for (abbrev, year, charging_structure) in data:
- if abbrev not in nren_dict:
- logger.warning(f'{abbrev} unknown. Skipping.')
- continue
-
- charging_structure_entry = model.ChargingStructure(
- nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
- session.merge(charging_structure_entry)
- session.commit()
-
-
-def db_staffing_migration():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- staff_data = parse_excel_data.fetch_staffing_excel_data()
-
- nren_staff_map = {}
- for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
- if abbrev not in nren_dict:
- logger.warning(f'{abbrev} unknown. Skipping staff data.')
- continue
-
- nren = nren_dict[abbrev]
+ db.session.merge(budget_entry)
+ db.session.commit()
+
+
+def db_funding_migration(nren_dict):
+ # Import the data to database
+ data = parse_excel_data.fetch_funding_excel_data()
+
+ for (abbrev, year, client_institution,
+ european_funding,
+ gov_public_bodies,
+ commercial, other) in data:
+
+ _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
+ total = sum(_data)
+ if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
+ logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
+
+ if abbrev not in nren_dict:
+ logger.warning(f'{abbrev} unknown. Skipping.')
+ continue
+
+ budget_entry = model.FundingSource(
+ nren=nren_dict[abbrev],
+ year=year,
+ client_institutions=client_institution,
+ european_funding=european_funding,
+ gov_public_bodies=gov_public_bodies,
+ commercial=commercial,
+ other=other)
+ db.session.merge(budget_entry)
+ db.session.commit()
+
+
+def db_charging_structure_migration(nren_dict):
+ # Import the data to database
+ data = parse_excel_data.fetch_charging_structure_excel_data()
+
+ for (abbrev, year, charging_structure) in data:
+ if abbrev not in nren_dict:
+ logger.warning(f'{abbrev} unknown. Skipping.')
+ continue
+
+ charging_structure_entry = model.ChargingStructure(
+ nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
+ db.session.merge(charging_structure_entry)
+ db.session.commit()
+
+
+def db_staffing_migration(nren_dict):
+ staff_data = parse_excel_data.fetch_staffing_excel_data()
+
+ nren_staff_map = {}
+ for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
+ if abbrev not in nren_dict:
+ logger.warning(f'{abbrev} unknown. Skipping staff data.')
+ continue
+
+ nren = nren_dict[abbrev]
+ nren_staff_map[(nren.id, year)] = model.NrenStaff(
+ nren=nren,
+ nren_id=nren.id,
+ year=year,
+ permanent_fte=permanent_fte,
+ subcontracted_fte=subcontracted_fte,
+ technical_fte=0,
+ non_technical_fte=0
+ )
+
+ function_data = parse_excel_data.fetch_staff_function_excel_data()
+ for (abbrev, year, technical_fte, non_technical_fte) in function_data:
+ if abbrev not in nren_dict:
+ logger.warning(f'{abbrev} unknown. Skipping staff function data.')
+ continue
+
+ nren = nren_dict[abbrev]
+ if (nren.id, year) in nren_staff_map:
+ nren_staff_map[(nren.id, year)].technical_fte = technical_fte
+ nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
+ else:
nren_staff_map[(nren.id, year)] = model.NrenStaff(
nren=nren,
nren_id=nren.id,
year=year,
- permanent_fte=permanent_fte,
- subcontracted_fte=subcontracted_fte,
- technical_fte=0,
- non_technical_fte=0
+ permanent_fte=0,
+ subcontracted_fte=0,
+ technical_fte=technical_fte,
+ non_technical_fte=non_technical_fte
)
- function_data = parse_excel_data.fetch_staff_function_excel_data()
- for (abbrev, year, technical_fte, non_technical_fte) in function_data:
- if abbrev not in nren_dict:
- logger.warning(f'{abbrev} unknown. Skipping staff function data.')
- continue
-
- nren = nren_dict[abbrev]
- if (nren.id, year) in nren_staff_map:
- nren_staff_map[(nren.id, year)].technical_fte = technical_fte
- nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
- else:
- nren_staff_map[(nren.id, year)] = model.NrenStaff(
- nren=nren,
- nren_id=nren.id,
- year=year,
- permanent_fte=0,
- subcontracted_fte=0,
- technical_fte=technical_fte,
- non_technical_fte=non_technical_fte
- )
-
- for nren_staff_model in nren_staff_map.values():
- employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
- technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
- if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
- logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
- f' FTE do not equal across employed/technical categories ({employed} != {technical})')
-
- session.merge(nren_staff_model)
+ for nren_staff_model in nren_staff_map.values():
+ employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
+ technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
+ if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
+ logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
+ f' FTE do not equal across employed/technical categories ({employed} != {technical})')
- session.commit()
+ db.session.merge(nren_staff_model)
+ db.session.commit()
-def db_ecprojects_migration():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
-
- ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
- for (abbrev, year, project) in ecproject_data:
- if abbrev not in nren_dict:
- logger.warning(f'{abbrev} unknown. Skipping.')
- continue
- nren = nren_dict[abbrev]
- ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
- session.merge(ecproject_entry)
- session.commit()
+def db_ecprojects_migration(nren_dict):
+ ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
+ for (abbrev, year, project) in ecproject_data:
+ if abbrev not in nren_dict:
+ logger.warning(f'{abbrev} unknown. Skipping.')
+ continue
+ nren = nren_dict[abbrev]
+ ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
+ db.session.merge(ecproject_entry)
+ db.session.commit()
-def db_organizations_migration():
- with db.session_scope() as session:
- nren_dict = helpers.get_uppercase_nren_dict(session)
- organization_data = parse_excel_data.fetch_organization_excel_data()
- for (abbrev, year, org) in organization_data:
- if abbrev not in nren_dict:
- logger.warning(f'{abbrev} unknown. Skipping.')
- continue
+def db_organizations_migration(nren_dict):
+ organization_data = parse_excel_data.fetch_organization_excel_data()
+ for (abbrev, year, org) in organization_data:
+ if abbrev not in nren_dict:
+ logger.warning(f'{abbrev} unknown. Skipping.')
+ continue
- nren = nren_dict[abbrev]
- org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
- session.merge(org_entry)
- session.commit()
+ nren = nren_dict[abbrev]
+ org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
+ db.session.merge(org_entry)
+ db.session.commit()
-def _cli(config):
+def _cli(config, app):
helpers.init_db(config)
- db_budget_migration()
- db_funding_migration()
- db_charging_structure_migration()
- db_staffing_migration()
- db_ecprojects_migration()
- db_organizations_migration()
+ with app.app_context():
+ nren_dict = helpers.get_uppercase_nren_dict()
+ db_budget_migration(nren_dict)
+ db_funding_migration(nren_dict)
+ db_charging_structure_migration(nren_dict)
+ db_staffing_migration(nren_dict)
+ db_ecprojects_migration(nren_dict)
+ db_organizations_migration(nren_dict)
@click.command()
@click.option('--config', type=click.STRING, default='config.json')
def cli(config):
app_config = load(open(config, 'r'))
+ app = compendium_v2._create_app_with_db(app_config)
print("survey-publisher-v1 starting")
- _cli(app_config)
+ _cli(app_config, app)
if __name__ == "__main__":
diff --git a/compendium_v2/routes/budget.py b/compendium_v2/routes/budget.py
index 763b3af3..1f4ff850 100644
--- a/compendium_v2/routes/budget.py
+++ b/compendium_v2/routes/budget.py
@@ -1,28 +1,17 @@
import logging
from typing import Any
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
-from compendium_v2 import db
-from compendium_v2.db import model
+from compendium_v2.db import db
+from compendium_v2.db.model import BudgetEntry
from compendium_v2.routes import common
-routes = Blueprint('budget', __name__)
-
-
-@routes.before_request
-def before_request():
- config = current_app.config['CONFIG_PARAMS']
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
-
+routes = Blueprint('budget', __name__)
logger = logging.getLogger(__name__)
-col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061',
- '#bd7ebe', '#ffb55a', '#ffee65',
- '#beb9db', '#fdcce5', '#8bd3c7']
-
BUDGET_RESPONSE_SCHEMA = {
'$schema': 'http://json-schema.org/draft-07/schema#',
@@ -58,15 +47,15 @@ def budget_view() -> Any:
:return:
"""
- def _extract_data(entry: model.BudgetEntry):
+ def _extract_data(entry: BudgetEntry):
return {
'NREN': entry.nren.name,
'BUDGET': float(entry.budget),
'BUDGET_YEAR': entry.year,
}
- with db.session_scope() as session:
- entries = sorted([_extract_data(entry)
- for entry in session.query(model.BudgetEntry)],
- key=lambda d: (d['BUDGET_YEAR'], d['NREN']))
+ entries = sorted(
+ [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))],
+ key=lambda d: (d['BUDGET_YEAR'], d['NREN'])
+ )
return jsonify(entries)
diff --git a/compendium_v2/routes/charging.py b/compendium_v2/routes/charging.py
index 0b2c2384..57e8114f 100644
--- a/compendium_v2/routes/charging.py
+++ b/compendium_v2/routes/charging.py
@@ -1,21 +1,15 @@
import logging
-
-from flask import Blueprint, jsonify, current_app
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
from typing import Any
-routes = Blueprint('charging', __name__)
-
+from flask import Blueprint, jsonify
+from sqlalchemy import select
-@routes.before_request
-def before_request():
- config = current_app.config['CONFIG_PARAMS']
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ChargingStructure
+from compendium_v2.routes import common
+routes = Blueprint('charging', __name__)
logger = logging.getLogger(__name__)
CHARGING_STRUCTURE_RESPONSE_SCHEMA = {
@@ -53,16 +47,15 @@ def charging_structure_view() -> Any:
:return:
"""
- def _extract_data(entry: model.ChargingStructure):
+ def _extract_data(entry: ChargingStructure):
return {
'NREN': entry.nren.name,
'YEAR': int(entry.year),
'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None,
}
- with db.session_scope() as session:
- entries = sorted([_extract_data(entry)
- for entry in session.query(model.ChargingStructure)
- .all()],
- key=lambda d: (d['NREN'], d['YEAR']))
+ entries = sorted(
+ [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))],
+ key=lambda d: (d['NREN'], d['YEAR'])
+ )
return jsonify(entries)
diff --git a/compendium_v2/routes/ec_projects.py b/compendium_v2/routes/ec_projects.py
index 7114718d..b58d8931 100644
--- a/compendium_v2/routes/ec_projects.py
+++ b/compendium_v2/routes/ec_projects.py
@@ -1,22 +1,15 @@
import logging
-
-from flask import Blueprint, jsonify, current_app
-
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
from typing import Any
-routes = Blueprint('ec-projects', __name__)
+from flask import Blueprint, jsonify
+from sqlalchemy import select
-
-@routes.before_request
-def before_request():
- config = current_app.config['CONFIG_PARAMS']
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ECProject
+from compendium_v2.routes import common
+routes = Blueprint('ec-projects', __name__)
logger = logging.getLogger(__name__)
EC_PROJECTS_RESPONSE_SCHEMA = {
@@ -55,13 +48,12 @@ def ec_projects_view() -> Any:
:return:
"""
- def _extract_project(entry: model.ECProject):
+ def _extract_project(entry: ECProject):
return {
'nren': entry.nren.name,
'year': entry.year,
'project': entry.project
}
- with db.session_scope() as session:
- result = [_extract_project(project) for project in session.query(model.ECProject)]
+ result = [_extract_project(project) for project in db.session.scalars(select(ECProject))]
return jsonify(result)
diff --git a/compendium_v2/routes/funding.py b/compendium_v2/routes/funding.py
index ed0e26c8..c02bf136 100644
--- a/compendium_v2/routes/funding.py
+++ b/compendium_v2/routes/funding.py
@@ -1,22 +1,15 @@
import logging
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
-from compendium_v2 import db
from compendium_v2.routes import common
-from compendium_v2.db import model
+from compendium_v2.db import db
+from compendium_v2.db.model import FundingSource
from typing import Any
-routes = Blueprint('funding', __name__)
-
-
-@routes.before_request
-def before_request():
- config = current_app.config['CONFIG_PARAMS']
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
-
+routes = Blueprint('funding', __name__)
logger = logging.getLogger(__name__)
FUNDING_RESPONSE_SCHEMA = {
@@ -59,7 +52,7 @@ def funding_source_view() -> Any:
:return:
"""
- def _extract_data(entry: model.FundingSource):
+ def _extract_data(entry: FundingSource):
return {
'NREN': entry.nren.name,
'YEAR': entry.year,
@@ -70,8 +63,8 @@ def funding_source_view() -> Any:
'OTHER': float(entry.other)
}
- with db.session_scope() as session:
- entries = sorted([_extract_data(entry)
- for entry in session.query(model.FundingSource)],
- key=lambda d: (d['NREN'], d['YEAR']))
+ entries = sorted(
+ [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))],
+ key=lambda d: (d['NREN'], d['YEAR'])
+ )
return jsonify(entries)
diff --git a/compendium_v2/routes/organization.py b/compendium_v2/routes/organization.py
index 61a43354..8e8ebc8d 100644
--- a/compendium_v2/routes/organization.py
+++ b/compendium_v2/routes/organization.py
@@ -1,22 +1,15 @@
import logging
-
-from flask import Blueprint, jsonify, current_app
-
-from compendium_v2 import db
-from compendium_v2.routes import common
-from compendium_v2.db import model
from typing import Any
-routes = Blueprint('organization', __name__)
+from flask import Blueprint, jsonify
+from sqlalchemy import select
-
-@routes.before_request
-def before_request():
- config = current_app.config['CONFIG_PARAMS']
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
+from compendium_v2.db import db
+from compendium_v2.db.model import ParentOrganization, SubOrganization
+from compendium_v2.routes import common
+routes = Blueprint('organization', __name__)
logger = logging.getLogger(__name__)
ORGANIZATION_RESPONSE_SCHEMA = {
@@ -71,15 +64,14 @@ def parent_organization_view() -> Any:
:return:
"""
- def _extract_parent(entry: model.ParentOrganization):
+ def _extract_parent(entry: ParentOrganization):
return {
'nren': entry.nren.name,
'year': entry.year,
'name': entry.organization
}
- with db.session_scope() as session:
- result = [_extract_parent(org) for org in session.query(model.ParentOrganization)]
+ result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))]
return jsonify(result)
@@ -98,7 +90,7 @@ def sub_organization_view() -> Any:
:return:
"""
- def _extract_sub(entry: model.SubOrganization):
+ def _extract_sub(entry: SubOrganization):
return {
'nren': entry.nren.name,
'year': entry.year,
@@ -106,6 +98,5 @@ def sub_organization_view() -> Any:
'role': entry.role
}
- with db.session_scope() as session:
- result = [_extract_sub(org) for org in session.query(model.SubOrganization)]
+ result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))]
return jsonify(result)
diff --git a/compendium_v2/routes/staff.py b/compendium_v2/routes/staff.py
index 73e79c48..b8478266 100644
--- a/compendium_v2/routes/staff.py
+++ b/compendium_v2/routes/staff.py
@@ -1,22 +1,15 @@
import logging
-from flask import Blueprint, jsonify, current_app
+from flask import Blueprint, jsonify
+from sqlalchemy import select
-from compendium_v2 import db
+from compendium_v2.db import db
+from compendium_v2.db.model import NREN, NrenStaff
from compendium_v2.routes import common
-from compendium_v2.db import model
from typing import Any
-routes = Blueprint('staff', __name__)
-
-
-@routes.before_request
-def before_request():
- config = current_app.config['CONFIG_PARAMS']
- dsn_prn = config['SQLALCHEMY_DATABASE_URI']
- db.init_db_model(dsn_prn)
-
+routes = Blueprint('staff', __name__)
logger = logging.getLogger(__name__)
STAFF_RESPONSE_SCHEMA = {
@@ -57,7 +50,7 @@ def staff_view() -> Any:
:return:
"""
- def _extract_data(entry: model.NrenStaff):
+ def _extract_data(entry: NrenStaff):
return {
'nren': entry.nren.name,
'year': entry.year,
@@ -67,7 +60,6 @@ def staff_view() -> Any:
'non_technical_fte': float(entry.non_technical_fte)
}
- with db.session_scope() as session:
- entries = [_extract_data(entry) for entry in session.query(
- model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())]
+ entries = [_extract_data(entry) for entry in db.session.scalars(
+ select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))]
return jsonify(entries)
diff --git a/compendium_v2/survey_db/__init__.py b/compendium_v2/survey_db/__init__.py
index ce48aaa0..1550ddcb 100644
--- a/compendium_v2/survey_db/__init__.py
+++ b/compendium_v2/survey_db/__init__.py
@@ -31,11 +31,6 @@ def session_scope(
session.close()
-def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
- return (f'postgresql://{db_username}:{db_password}'
- f'@{db_hostname}:{port}/{db_name}')
-
-
def init_db_model(dsn):
global _SESSION_MAKER
diff --git a/requirements.txt b/requirements.txt
index 12ec2076..98804e02 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ click~=8.1
jsonschema~=4.17
flask~=2.2
flask-cors~=3.0
+flask-sqlalchemy~=3.0
openpyxl~=3.1
psycopg2-binary~=2.9
SQLAlchemy~=2.0
diff --git a/setup.py b/setup.py
index 213288e5..d051ddd7 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,7 @@ setup(
'jsonschema~=4.17',
'flask~=2.2',
'flask-cors~=3.0',
+ 'flask-sqlalchemy~=3.0',
'openpyxl~=3.1',
'psycopg2-binary~=2.9',
'SQLAlchemy~=2.0',
diff --git a/test/conftest.py b/test/conftest.py
index 79186212..ba6b9a43 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,20 +1,15 @@
-import json
+import csv
import os
-import tempfile
-import random
-
import pytest
-import compendium_v2
-from compendium_v2 import db
-from compendium_v2.db import model
-from compendium_v2 import survey_db
-from compendium_v2.survey_db import model as survey_model
+import random
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
-import csv
+import compendium_v2
+from compendium_v2.db import db, model
+from compendium_v2.survey_db import model as survey_model
def _test_data_csv(filename):
@@ -43,61 +38,29 @@ def mocked_survey_db(mocker):
@pytest.fixture
-def mocked_db(mocker):
- # cf. https://stackoverflow.com/a/33057675
- engine = create_engine(
- 'sqlite://',
- connect_args={'check_same_thread': False},
- poolclass=StaticPool,
- echo=False)
- model.Base.metadata.create_all(engine)
- mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine))
- mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None)
- mocker.patch('compendium_v2.migrate_database', lambda config: None)
-
-
-@pytest.fixture
-def test_budget_data():
- with db.session_scope() as session:
+def test_budget_data(app):
+ with app.app_context():
data = [row for row in _test_data_csv("BudgetTestData.csv")]
nren_names = set([row["nren"] for row in data])
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
- session.add_all(nren_dict.values())
+ db.session.add_all(nren_dict.values())
for row in data:
nren = nren_dict[row["nren"]]
budget = row["budget"]
year = row["year"]
- session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
-
- with survey_db.session_scope() as session:
- data = _test_data_csv("BudgetTestData.csv")
- nrens = set()
- budgets_data = []
- for row in data:
- nren = row["nren"]
- budget = row["budget"]
- year = row["year"]
- country_code = row["nren"]
-
- nrens.add(nren)
-
- budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code))
-
- for nren in nrens:
- session.add(survey_model.Nrens(abbreviation=nren, country_code=nren))
-
- session.add_all(budgets_data)
+ db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
+ db.session.commit()
@pytest.fixture
-def test_funding_source_data():
- with db.session_scope() as session:
+def test_funding_source_data(app):
+ with app.app_context():
data = [row for row in _test_data_csv("FundingSourceTestData.csv")]
nren_names = set([row["nren"] for row in data])
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
- session.add_all(nren_dict.values())
+ db.session.add_all(nren_dict.values())
for row in data:
nren = nren_dict[row["nren"]]
@@ -108,7 +71,7 @@ def test_funding_source_data():
commercial = row["commercial"]
other = row["other"]
- session.add(
+ db.session.add(
model.FundingSource(
nren=nren, year=year,
client_institutions=client,
@@ -117,10 +80,11 @@ def test_funding_source_data():
commercial=commercial,
other=other)
)
+ db.session.commit()
@pytest.fixture
-def test_staff_data():
+def test_staff_data(app):
# generator of random test data for 5 years and 100 nrens
def _generate_rows():
@@ -135,12 +99,12 @@ def test_staff_data():
"non_technical_fte": random.randint(0, 100)
}
- with db.session_scope() as session:
+ with app.app_context():
data = list(_generate_rows())
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]}
- session.add_all(nren_dict.values())
+ db.session.add_all(nren_dict.values())
for row in data:
nren = nren_dict[row["nren"]]
@@ -150,7 +114,7 @@ def test_staff_data():
technical_fte = row["technical_fte"]
non_technical_fte = row["non_technical_fte"]
- session.add(
+ db.session.add(
model.NrenStaff(
nren=nren,
year=year,
@@ -160,30 +124,29 @@ def test_staff_data():
non_technical_fte=non_technical_fte
)
)
+ db.session.commit()
@pytest.fixture
-def data_config_filename(dummy_config):
- with tempfile.NamedTemporaryFile() as f:
- f.write(json.dumps(dummy_config).encode('utf-8'))
- f.flush()
- yield f.name
+def app(dummy_config):
+ app = compendium_v2._create_app_with_db(dummy_config)
+ with app.app_context():
+ db.create_all()
+ yield app
@pytest.fixture
-def client(data_config_filename, mocked_db, mocked_survey_db):
- os.environ['SETTINGS_FILENAME'] = data_config_filename
- with compendium_v2.create_app().test_client() as c:
- yield c
+def client(app):
+ return app.test_client()
@pytest.fixture
-def test_charging_structure_data():
- with db.session_scope() as session:
+def test_charging_structure_data(app):
+ with app.app_context():
data = [row for row in _test_data_csv("ChargingStructureTestData.csv")]
nren_names = set([row["nren"] for row in data])
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
- session.add_all(nren_dict.values())
+ db.session.add_all(nren_dict.values())
for row in data:
nren = nren_dict[row["nren"]]
@@ -192,15 +155,16 @@ def test_charging_structure_data():
if fee_type == "null":
fee_type = None
- session.add(
+ db.session.add(
model.ChargingStructure(
nren=nren, year=year,
fee_type=fee_type)
)
+ db.session.commit()
@pytest.fixture
-def test_organization_data():
+def test_organization_data(app):
def _generate_sub_org_data():
for nren in ["nren" + str(i) for i in range(1, 50)]:
for year in range(2016, 2021):
@@ -220,21 +184,21 @@ def test_organization_data():
'name': 'org' + str(year)
}
- with db.session_scope() as session:
+ with app.app_context():
org_data = list(_generate_org_data())
sub_org_data = list(_generate_sub_org_data())
nren_dict = {nren_name: model.NREN(name=nren_name)
for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])}
- session.add_all(nren_dict.values())
+ db.session.add_all(nren_dict.values())
for org in org_data:
nren = nren_dict[org["nren"]]
year = org["year"]
name = org["name"]
- session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
+ db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
for sub_org in sub_org_data:
nren = nren_dict[sub_org["nren"]]
@@ -242,13 +206,13 @@ def test_organization_data():
name = sub_org["name"]
role = sub_org["role"]
- session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
+ db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
- session.commit()
+ db.session.commit()
@pytest.fixture
-def test_ec_project_data():
+def test_ec_project_data(app):
def _generate_ec_project_data():
for nren in ["nren" + str(i) for i in range(1, 50)]:
for year in range(2016, 2021):
@@ -264,19 +228,19 @@ def test_ec_project_data():
'project': 'ec_project2',
}
- with db.session_scope() as session:
+ with app.app_context():
ec_project_data = list(_generate_ec_project_data())
nren_dict = {nren_name: model.NREN(name=nren_name)
for nren_name in set(d['nren'] for d in ec_project_data)}
- session.add_all(nren_dict.values())
+ db.session.add_all(nren_dict.values())
for ec_project in ec_project_data:
nren = nren_dict[ec_project["nren"]]
year = ec_project["year"]
project = ec_project["project"]
- session.add(model.ECProject(nren=nren, year=year, project=project))
+ db.session.add(model.ECProject(nren=nren, year=year, project=project))
- session.commit()
+ db.session.commit()
diff --git a/test/test_survey_publisher_2022.py b/test/test_survey_publisher_2022.py
index ec8802ed..433def2c 100644
--- a/test/test_survey_publisher_2022.py
+++ b/test/test_survey_publisher_2022.py
@@ -1,5 +1,4 @@
-from compendium_v2 import db
-from compendium_v2.db import model
+from compendium_v2.db import db, model
from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \
StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion
@@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority"
]
-def test_publisher(client, mocker, dummy_config):
+def test_publisher(app, mocker, dummy_config):
global org_data
def get_rows_as_tuples(*args, **kwargs):
@@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config):
mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data)
mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data)
- with db.session_scope() as session:
- nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
- session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+ nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
+ with app.app_context():
+ db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+ db.session.commit()
- _cli(dummy_config)
+ _cli(dummy_config, app)
- with db.session_scope() as session:
- budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
+ with app.app_context():
+ budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
assert len(budgets) == 3
assert budgets[0].nren.name.lower() == 'nren1'
assert budgets[0].budget == 100
- funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
+ funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
assert len(funding_sources) == 3
assert funding_sources[0].nren.name.lower() == 'nren1'
assert funding_sources[0].client_institutions == 10
@@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config):
assert funding_sources[2].european_funding == 30
assert funding_sources[2].other == 30
- staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
+ staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
assert len(staff_data) == 3
assert staff_data[0].nren.name.lower() == 'nren1'
@@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config):
assert staff_data[2].permanent_fte == 30
assert staff_data[2].subcontracted_fte == 0
- _org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
+ _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
assert len(_org_data) == 2
assert _org_data[0].nren.name.lower() == 'nren1'
@@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config):
assert _org_data[1].nren.name.lower() == 'nren3'
assert _org_data[1].organization == 'Org3'
- charging_structures = session.query(model.ChargingStructure).order_by(
+ charging_structures = db.session.query(model.ChargingStructure).order_by(
model.ChargingStructure.nren_id.asc()).all()
assert len(charging_structures) == 3
assert charging_structures[0].nren.name.lower() == 'nren1'
@@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config):
assert charging_structures[2].nren.name.lower() == 'nren3'
assert charging_structures[2].fee_type == model.FeeType.other
- _ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
+ _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
assert len(_ec_data) == 3
assert _ec_data[0].nren.name.lower() == 'nren2'
diff --git a/test/test_survey_publisher_v1.py b/test/test_survey_publisher_v1.py
index c7ebc2d9..a6dbbaa1 100644
--- a/test/test_survey_publisher_v1.py
+++ b/test/test_survey_publisher_v1.py
@@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli
EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx")
-def test_publisher(client, mocker, dummy_config):
+def test_publisher(mocked_survey_db, app, mocker, dummy_config):
mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE)
- with db.session_scope() as session:
+ with app.app_context():
nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
- session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+ db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
+ db.session.commit()
- _cli(dummy_config)
+ _cli(dummy_config, app)
- with db.session_scope() as session:
- budget_count = session.query(model.BudgetEntry.year).count()
+ with app.app_context():
+ budget_count = db.session.query(model.BudgetEntry.year).count()
assert budget_count
- funding_source_count = session.query(model.FundingSource.year).count()
+ funding_source_count = db.session.query(model.FundingSource.year).count()
assert funding_source_count
- charging_structure_count = session.query(model.ChargingStructure.year).count()
+ charging_structure_count = db.session.query(model.ChargingStructure.year).count()
assert charging_structure_count
- staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
+ staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
# data should only be saved for the NRENs we have saved in the database
staff_data_nrens = set([staff.nren.name for staff in staff_data])
@@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config):
assert kifu_data[5].technical_fte == 133
assert kifu_data[5].non_technical_fte == 45
- ecproject_data = session.query(model.ECProject).all()
+ ecproject_data = db.session.query(model.ECProject).all()
# test a couple of random entries
surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017]
assert len(surf2017) == 1
@@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config):
assert len(kifu2019) == 4
assert kifu2019[3].project == 'SuperHeroes for Science'
- parent_data = session.query(model.ParentOrganization).all()
+ parent_data = db.session.query(model.ParentOrganization).all()
# test a random entry
asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021]
assert len(asnet2021) == 1
--
GitLab