from sqlalchemy import func
from sqlalchemy.sql.expression import select, update, insert, literal
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import array_agg, aggregate_order_by
from geonature.utils.env import db
from geonature.core.imports.models import (
ImportUserError,
ImportUserErrorType,
TImports,
)
from geonature.core.imports.utils import generated_fields
import pandas as pd
__all__ = [
"get_duplicates_query",
"report_erroneous_rows",
"print_transient_table",
"transient_table_to_dataframe",
]
[docs]
def get_duplicates_query(imprt, dest_field, whereclause=sa.true()):
transient_table = imprt.destination.get_transient_table()
whereclause = sa.and_(
transient_table.c.id_import == imprt.id_import,
whereclause,
)
partitions = (
select(
array_agg(transient_table.c.line_no)
.over(
partition_by=dest_field,
)
.label("duplicate_lines")
)
.where(whereclause)
.alias("partitions")
)
duplicates = (
select([func.unnest(partitions.c.duplicate_lines).label("lines")])
.where(func.array_length(partitions.c.duplicate_lines, 1) > 1)
.distinct("lines")
.alias("duplicates")
)
return duplicates
[docs]
def report_erroneous_rows(
imprt,
entity,
error_type,
error_column,
whereclause,
error_comment=None,
level_validity_mapping={"ERROR": False},
):
"""
Report erroneous rows in a transient table and update validity of rows based on error level.
This function reports errors found in imported data based on a WHERE clause, and updates the validity
column of affected rows if the error level is specified in `level_validity_mapping`.
By default, errors with the level "ERROR" mark the row as invalid.
Parameters
----------
imprt : TImport
The current import object.
entity : Entity
The entity associated with the error. Must be defined if the error level is "ERROR".
error_type : str
Type of error to report. Must correspond to a record in `ImportUserErrorType`.
error_column : str
Name of the column where the error is detected. Can be mapped via `imprt.fieldmapping`.
whereclause : sqlalchemy.sql.elements.ClauseElement
SQL clause defining the rows affected by the error.
error_comment : str, optional
Optional comment to include extra explanation to describe the error in the current import context.
level_validity_mapping : dict, optional
Dictionary mapping error levels to validity values:
- If the level is not in the dictionary, the row validity remains unchanged.
- If the level is present, validity is set according to the associated value:
- `False`: The row is marked as erroneous.
- `None`: The row is marked as should not be imported.
By default, only the level "ERROR" is mapped to `False`.
Raises
------
AssertionError
If `entity` is not defined for an error of level "ERROR".
Examples
--------
>>> # Example usage for reporting an "ERROR" level error
>>> report_erroneous_rows(
... imprt=my_import,
... entity=my_entity,
... error_type="MISSING_VALUE",
... error_column="customer_name",
... whereclause=(transient_table.c.customer_name == None),
... error_comment="Customer name missing",
... )
"""
transient_table = imprt.destination.get_transient_table()
error_type = ImportUserErrorType.query.filter_by(name=error_type).one()
error_column = generated_fields.get(error_column, error_column)
error_column = imprt.fieldmapping.get(error_column, {}).get("column_src", error_column)
if error_type.level in level_validity_mapping:
assert entity is not None
cte = (
update(transient_table)
.values(
{
transient_table.c[entity.validity_column]: level_validity_mapping[
error_type.level
],
}
)
.where(transient_table.c.id_import == imprt.id_import)
.where(whereclause)
.returning(transient_table.c.line_no)
.cte("cte")
)
else:
cte = (
select(transient_table.c.line_no)
.where(transient_table.c.id_import == imprt.id_import)
.where(whereclause)
.cte("cte")
)
insert_args = {
ImportUserError.id_import: literal(imprt.id_import).label("id_import"),
ImportUserError.id_type: literal(error_type.pk).label("id_type"),
ImportUserError.rows: array_agg(aggregate_order_by(cte.c.line_no, cte.c.line_no)).label(
"rows"
),
ImportUserError.column: literal(error_column).label("error_column"),
ImportUserError.comment: literal(error_comment).label("error_comment"),
}
if entity is not None:
insert_args.update(
{
ImportUserError.id_entity: literal(entity.id_entity).label("id_entity"),
}
)
# Create the final insert statement
error_select = select(insert_args.values()).alias("error")
stmt = insert(ImportUserError).from_select(
names=insert_args.keys(),
select=(select(error_select).where(error_select.c.rows != None)),
)
db.session.execute(stmt)
[docs]
def print_transient_table(imprt: TImports, columns=None):
"""
Print the content of the transient table for a given import.
Parameters
----------
imprt : TImports
The import to print.
columns : list, optional
The columns to print. If None, all columns are printed.
"""
print(transient_table_to_dataframe(imprt, columns).to_string())
[docs]
def transient_table_to_dataframe(imprt: TImports, columns=None) -> pd.DataFrame:
"""
Get the content of the transient table for a given import as a pandas DataFrame.
Parameters
----------
imprt : TImports
The import to get.
columns : list, optional
The columns to include in the DataFrame. If None, all columns are included.
Returns
-------
pd.DataFrame
The content of the transient table as a DataFrame.
"""
trans_table = imprt.destination.get_transient_table()
res = db.session.execute(
sa.select(*([trans_table.c[col] for col in columns] if columns else [trans_table]))
.where(imprt.id_import == trans_table.c.id_import)
.order_by(trans_table.c.line_no)
).all()
return pd.DataFrame(res, columns=columns)