Skip to content
Snippets Groups Projects
Commit 6e62575f authored by Remco Tukker's avatar Remco Tukker
Browse files

use flask-sqlalchemy for the main db

parent 7b3587a5
Branches
Tags
1 merge request!18use flask-sqlalchemy for the main db
Showing
with 545 additions and 689 deletions
...@@ -8,7 +8,7 @@ from flask import Flask ...@@ -8,7 +8,7 @@ from flask import Flask
from flask_cors import CORS # for debugging from flask_cors import CORS # for debugging
from compendium_v2 import config, environment from compendium_v2 import config, environment
from compendium_v2.db import db
from compendium_v2.migrations import migration_utils from compendium_v2.migrations import migration_utils
...@@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask: ...@@ -33,6 +33,14 @@ def _create_app(app_config) -> Flask:
return app return app
def _create_app_with_db(app_config) -> Flask:
# used by the tests and the publishers
app = _create_app(app_config)
app.config['SQLALCHEMY_DATABASE_URI'] = app.config['CONFIG_PARAMS']['SQLALCHEMY_DATABASE_URI']
db.init_app(app)
return app
def create_app() -> Flask: def create_app() -> Flask:
""" """
overrides default settings with those found overrides default settings with those found
...@@ -46,7 +54,7 @@ def create_app() -> Flask: ...@@ -46,7 +54,7 @@ def create_app() -> Flask:
with open(os.environ['SETTINGS_FILENAME']) as f: with open(os.environ['SETTINGS_FILENAME']) as f:
app_config = config.load(f) app_config = config.load(f)
app = _create_app(app_config) app = _create_app_with_db(app_config)
logging.info('Flask app initialized') logging.info('Flask app initialized')
......
import contextlib
import logging import logging
from typing import Optional, Union, Callable, Iterator
from sqlalchemy import create_engine from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import MetaData
from sqlalchemy.orm import sessionmaker, Session
logger = logging.getLogger(__name__)
_SESSION_MAKER: Union[None, sessionmaker] = None
@contextlib.contextmanager
def session_scope(
callback_before_close: Optional[Callable] = None) -> Iterator[Session]:
# best practice is to keep session scope separate from data processing
# cf. https://docs.sqlalchemy.org/en/13/orm/session_basics.html
assert _SESSION_MAKER
session = _SESSION_MAKER()
try:
yield session
session.commit()
if callback_before_close:
callback_before_close()
except SQLAlchemyError:
logger.error('caught sql layer exception, rolling back')
session.rollback()
raise # re-raise, will be handled by main consumer
finally:
session.close()
def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
return (f'postgresql://{db_username}:{db_password}'
f'@{db_hostname}:{port}/{db_name}')
logger = logging.getLogger(__name__)
def init_db_model(dsn): metadata_obj = MetaData(naming_convention={
global _SESSION_MAKER "ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
})
# cf. https://docs.sqlalchemy.org/en db = SQLAlchemy(metadata=metadata_obj)
# /latest/orm/extensions/automap.html
engine = create_engine(dsn, pool_size=10)
_SESSION_MAKER = sessionmaker(bind=engine)
...@@ -4,42 +4,34 @@ from enum import Enum ...@@ -4,42 +4,34 @@ from enum import Enum
from typing import Optional from typing import Optional
from typing_extensions import Annotated from typing_extensions import Annotated
from sqlalchemy import MetaData, String from sqlalchemy import String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import ForeignKey
from compendium_v2.db import db
logger = logging.getLogger(__name__)
convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
metadata_obj = MetaData(naming_convention=convention) logger = logging.getLogger(__name__)
str128 = Annotated[str, 128] str128 = Annotated[str, 128]
str128_pk = Annotated[str, mapped_column(String(128), primary_key=True)]
str256_pk = Annotated[str, mapped_column(String(256), primary_key=True)]
int_pk = Annotated[int, mapped_column(primary_key=True)] int_pk = Annotated[int, mapped_column(primary_key=True)]
int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)] int_pk_fkNREN = Annotated[int, mapped_column(ForeignKey("nren.id"), primary_key=True)]
class Base(DeclarativeBase): # Unfortunately flask-sqlalchemy doesnt fully support DeclarativeBase yet.
metadata = metadata_obj # See https://github.com/pallets-eco/flask-sqlalchemy/issues/1140
type_annotation_map = { # mypy: disable-error-code="name-defined"
str128: String(128),
}
class NREN(Base): class NREN(db.Model):
__tablename__ = 'nren' __tablename__ = 'nren'
id: Mapped[int_pk] id: Mapped[int_pk]
name: Mapped[str128] name: Mapped[str128]
class BudgetEntry(Base): class BudgetEntry(db.Model):
__tablename__ = 'budgets' __tablename__ = 'budgets'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
...@@ -47,7 +39,7 @@ class BudgetEntry(Base): ...@@ -47,7 +39,7 @@ class BudgetEntry(Base):
budget: Mapped[Decimal] budget: Mapped[Decimal]
class FundingSource(Base): class FundingSource(db.Model):
__tablename__ = 'funding_source' __tablename__ = 'funding_source'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
...@@ -67,7 +59,7 @@ class FeeType(Enum): ...@@ -67,7 +59,7 @@ class FeeType(Enum):
other = "other" other = "other"
class ChargingStructure(Base): class ChargingStructure(db.Model):
__tablename__ = 'charging_structure' __tablename__ = 'charging_structure'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
...@@ -75,7 +67,7 @@ class ChargingStructure(Base): ...@@ -75,7 +67,7 @@ class ChargingStructure(Base):
fee_type: Mapped[Optional[FeeType]] fee_type: Mapped[Optional[FeeType]]
class NrenStaff(Base): class NrenStaff(db.Model):
__tablename__ = 'nren_staff' __tablename__ = 'nren_staff'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
...@@ -86,7 +78,7 @@ class NrenStaff(Base): ...@@ -86,7 +78,7 @@ class NrenStaff(Base):
non_technical_fte: Mapped[Decimal] non_technical_fte: Mapped[Decimal]
class ParentOrganization(Base): class ParentOrganization(db.Model):
__tablename__ = 'parent_organization' __tablename__ = 'parent_organization'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
...@@ -94,18 +86,18 @@ class ParentOrganization(Base): ...@@ -94,18 +86,18 @@ class ParentOrganization(Base):
organization: Mapped[str128] organization: Mapped[str128]
class SubOrganization(Base): class SubOrganization(db.Model):
__tablename__ = 'sub_organization' __tablename__ = 'sub_organization'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
year: Mapped[int_pk] year: Mapped[int_pk]
organization: Mapped[str128] = mapped_column(primary_key=True) organization: Mapped[str128_pk]
role: Mapped[str128] role: Mapped[str128]
class ECProject(Base): class ECProject(db.Model):
__tablename__ = 'ec_project' __tablename__ = 'ec_project'
nren_id: Mapped[int_pk_fkNREN] nren_id: Mapped[int_pk_fkNREN]
nren: Mapped[NREN] = relationship(NREN, lazy='joined') nren: Mapped[NREN] = relationship(lazy='joined')
year: Mapped[int_pk] year: Mapped[int_pk]
project: Mapped[str] = mapped_column(String(256), primary_key=True) project: Mapped[str256_pk]
...@@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config ...@@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool from sqlalchemy import pool
from alembic import context from alembic import context
from compendium_v2.db.model import metadata_obj from compendium_v2.db import metadata_obj
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
......
import logging import logging
import os import os
from compendium_v2 import db
from alembic.config import Config from alembic.config import Config
from alembic import command from alembic import command
...@@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY): ...@@ -27,9 +26,14 @@ def upgrade(dsn, migrations_directory=DEFAULT_MIGRATIONS_DIRECTORY):
command.upgrade(alembic_config, 'head') command.upgrade(alembic_config, 'head')
def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
return (f'postgresql://{db_username}:{db_password}'
f'@{db_hostname}:{port}/{db_name}')
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
upgrade(db.postgresql_dsn( upgrade(postgresql_dsn(
db_username='compendium', db_username='compendium',
db_password='compendium321', db_password='compendium321',
db_hostname='localhost', db_hostname='localhost',
......
from compendium_v2 import db, survey_db from sqlalchemy import select
from compendium_v2.db import model
from compendium_v2 import survey_db
from compendium_v2.db import db, model
def init_db(config): def init_db(config):
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
dsn_survey = config['SURVEY_DATABASE_URI'] dsn_survey = config['SURVEY_DATABASE_URI']
survey_db.init_db_model(dsn_survey) survey_db.init_db_model(dsn_survey)
def get_uppercase_nren_dict(session): def get_uppercase_nren_dict():
""" """
:param session: db session that is used to query the known NRENs
:return: a dictionary of all known NRENs db entities keyed on the :return: a dictionary of all known NRENs db entities keyed on the
uppercased name uppercased name
""" """
current_nrens = session.query(model.NREN).all() current_nrens = db.session.scalars(select(model.NREN))
nren_dict = {nren.name.upper(): nren for nren in current_nrens} nren_dict = {nren.name.upper(): nren for nren in current_nrens}
# add aliases that are used in the source data: # add aliases that are used in the source data:
nren_dict['ASNET'] = nren_dict['ASNET-AM'] nren_dict['ASNET'] = nren_dict['ASNET-AM']
......
This diff is collapsed.
...@@ -11,11 +11,12 @@ import logging ...@@ -11,11 +11,12 @@ import logging
import math import math
import click import click
import compendium_v2
from compendium_v2.environment import setup_logging from compendium_v2.environment import setup_logging
from compendium_v2 import db, survey_db from compendium_v2 import survey_db
from compendium_v2.background_task import parse_excel_data from compendium_v2.background_task import parse_excel_data
from compendium_v2.config import load from compendium_v2.config import load
from compendium_v2.db import model from compendium_v2.db import db, model
from compendium_v2.survey_db import model as survey_model from compendium_v2.survey_db import model as survey_model
from compendium_v2.publishers import helpers from compendium_v2.publishers import helpers
...@@ -24,11 +25,8 @@ setup_logging() ...@@ -24,11 +25,8 @@ setup_logging()
logger = logging.getLogger('survey-publisher-v1') logger = logging.getLogger('survey-publisher-v1')
def db_budget_migration(): def db_budget_migration(nren_dict):
with survey_db.session_scope() as survey_session, \ with survey_db.session_scope() as survey_session:
db.session_scope() as session:
nren_dict = helpers.get_uppercase_nren_dict(session)
# move data from Survey DB budget table # move data from Survey DB budget table
data = survey_session.query(survey_model.Nrens) data = survey_session.query(survey_model.Nrens)
...@@ -49,7 +47,7 @@ def db_budget_migration(): ...@@ -49,7 +47,7 @@ def db_budget_migration():
budget=float(budget.budget), budget=float(budget.budget),
year=year year=year
) )
session.merge(budget_entry) db.session.merge(budget_entry)
# Import the data from excel sheet to database # Import the data from excel sheet to database
exceldata = parse_excel_data.fetch_budget_excel_data() exceldata = parse_excel_data.fetch_budget_excel_data()
...@@ -63,165 +61,153 @@ def db_budget_migration(): ...@@ -63,165 +61,153 @@ def db_budget_migration():
logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})') logger.warning(f'{nren} has budget set to >200M EUR for {year}. ({budget})')
budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year) budget_entry = model.BudgetEntry(nren=nren_dict[abbrev], budget=budget, year=year)
session.merge(budget_entry) db.session.merge(budget_entry)
session.commit() db.session.commit()
def db_funding_migration(): def db_funding_migration(nren_dict):
with db.session_scope() as session: # Import the data to database
nren_dict = helpers.get_uppercase_nren_dict(session) data = parse_excel_data.fetch_funding_excel_data()
# Import the data to database for (abbrev, year, client_institution,
data = parse_excel_data.fetch_funding_excel_data() european_funding,
gov_public_bodies,
for (abbrev, year, client_institution, commercial, other) in data:
european_funding,
gov_public_bodies, _data = [client_institution, european_funding, gov_public_bodies, commercial, other]
commercial, other) in data: total = sum(_data)
if not math.isclose(total, 100, abs_tol=0.01) and total != 0:
_data = [client_institution, european_funding, gov_public_bodies, commercial, other] logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})')
total = sum(_data)
if not math.isclose(total, 100, abs_tol=0.01) and total != 0: if abbrev not in nren_dict:
logger.warning(f'{abbrev} funding sources for {year} do not sum to 100% ({total})') logger.warning(f'{abbrev} unknown. Skipping.')
continue
if abbrev not in nren_dict:
logger.warning(f'{abbrev} unknown. Skipping.') budget_entry = model.FundingSource(
continue nren=nren_dict[abbrev],
year=year,
budget_entry = model.FundingSource( client_institutions=client_institution,
nren=nren_dict[abbrev], european_funding=european_funding,
year=year, gov_public_bodies=gov_public_bodies,
client_institutions=client_institution, commercial=commercial,
european_funding=european_funding, other=other)
gov_public_bodies=gov_public_bodies, db.session.merge(budget_entry)
commercial=commercial, db.session.commit()
other=other)
session.merge(budget_entry)
session.commit() def db_charging_structure_migration(nren_dict):
# Import the data to database
data = parse_excel_data.fetch_charging_structure_excel_data()
def db_charging_structure_migration():
with db.session_scope() as session: for (abbrev, year, charging_structure) in data:
nren_dict = helpers.get_uppercase_nren_dict(session) if abbrev not in nren_dict:
logger.warning(f'{abbrev} unknown. Skipping.')
# Import the data to database continue
data = parse_excel_data.fetch_charging_structure_excel_data()
charging_structure_entry = model.ChargingStructure(
for (abbrev, year, charging_structure) in data: nren=nren_dict[abbrev], year=year, fee_type=charging_structure)
if abbrev not in nren_dict: db.session.merge(charging_structure_entry)
logger.warning(f'{abbrev} unknown. Skipping.') db.session.commit()
continue
charging_structure_entry = model.ChargingStructure( def db_staffing_migration(nren_dict):
nren=nren_dict[abbrev], year=year, fee_type=charging_structure) staff_data = parse_excel_data.fetch_staffing_excel_data()
session.merge(charging_structure_entry)
session.commit() nren_staff_map = {}
for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data:
if abbrev not in nren_dict:
def db_staffing_migration(): logger.warning(f'{abbrev} unknown. Skipping staff data.')
with db.session_scope() as session: continue
nren_dict = helpers.get_uppercase_nren_dict(session)
nren = nren_dict[abbrev]
staff_data = parse_excel_data.fetch_staffing_excel_data() nren_staff_map[(nren.id, year)] = model.NrenStaff(
nren=nren,
nren_staff_map = {} nren_id=nren.id,
for (abbrev, year, permanent_fte, subcontracted_fte) in staff_data: year=year,
if abbrev not in nren_dict: permanent_fte=permanent_fte,
logger.warning(f'{abbrev} unknown. Skipping staff data.') subcontracted_fte=subcontracted_fte,
continue technical_fte=0,
non_technical_fte=0
nren = nren_dict[abbrev] )
function_data = parse_excel_data.fetch_staff_function_excel_data()
for (abbrev, year, technical_fte, non_technical_fte) in function_data:
if abbrev not in nren_dict:
logger.warning(f'{abbrev} unknown. Skipping staff function data.')
continue
nren = nren_dict[abbrev]
if (nren.id, year) in nren_staff_map:
nren_staff_map[(nren.id, year)].technical_fte = technical_fte
nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
else:
nren_staff_map[(nren.id, year)] = model.NrenStaff( nren_staff_map[(nren.id, year)] = model.NrenStaff(
nren=nren, nren=nren,
nren_id=nren.id, nren_id=nren.id,
year=year, year=year,
permanent_fte=permanent_fte, permanent_fte=0,
subcontracted_fte=subcontracted_fte, subcontracted_fte=0,
technical_fte=0, technical_fte=technical_fte,
non_technical_fte=0 non_technical_fte=non_technical_fte
) )
function_data = parse_excel_data.fetch_staff_function_excel_data() for nren_staff_model in nren_staff_map.values():
for (abbrev, year, technical_fte, non_technical_fte) in function_data: employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
if abbrev not in nren_dict: technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
logger.warning(f'{abbrev} unknown. Skipping staff function data.') if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
continue logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
f' FTE do not equal across employed/technical categories ({employed} != {technical})')
nren = nren_dict[abbrev]
if (nren.id, year) in nren_staff_map:
nren_staff_map[(nren.id, year)].technical_fte = technical_fte
nren_staff_map[(nren.id, year)].non_technical_fte = non_technical_fte
else:
nren_staff_map[(nren.id, year)] = model.NrenStaff(
nren=nren,
nren_id=nren.id,
year=year,
permanent_fte=0,
subcontracted_fte=0,
technical_fte=technical_fte,
non_technical_fte=non_technical_fte
)
for nren_staff_model in nren_staff_map.values():
employed = nren_staff_model.permanent_fte + nren_staff_model.subcontracted_fte
technical = nren_staff_model.technical_fte + nren_staff_model.non_technical_fte
if not math.isclose(employed, technical, abs_tol=0.01) and employed != 0 and technical != 0:
logger.warning(f'{nren_staff_model.nren.name} in {nren_staff_model.year}:'
f' FTE do not equal across employed/technical categories ({employed} != {technical})')
session.merge(nren_staff_model)
session.commit() db.session.merge(nren_staff_model)
db.session.commit()
def db_ecprojects_migration():
with db.session_scope() as session:
nren_dict = helpers.get_uppercase_nren_dict(session)
ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
for (abbrev, year, project) in ecproject_data:
if abbrev not in nren_dict:
logger.warning(f'{abbrev} unknown. Skipping.')
continue
nren = nren_dict[abbrev] def db_ecprojects_migration(nren_dict):
ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project) ecproject_data = parse_excel_data.fetch_ecproject_excel_data()
session.merge(ecproject_entry) for (abbrev, year, project) in ecproject_data:
session.commit() if abbrev not in nren_dict:
logger.warning(f'{abbrev} unknown. Skipping.')
continue
nren = nren_dict[abbrev]
ecproject_entry = model.ECProject(nren=nren, nren_id=nren.id, year=year, project=project)
db.session.merge(ecproject_entry)
db.session.commit()
def db_organizations_migration():
with db.session_scope() as session:
nren_dict = helpers.get_uppercase_nren_dict(session)
organization_data = parse_excel_data.fetch_organization_excel_data() def db_organizations_migration(nren_dict):
for (abbrev, year, org) in organization_data: organization_data = parse_excel_data.fetch_organization_excel_data()
if abbrev not in nren_dict: for (abbrev, year, org) in organization_data:
logger.warning(f'{abbrev} unknown. Skipping.') if abbrev not in nren_dict:
continue logger.warning(f'{abbrev} unknown. Skipping.')
continue
nren = nren_dict[abbrev] nren = nren_dict[abbrev]
org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org) org_entry = model.ParentOrganization(nren=nren, nren_id=nren.id, year=year, organization=org)
session.merge(org_entry) db.session.merge(org_entry)
session.commit() db.session.commit()
def _cli(config): def _cli(config, app):
helpers.init_db(config) helpers.init_db(config)
db_budget_migration() with app.app_context():
db_funding_migration() nren_dict = helpers.get_uppercase_nren_dict()
db_charging_structure_migration() db_budget_migration(nren_dict)
db_staffing_migration() db_funding_migration(nren_dict)
db_ecprojects_migration() db_charging_structure_migration(nren_dict)
db_organizations_migration() db_staffing_migration(nren_dict)
db_ecprojects_migration(nren_dict)
db_organizations_migration(nren_dict)
@click.command() @click.command()
@click.option('--config', type=click.STRING, default='config.json') @click.option('--config', type=click.STRING, default='config.json')
def cli(config): def cli(config):
app_config = load(open(config, 'r')) app_config = load(open(config, 'r'))
app = compendium_v2._create_app_with_db(app_config)
print("survey-publisher-v1 starting") print("survey-publisher-v1 starting")
_cli(app_config) _cli(app_config, app)
if __name__ == "__main__": if __name__ == "__main__":
......
import logging import logging
from typing import Any from typing import Any
from flask import Blueprint, jsonify, current_app from flask import Blueprint, jsonify
from sqlalchemy import select
from compendium_v2 import db from compendium_v2.db import db
from compendium_v2.db import model from compendium_v2.db.model import BudgetEntry
from compendium_v2.routes import common from compendium_v2.routes import common
routes = Blueprint('budget', __name__)
@routes.before_request
def before_request():
config = current_app.config['CONFIG_PARAMS']
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
routes = Blueprint('budget', __name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
col_pal = ['#fd7f6f', '#7eb0d5', '#b2e061',
'#bd7ebe', '#ffb55a', '#ffee65',
'#beb9db', '#fdcce5', '#8bd3c7']
BUDGET_RESPONSE_SCHEMA = { BUDGET_RESPONSE_SCHEMA = {
'$schema': 'http://json-schema.org/draft-07/schema#', '$schema': 'http://json-schema.org/draft-07/schema#',
...@@ -58,15 +47,15 @@ def budget_view() -> Any: ...@@ -58,15 +47,15 @@ def budget_view() -> Any:
:return: :return:
""" """
def _extract_data(entry: model.BudgetEntry): def _extract_data(entry: BudgetEntry):
return { return {
'NREN': entry.nren.name, 'NREN': entry.nren.name,
'BUDGET': float(entry.budget), 'BUDGET': float(entry.budget),
'BUDGET_YEAR': entry.year, 'BUDGET_YEAR': entry.year,
} }
with db.session_scope() as session: entries = sorted(
entries = sorted([_extract_data(entry) [_extract_data(entry) for entry in db.session.scalars(select(BudgetEntry))],
for entry in session.query(model.BudgetEntry)], key=lambda d: (d['BUDGET_YEAR'], d['NREN'])
key=lambda d: (d['BUDGET_YEAR'], d['NREN'])) )
return jsonify(entries) return jsonify(entries)
import logging import logging
from flask import Blueprint, jsonify, current_app
from compendium_v2 import db
from compendium_v2.routes import common
from compendium_v2.db import model
from typing import Any from typing import Any
routes = Blueprint('charging', __name__) from flask import Blueprint, jsonify
from sqlalchemy import select
@routes.before_request from compendium_v2.db import db
def before_request(): from compendium_v2.db.model import ChargingStructure
config = current_app.config['CONFIG_PARAMS'] from compendium_v2.routes import common
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
routes = Blueprint('charging', __name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CHARGING_STRUCTURE_RESPONSE_SCHEMA = { CHARGING_STRUCTURE_RESPONSE_SCHEMA = {
...@@ -53,16 +47,15 @@ def charging_structure_view() -> Any: ...@@ -53,16 +47,15 @@ def charging_structure_view() -> Any:
:return: :return:
""" """
def _extract_data(entry: model.ChargingStructure): def _extract_data(entry: ChargingStructure):
return { return {
'NREN': entry.nren.name, 'NREN': entry.nren.name,
'YEAR': int(entry.year), 'YEAR': int(entry.year),
'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None, 'FEE_TYPE': entry.fee_type.value if entry.fee_type is not None else None,
} }
with db.session_scope() as session: entries = sorted(
entries = sorted([_extract_data(entry) [_extract_data(entry) for entry in db.session.scalars(select(ChargingStructure))],
for entry in session.query(model.ChargingStructure) key=lambda d: (d['NREN'], d['YEAR'])
.all()], )
key=lambda d: (d['NREN'], d['YEAR']))
return jsonify(entries) return jsonify(entries)
import logging import logging
from flask import Blueprint, jsonify, current_app
from compendium_v2 import db
from compendium_v2.routes import common
from compendium_v2.db import model
from typing import Any from typing import Any
routes = Blueprint('ec-projects', __name__) from flask import Blueprint, jsonify
from sqlalchemy import select
from compendium_v2.db import db
@routes.before_request from compendium_v2.db.model import ECProject
def before_request(): from compendium_v2.routes import common
config = current_app.config['CONFIG_PARAMS']
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
routes = Blueprint('ec-projects', __name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EC_PROJECTS_RESPONSE_SCHEMA = { EC_PROJECTS_RESPONSE_SCHEMA = {
...@@ -55,13 +48,12 @@ def ec_projects_view() -> Any: ...@@ -55,13 +48,12 @@ def ec_projects_view() -> Any:
:return: :return:
""" """
def _extract_project(entry: model.ECProject): def _extract_project(entry: ECProject):
return { return {
'nren': entry.nren.name, 'nren': entry.nren.name,
'year': entry.year, 'year': entry.year,
'project': entry.project 'project': entry.project
} }
with db.session_scope() as session: result = [_extract_project(project) for project in db.session.scalars(select(ECProject))]
result = [_extract_project(project) for project in session.query(model.ECProject)]
return jsonify(result) return jsonify(result)
import logging import logging
from flask import Blueprint, jsonify, current_app from flask import Blueprint, jsonify
from sqlalchemy import select
from compendium_v2 import db
from compendium_v2.routes import common from compendium_v2.routes import common
from compendium_v2.db import model from compendium_v2.db import db
from compendium_v2.db.model import FundingSource
from typing import Any from typing import Any
routes = Blueprint('funding', __name__)
@routes.before_request
def before_request():
config = current_app.config['CONFIG_PARAMS']
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
routes = Blueprint('funding', __name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
FUNDING_RESPONSE_SCHEMA = { FUNDING_RESPONSE_SCHEMA = {
...@@ -59,7 +52,7 @@ def funding_source_view() -> Any: ...@@ -59,7 +52,7 @@ def funding_source_view() -> Any:
:return: :return:
""" """
def _extract_data(entry: model.FundingSource): def _extract_data(entry: FundingSource):
return { return {
'NREN': entry.nren.name, 'NREN': entry.nren.name,
'YEAR': entry.year, 'YEAR': entry.year,
...@@ -70,8 +63,8 @@ def funding_source_view() -> Any: ...@@ -70,8 +63,8 @@ def funding_source_view() -> Any:
'OTHER': float(entry.other) 'OTHER': float(entry.other)
} }
with db.session_scope() as session: entries = sorted(
entries = sorted([_extract_data(entry) [_extract_data(entry) for entry in db.session.scalars(select(FundingSource))],
for entry in session.query(model.FundingSource)], key=lambda d: (d['NREN'], d['YEAR'])
key=lambda d: (d['NREN'], d['YEAR'])) )
return jsonify(entries) return jsonify(entries)
import logging import logging
from flask import Blueprint, jsonify, current_app
from compendium_v2 import db
from compendium_v2.routes import common
from compendium_v2.db import model
from typing import Any from typing import Any
routes = Blueprint('organization', __name__) from flask import Blueprint, jsonify
from sqlalchemy import select
from compendium_v2.db import db
@routes.before_request from compendium_v2.db.model import ParentOrganization, SubOrganization
def before_request(): from compendium_v2.routes import common
config = current_app.config['CONFIG_PARAMS']
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
routes = Blueprint('organization', __name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ORGANIZATION_RESPONSE_SCHEMA = { ORGANIZATION_RESPONSE_SCHEMA = {
...@@ -71,15 +64,14 @@ def parent_organization_view() -> Any: ...@@ -71,15 +64,14 @@ def parent_organization_view() -> Any:
:return: :return:
""" """
def _extract_parent(entry: model.ParentOrganization): def _extract_parent(entry: ParentOrganization):
return { return {
'nren': entry.nren.name, 'nren': entry.nren.name,
'year': entry.year, 'year': entry.year,
'name': entry.organization 'name': entry.organization
} }
with db.session_scope() as session: result = [_extract_parent(org) for org in db.session.scalars(select(ParentOrganization))]
result = [_extract_parent(org) for org in session.query(model.ParentOrganization)]
return jsonify(result) return jsonify(result)
...@@ -98,7 +90,7 @@ def sub_organization_view() -> Any: ...@@ -98,7 +90,7 @@ def sub_organization_view() -> Any:
:return: :return:
""" """
def _extract_sub(entry: model.SubOrganization): def _extract_sub(entry: SubOrganization):
return { return {
'nren': entry.nren.name, 'nren': entry.nren.name,
'year': entry.year, 'year': entry.year,
...@@ -106,6 +98,5 @@ def sub_organization_view() -> Any: ...@@ -106,6 +98,5 @@ def sub_organization_view() -> Any:
'role': entry.role 'role': entry.role
} }
with db.session_scope() as session: result = [_extract_sub(org) for org in db.session.scalars(select(SubOrganization))]
result = [_extract_sub(org) for org in session.query(model.SubOrganization)]
return jsonify(result) return jsonify(result)
import logging import logging
from flask import Blueprint, jsonify, current_app from flask import Blueprint, jsonify
from sqlalchemy import select
from compendium_v2 import db from compendium_v2.db import db
from compendium_v2.db.model import NREN, NrenStaff
from compendium_v2.routes import common from compendium_v2.routes import common
from compendium_v2.db import model
from typing import Any from typing import Any
routes = Blueprint('staff', __name__)
@routes.before_request
def before_request():
config = current_app.config['CONFIG_PARAMS']
dsn_prn = config['SQLALCHEMY_DATABASE_URI']
db.init_db_model(dsn_prn)
routes = Blueprint('staff', __name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
STAFF_RESPONSE_SCHEMA = { STAFF_RESPONSE_SCHEMA = {
...@@ -57,7 +50,7 @@ def staff_view() -> Any: ...@@ -57,7 +50,7 @@ def staff_view() -> Any:
:return: :return:
""" """
def _extract_data(entry: model.NrenStaff): def _extract_data(entry: NrenStaff):
return { return {
'nren': entry.nren.name, 'nren': entry.nren.name,
'year': entry.year, 'year': entry.year,
...@@ -67,7 +60,6 @@ def staff_view() -> Any: ...@@ -67,7 +60,6 @@ def staff_view() -> Any:
'non_technical_fte': float(entry.non_technical_fte) 'non_technical_fte': float(entry.non_technical_fte)
} }
with db.session_scope() as session: entries = [_extract_data(entry) for entry in db.session.scalars(
entries = [_extract_data(entry) for entry in session.query( select(NrenStaff).join(NREN).order_by(NREN.name.asc(), NrenStaff.year.desc()))]
model.NrenStaff).join(model.NREN).order_by(model.NREN.name.asc(), model.NrenStaff.year.desc())]
return jsonify(entries) return jsonify(entries)
...@@ -31,11 +31,6 @@ def session_scope( ...@@ -31,11 +31,6 @@ def session_scope(
session.close() session.close()
def postgresql_dsn(db_username, db_password, db_hostname, db_name, port=5432):
return (f'postgresql://{db_username}:{db_password}'
f'@{db_hostname}:{port}/{db_name}')
def init_db_model(dsn): def init_db_model(dsn):
global _SESSION_MAKER global _SESSION_MAKER
......
...@@ -3,6 +3,7 @@ click~=8.1 ...@@ -3,6 +3,7 @@ click~=8.1
jsonschema~=4.17 jsonschema~=4.17
flask~=2.2 flask~=2.2
flask-cors~=3.0 flask-cors~=3.0
flask-sqlalchemy~=3.0
openpyxl~=3.1 openpyxl~=3.1
psycopg2-binary~=2.9 psycopg2-binary~=2.9
SQLAlchemy~=2.0 SQLAlchemy~=2.0
......
...@@ -15,6 +15,7 @@ setup( ...@@ -15,6 +15,7 @@ setup(
'jsonschema~=4.17', 'jsonschema~=4.17',
'flask~=2.2', 'flask~=2.2',
'flask-cors~=3.0', 'flask-cors~=3.0',
'flask-sqlalchemy~=3.0',
'openpyxl~=3.1', 'openpyxl~=3.1',
'psycopg2-binary~=2.9', 'psycopg2-binary~=2.9',
'SQLAlchemy~=2.0', 'SQLAlchemy~=2.0',
......
import json import csv
import os import os
import tempfile
import random
import pytest import pytest
import compendium_v2 import random
from compendium_v2 import db
from compendium_v2.db import model
from compendium_v2 import survey_db
from compendium_v2.survey_db import model as survey_model
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
import csv import compendium_v2
from compendium_v2.db import db, model
from compendium_v2.survey_db import model as survey_model
def _test_data_csv(filename): def _test_data_csv(filename):
...@@ -43,61 +38,29 @@ def mocked_survey_db(mocker): ...@@ -43,61 +38,29 @@ def mocked_survey_db(mocker):
@pytest.fixture @pytest.fixture
def mocked_db(mocker): def test_budget_data(app):
# cf. https://stackoverflow.com/a/33057675 with app.app_context():
engine = create_engine(
'sqlite://',
connect_args={'check_same_thread': False},
poolclass=StaticPool,
echo=False)
model.Base.metadata.create_all(engine)
mocker.patch('compendium_v2.db._SESSION_MAKER', sessionmaker(bind=engine))
mocker.patch('compendium_v2.db.init_db_model', lambda dsn: None)
mocker.patch('compendium_v2.migrate_database', lambda config: None)
@pytest.fixture
def test_budget_data():
with db.session_scope() as session:
data = [row for row in _test_data_csv("BudgetTestData.csv")] data = [row for row in _test_data_csv("BudgetTestData.csv")]
nren_names = set([row["nren"] for row in data]) nren_names = set([row["nren"] for row in data])
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
session.add_all(nren_dict.values()) db.session.add_all(nren_dict.values())
for row in data: for row in data:
nren = nren_dict[row["nren"]] nren = nren_dict[row["nren"]]
budget = row["budget"] budget = row["budget"]
year = row["year"] year = row["year"]
session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year))) db.session.add(model.BudgetEntry(nren=nren, budget=float(budget), year=int(year)))
db.session.commit()
with survey_db.session_scope() as session:
data = _test_data_csv("BudgetTestData.csv")
nrens = set()
budgets_data = []
for row in data:
nren = row["nren"]
budget = row["budget"]
year = row["year"]
country_code = row["nren"]
nrens.add(nren)
budgets_data.append(survey_model.Budgets(budget=budget, year=year, country_code=country_code))
for nren in nrens:
session.add(survey_model.Nrens(abbreviation=nren, country_code=nren))
session.add_all(budgets_data)
@pytest.fixture @pytest.fixture
def test_funding_source_data(): def test_funding_source_data(app):
with db.session_scope() as session: with app.app_context():
data = [row for row in _test_data_csv("FundingSourceTestData.csv")] data = [row for row in _test_data_csv("FundingSourceTestData.csv")]
nren_names = set([row["nren"] for row in data]) nren_names = set([row["nren"] for row in data])
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
session.add_all(nren_dict.values()) db.session.add_all(nren_dict.values())
for row in data: for row in data:
nren = nren_dict[row["nren"]] nren = nren_dict[row["nren"]]
...@@ -108,7 +71,7 @@ def test_funding_source_data(): ...@@ -108,7 +71,7 @@ def test_funding_source_data():
commercial = row["commercial"] commercial = row["commercial"]
other = row["other"] other = row["other"]
session.add( db.session.add(
model.FundingSource( model.FundingSource(
nren=nren, year=year, nren=nren, year=year,
client_institutions=client, client_institutions=client,
...@@ -117,10 +80,11 @@ def test_funding_source_data(): ...@@ -117,10 +80,11 @@ def test_funding_source_data():
commercial=commercial, commercial=commercial,
other=other) other=other)
) )
db.session.commit()
@pytest.fixture @pytest.fixture
def test_staff_data(): def test_staff_data(app):
# generator of random test data for 5 years and 100 nrens # generator of random test data for 5 years and 100 nrens
def _generate_rows(): def _generate_rows():
...@@ -135,12 +99,12 @@ def test_staff_data(): ...@@ -135,12 +99,12 @@ def test_staff_data():
"non_technical_fte": random.randint(0, 100) "non_technical_fte": random.randint(0, 100)
} }
with db.session_scope() as session: with app.app_context():
data = list(_generate_rows()) data = list(_generate_rows())
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]} nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in [d['nren'] for d in data]}
session.add_all(nren_dict.values()) db.session.add_all(nren_dict.values())
for row in data: for row in data:
nren = nren_dict[row["nren"]] nren = nren_dict[row["nren"]]
...@@ -150,7 +114,7 @@ def test_staff_data(): ...@@ -150,7 +114,7 @@ def test_staff_data():
technical_fte = row["technical_fte"] technical_fte = row["technical_fte"]
non_technical_fte = row["non_technical_fte"] non_technical_fte = row["non_technical_fte"]
session.add( db.session.add(
model.NrenStaff( model.NrenStaff(
nren=nren, nren=nren,
year=year, year=year,
...@@ -160,30 +124,29 @@ def test_staff_data(): ...@@ -160,30 +124,29 @@ def test_staff_data():
non_technical_fte=non_technical_fte non_technical_fte=non_technical_fte
) )
) )
db.session.commit()
@pytest.fixture @pytest.fixture
def data_config_filename(dummy_config): def app(dummy_config):
with tempfile.NamedTemporaryFile() as f: app = compendium_v2._create_app_with_db(dummy_config)
f.write(json.dumps(dummy_config).encode('utf-8')) with app.app_context():
f.flush() db.create_all()
yield f.name yield app
@pytest.fixture @pytest.fixture
def client(data_config_filename, mocked_db, mocked_survey_db): def client(app):
os.environ['SETTINGS_FILENAME'] = data_config_filename return app.test_client()
with compendium_v2.create_app().test_client() as c:
yield c
@pytest.fixture @pytest.fixture
def test_charging_structure_data(): def test_charging_structure_data(app):
with db.session_scope() as session: with app.app_context():
data = [row for row in _test_data_csv("ChargingStructureTestData.csv")] data = [row for row in _test_data_csv("ChargingStructureTestData.csv")]
nren_names = set([row["nren"] for row in data]) nren_names = set([row["nren"] for row in data])
nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names} nren_dict = {nren_name: model.NREN(name=nren_name) for nren_name in nren_names}
session.add_all(nren_dict.values()) db.session.add_all(nren_dict.values())
for row in data: for row in data:
nren = nren_dict[row["nren"]] nren = nren_dict[row["nren"]]
...@@ -192,15 +155,16 @@ def test_charging_structure_data(): ...@@ -192,15 +155,16 @@ def test_charging_structure_data():
if fee_type == "null": if fee_type == "null":
fee_type = None fee_type = None
session.add( db.session.add(
model.ChargingStructure( model.ChargingStructure(
nren=nren, year=year, nren=nren, year=year,
fee_type=fee_type) fee_type=fee_type)
) )
db.session.commit()
@pytest.fixture @pytest.fixture
def test_organization_data(): def test_organization_data(app):
def _generate_sub_org_data(): def _generate_sub_org_data():
for nren in ["nren" + str(i) for i in range(1, 50)]: for nren in ["nren" + str(i) for i in range(1, 50)]:
for year in range(2016, 2021): for year in range(2016, 2021):
...@@ -220,21 +184,21 @@ def test_organization_data(): ...@@ -220,21 +184,21 @@ def test_organization_data():
'name': 'org' + str(year) 'name': 'org' + str(year)
} }
with db.session_scope() as session: with app.app_context():
org_data = list(_generate_org_data()) org_data = list(_generate_org_data())
sub_org_data = list(_generate_sub_org_data()) sub_org_data = list(_generate_sub_org_data())
nren_dict = {nren_name: model.NREN(name=nren_name) nren_dict = {nren_name: model.NREN(name=nren_name)
for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])} for nren_name in set(d['nren'] for d in [*org_data, *sub_org_data])}
session.add_all(nren_dict.values()) db.session.add_all(nren_dict.values())
for org in org_data: for org in org_data:
nren = nren_dict[org["nren"]] nren = nren_dict[org["nren"]]
year = org["year"] year = org["year"]
name = org["name"] name = org["name"]
session.add(model.ParentOrganization(nren=nren, year=year, organization=name)) db.session.add(model.ParentOrganization(nren=nren, year=year, organization=name))
for sub_org in sub_org_data: for sub_org in sub_org_data:
nren = nren_dict[sub_org["nren"]] nren = nren_dict[sub_org["nren"]]
...@@ -242,13 +206,13 @@ def test_organization_data(): ...@@ -242,13 +206,13 @@ def test_organization_data():
name = sub_org["name"] name = sub_org["name"]
role = sub_org["role"] role = sub_org["role"]
session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role)) db.session.add(model.SubOrganization(nren=nren, year=year, organization=name, role=role))
session.commit() db.session.commit()
@pytest.fixture @pytest.fixture
def test_ec_project_data(): def test_ec_project_data(app):
def _generate_ec_project_data(): def _generate_ec_project_data():
for nren in ["nren" + str(i) for i in range(1, 50)]: for nren in ["nren" + str(i) for i in range(1, 50)]:
for year in range(2016, 2021): for year in range(2016, 2021):
...@@ -264,19 +228,19 @@ def test_ec_project_data(): ...@@ -264,19 +228,19 @@ def test_ec_project_data():
'project': 'ec_project2', 'project': 'ec_project2',
} }
with db.session_scope() as session: with app.app_context():
ec_project_data = list(_generate_ec_project_data()) ec_project_data = list(_generate_ec_project_data())
nren_dict = {nren_name: model.NREN(name=nren_name) nren_dict = {nren_name: model.NREN(name=nren_name)
for nren_name in set(d['nren'] for d in ec_project_data)} for nren_name in set(d['nren'] for d in ec_project_data)}
session.add_all(nren_dict.values()) db.session.add_all(nren_dict.values())
for ec_project in ec_project_data: for ec_project in ec_project_data:
nren = nren_dict[ec_project["nren"]] nren = nren_dict[ec_project["nren"]]
year = ec_project["year"] year = ec_project["year"]
project = ec_project["project"] project = ec_project["project"]
session.add(model.ECProject(nren=nren, year=year, project=project)) db.session.add(model.ECProject(nren=nren, year=year, project=project))
session.commit() db.session.commit()
from compendium_v2 import db from compendium_v2.db import db, model
from compendium_v2.db import model
from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \ from compendium_v2.publishers.survey_publisher_2022 import _cli, FundingSource, \
StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion StaffQuestion, OrgQuestion, ChargingStructure, ECQuestion
...@@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority" ...@@ -109,7 +108,7 @@ org_dataKTU,"NOC, administrative authority"
] ]
def test_publisher(client, mocker, dummy_config): def test_publisher(app, mocker, dummy_config):
global org_data global org_data
def get_rows_as_tuples(*args, **kwargs): def get_rows_as_tuples(*args, **kwargs):
...@@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config): ...@@ -186,19 +185,20 @@ def test_publisher(client, mocker, dummy_config):
mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data) mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_funding_sources', funding_source_data)
mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data) mocker.patch('compendium_v2.publishers.survey_publisher_2022.query_question', question_data)
with db.session_scope() as session: nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
nren_names = ['Nren1', 'Nren2', 'Nren3', 'Nren4', 'SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] with app.app_context():
session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
db.session.commit()
_cli(dummy_config) _cli(dummy_config, app)
with db.session_scope() as session: with app.app_context():
budgets = session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all() budgets = db.session.query(model.BudgetEntry).order_by(model.BudgetEntry.nren_id.asc()).all()
assert len(budgets) == 3 assert len(budgets) == 3
assert budgets[0].nren.name.lower() == 'nren1' assert budgets[0].nren.name.lower() == 'nren1'
assert budgets[0].budget == 100 assert budgets[0].budget == 100
funding_sources = session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all() funding_sources = db.session.query(model.FundingSource).order_by(model.FundingSource.nren_id.asc()).all()
assert len(funding_sources) == 3 assert len(funding_sources) == 3
assert funding_sources[0].nren.name.lower() == 'nren1' assert funding_sources[0].nren.name.lower() == 'nren1'
assert funding_sources[0].client_institutions == 10 assert funding_sources[0].client_institutions == 10
...@@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config): ...@@ -215,7 +215,7 @@ def test_publisher(client, mocker, dummy_config):
assert funding_sources[2].european_funding == 30 assert funding_sources[2].european_funding == 30
assert funding_sources[2].other == 30 assert funding_sources[2].other == 30
staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all() staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.nren_id.asc()).all()
assert len(staff_data) == 3 assert len(staff_data) == 3
assert staff_data[0].nren.name.lower() == 'nren1' assert staff_data[0].nren.name.lower() == 'nren1'
...@@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config): ...@@ -236,7 +236,7 @@ def test_publisher(client, mocker, dummy_config):
assert staff_data[2].permanent_fte == 30 assert staff_data[2].permanent_fte == 30
assert staff_data[2].subcontracted_fte == 0 assert staff_data[2].subcontracted_fte == 0
_org_data = session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all() _org_data = db.session.query(model.ParentOrganization).order_by(model.ParentOrganization.nren_id.asc()).all()
assert len(_org_data) == 2 assert len(_org_data) == 2
assert _org_data[0].nren.name.lower() == 'nren1' assert _org_data[0].nren.name.lower() == 'nren1'
...@@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config): ...@@ -245,7 +245,7 @@ def test_publisher(client, mocker, dummy_config):
assert _org_data[1].nren.name.lower() == 'nren3' assert _org_data[1].nren.name.lower() == 'nren3'
assert _org_data[1].organization == 'Org3' assert _org_data[1].organization == 'Org3'
charging_structures = session.query(model.ChargingStructure).order_by( charging_structures = db.session.query(model.ChargingStructure).order_by(
model.ChargingStructure.nren_id.asc()).all() model.ChargingStructure.nren_id.asc()).all()
assert len(charging_structures) == 3 assert len(charging_structures) == 3
assert charging_structures[0].nren.name.lower() == 'nren1' assert charging_structures[0].nren.name.lower() == 'nren1'
...@@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config): ...@@ -255,7 +255,7 @@ def test_publisher(client, mocker, dummy_config):
assert charging_structures[2].nren.name.lower() == 'nren3' assert charging_structures[2].nren.name.lower() == 'nren3'
assert charging_structures[2].fee_type == model.FeeType.other assert charging_structures[2].fee_type == model.FeeType.other
_ec_data = session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all() _ec_data = db.session.query(model.ECProject).order_by(model.ECProject.nren_id.asc()).all()
assert len(_ec_data) == 3 assert len(_ec_data) == 3
assert _ec_data[0].nren.name.lower() == 'nren2' assert _ec_data[0].nren.name.lower() == 'nren2'
......
...@@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli ...@@ -7,23 +7,24 @@ from compendium_v2.publishers.survey_publisher_v1 import _cli
EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx") EXCEL_FILE = os.path.join(os.path.dirname(__file__), "data", "2021_Organisation_DataSeries.xlsx")
def test_publisher(client, mocker, dummy_config): def test_publisher(mocked_survey_db, app, mocker, dummy_config):
mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE) mocker.patch('compendium_v2.background_task.parse_excel_data.EXCEL_FILE', EXCEL_FILE)
with db.session_scope() as session: with app.app_context():
nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH'] nren_names = ['SURF', 'KIFU', 'UoM', 'ASNET-AM', 'SIKT', 'LAT', 'RASH']
session.add_all([model.NREN(name=nren_name) for nren_name in nren_names]) db.session.add_all([model.NREN(name=nren_name) for nren_name in nren_names])
db.session.commit()
_cli(dummy_config) _cli(dummy_config, app)
with db.session_scope() as session: with app.app_context():
budget_count = session.query(model.BudgetEntry.year).count() budget_count = db.session.query(model.BudgetEntry.year).count()
assert budget_count assert budget_count
funding_source_count = session.query(model.FundingSource.year).count() funding_source_count = db.session.query(model.FundingSource.year).count()
assert funding_source_count assert funding_source_count
charging_structure_count = session.query(model.ChargingStructure.year).count() charging_structure_count = db.session.query(model.ChargingStructure.year).count()
assert charging_structure_count assert charging_structure_count
staff_data = session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all() staff_data = db.session.query(model.NrenStaff).order_by(model.NrenStaff.year.asc()).all()
# data should only be saved for the NRENs we have saved in the database # data should only be saved for the NRENs we have saved in the database
staff_data_nrens = set([staff.nren.name for staff in staff_data]) staff_data_nrens = set([staff.nren.name for staff in staff_data])
...@@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config): ...@@ -69,7 +70,7 @@ def test_publisher(client, mocker, dummy_config):
assert kifu_data[5].technical_fte == 133 assert kifu_data[5].technical_fte == 133
assert kifu_data[5].non_technical_fte == 45 assert kifu_data[5].non_technical_fte == 45
ecproject_data = session.query(model.ECProject).all() ecproject_data = db.session.query(model.ECProject).all()
# test a couple of random entries # test a couple of random entries
surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017] surf2017 = [x for x in ecproject_data if x.nren.name == 'SURF' and x.year == 2017]
assert len(surf2017) == 1 assert len(surf2017) == 1
...@@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config): ...@@ -83,7 +84,7 @@ def test_publisher(client, mocker, dummy_config):
assert len(kifu2019) == 4 assert len(kifu2019) == 4
assert kifu2019[3].project == 'SuperHeroes for Science' assert kifu2019[3].project == 'SuperHeroes for Science'
parent_data = session.query(model.ParentOrganization).all() parent_data = db.session.query(model.ParentOrganization).all()
# test a random entry # test a random entry
asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021] asnet2021 = [x for x in parent_data if x.nren.name == 'ASNET-AM' and x.year == 2021]
assert len(asnet2021) == 1 assert len(asnet2021) == 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment