Code source de geonature.core.imports.checks.sql.utils

from sqlalchemy import func
from sqlalchemy.sql.expression import select, update, insert, literal
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import array_agg, aggregate_order_by

from geonature.utils.env import db

from geonature.core.imports.models import (
    ImportUserError,
    ImportUserErrorType,
)
from geonature.core.imports.utils import generated_fields
import pandas as pd


__all__ = ["get_duplicates_query", "report_erroneous_rows"]


[docs] def get_duplicates_query(imprt, dest_field, whereclause=sa.true()): transient_table = imprt.destination.get_transient_table() whereclause = sa.and_( transient_table.c.id_import == imprt.id_import, whereclause, ) partitions = ( select( array_agg(transient_table.c.line_no) .over( partition_by=dest_field, ) .label("duplicate_lines") ) .where(whereclause) .alias("partitions") ) duplicates = ( select([func.unnest(partitions.c.duplicate_lines).label("lines")]) .where(func.array_length(partitions.c.duplicate_lines, 1) > 1) .distinct("lines") .alias("duplicates") ) return duplicates
[docs] def report_erroneous_rows( imprt, entity, error_type, error_column, whereclause, level_validity_mapping={"ERROR": False}, ): """ This function report errors where whereclause in true. But the function also set validity column to False for errors with ERROR level. Warning: level of error "ERROR", the entity must be defined level_validity_mapping may be used to override default behavior: - level does not exist in dict: row validity is untouched - level exists in dict: row validity is set accordingly: - False: row is marked as erroneous - None: row is marked as should not be imported """ transient_table = imprt.destination.get_transient_table() error_type = ImportUserErrorType.query.filter_by(name=error_type).one() error_column = generated_fields.get(error_column, error_column) error_column = imprt.fieldmapping.get(error_column, error_column) if error_type.level in level_validity_mapping: assert entity is not None cte = ( update(transient_table) .values( { transient_table.c[entity.validity_column]: level_validity_mapping[ error_type.level ], } ) .where(transient_table.c.id_import == imprt.id_import) .where(whereclause) .returning(transient_table.c.line_no) .cte("cte") ) else: cte = ( select(transient_table.c.line_no) .where(transient_table.c.id_import == imprt.id_import) .where(whereclause) .cte("cte") ) insert_args = { ImportUserError.id_import: literal(imprt.id_import).label("id_import"), ImportUserError.id_type: literal(error_type.pk).label("id_type"), ImportUserError.rows: array_agg(aggregate_order_by(cte.c.line_no, cte.c.line_no)).label( "rows" ), ImportUserError.column: literal(error_column).label("error_column"), } if entity is not None: insert_args.update( { ImportUserError.id_entity: literal(entity.id_entity).label("id_entity"), } ) # Create the final insert statement error_select = select(insert_args.values()).alias("error") stmt = insert(ImportUserError).from_select( names=insert_args.keys(), select=(select(error_select).where(error_select.c.rows != None)), ) db.session.execute(stmt)
def print_transient_table(imprt, columns=None): trans_table = imprt.destination.get_transient_table() res = db.session.execute( sa.select(*([trans_table.c[col] for col in columns] if columns else [trans_table])) .where(imprt.id_import == trans_table.c.id_import) .order_by(trans_table.c.line_no) ).all() print(pd.DataFrame(res, columns=columns).to_string())