import os
from io import BytesIO, TextIOWrapper
import csv
import json
from enum import IntEnum
from datetime import datetime, timedelta
from typing import IO, Any, Dict, Iterable, List, Optional, Set, Tuple

from flask import current_app, render_template
import sqlalchemy as sa
from sqlalchemy import func, select, delete
from chardet.universaldetector import UniversalDetector
from sqlalchemy.sql.expression import select, insert
import pandas as pd
import numpy as np
from sqlalchemy.dialects.postgresql import insert as pg_insert
from werkzeug.exceptions import BadRequest
from geonature.utils.env import db
from weasyprint import HTML

from geonature.utils.sentry import start_sentry_child
from geonature.core.imports.models import Entity, ImportUserError, BibFields, TImports

[docs] class ImportStep(IntEnum):
[docs] UPLOAD = 1
[docs] DECODE = 2
[docs] LOAD = 3
[docs] PREPARE = 4
[docs] IMPORT = 5
[docs] generated_fields = { "datetime_min": "date_min", "datetime_max": "date_max", }
[docs] def clean_import(imprt: TImports, step: ImportStep) -> None: """ Clean an import at a specific step. Parameters ---------- imprt : TImports The import to clean. step : ImportStep The step at which to clean the import. """ imprt.task_id = None if step <= ImportStep.UPLOAD: # source_file will be necessary overwritten # source_count will be necessary overwritten pass if step <= ImportStep.DECODE: imprt.columns = None if step <= ImportStep.LOAD: transient_table = imprt.destination.get_transient_table() stmt = delete(transient_table).where(transient_table.c.id_import == imprt.id_import) with start_sentry_child(op="task", description="clean transient data"): db.session.execute(stmt) imprt.source_count = None imprt.loaded = False if step <= ImportStep.PREPARE: with start_sentry_child(op="task", description="clean errors"): ImportUserError.query.filter(ImportUserError.imprt == imprt).delete() imprt.erroneous_rows = None imprt.processed = False if step <= ImportStep.IMPORT: imprt.date_end_import = None imprt.statistics = {"import_count": None} imprt.destination.actions.remove_data_from_destination(imprt)
[docs] def get_file_size(file_: IO) -> int: """ Get the size of a file in bytes. Parameters ---------- file_ : IO The file to get the size of. Returns ------- int The size of the file in bytes. """ current_position = file_.tell(), os.SEEK_END) size = file_.tell() return size
[docs] def detect_encoding(file_: IO) -> str: """ Detects the encoding of a file. Parameters ---------- file_ : IO The file to detect the encoding of. Returns ------- str The detected encoding. If no encoding is detected, then "UTF-8" is returned. """ begin = max_duration = timedelta( seconds=current_app.config["IMPORT"]["MAX_ENCODING_DETECTION_DURATION"] ) position = file_.tell() detector = UniversalDetector() for row in file_: detector.feed(row) if detector.done or ( - begin) > max_duration: break detector.close() return detector.result["encoding"] or "UTF-8"
[docs] def detect_separator(file_: IO, encoding: str) -> Optional[str]: """ Detects the delimiter used in a CSV file. Parameters ---------- file_ : IO The file object to detect the delimiter of. encoding : str The encoding of the file. Returns ------- Optional[str] The delimiter used in the file, or None if no delimiter is detected. Raises ------ BadRequest If the file starts with no column names. """ position = file_.tell() try: sample = file_.readline().decode(encoding) except UnicodeDecodeError: # encoding is likely to be detected encoding, so prompt to errors return None if sample == "\n": # files that do not start with column names raise BadRequest("File must start with columns") dialect = csv.Sniffer().sniff(sample) return dialect.delimiter
[docs] def preprocess_value(dataframe: pd.DataFrame, field: BibFields, source_col: str) -> pd.Series: """ Preprocesses values in a DataFrame depending if the field contains multiple values (e.g. additional_data) or not. Parameters ---------- dataframe : pd.DataFrame The DataFrame to preprocess the value of. field : BibFields The field to preprocess. source_col : str The column to preprocess. Returns ------- pd.Series The preprocessed value. """ def build_additional_data(columns: dict): result = {} for key, value in columns.items(): if value is None: continue try: value = json.loads(value) assert type(value) is dict except Exception: value = {key: value} result.update(value) return result if field.multi: assert type(source_col) is list col = dataframe[source_col].apply(build_additional_data, axis=1) else: col = dataframe[source_col] return col
[docs] def insert_import_data_in_transient_table(imprt: TImports) -> int: """ Insert the data from the import file into the transient table. Parameters ---------- imprt : TImports current import Returns ------- int The last line number of the import file that was inserted. """ transient_table = imprt.destination.get_transient_table() columns = imprt.columns fieldmapping, used_columns = build_fieldmapping(imprt, columns) extra_columns = set(columns) - set(used_columns) csvfile = TextIOWrapper(BytesIO(imprt.source_file), encoding=imprt.encoding) reader = pd.read_csv( csvfile, delimiter=imprt.separator, header=0, names=imprt.columns, index_col=False, dtype="str", na_filter=False, iterator=True, chunksize=10000, ) for chunk in reader: chunk.replace({"": None}, inplace=True) data = { "id_import": np.full(len(chunk), imprt.id_import), "line_no": 1 + 1 + chunk.index, # header + start line_no at 1 instead of 0 } data.update( { dest_field: preprocess_value(chunk, source_field["field"], source_field["value"]) for dest_field, source_field in fieldmapping.items() } ) # XXX keep extra_fields in t_imports_synthese? or add config argument? if extra_columns and "extra_fields" in transient_table.c: data.update( { "extra_fields": chunk[list(extra_columns)].apply( lambda cols: {k: v for k, v in cols.items()}, axis=1 ), } ) df = pd.DataFrame(data) imprt.destination.actions.preprocess_transient_data(imprt, df) records = df.to_dict(orient="records") db.session.execute(insert(transient_table).values(records)) return 1 + chunk.index[-1] # +1 because chunk.index start at 0
[docs] def build_fieldmapping( imprt: TImports, columns: Iterable[Any] ) -> Tuple[Dict[str, Dict[str, Any]], List[str]]: """ Build a dictionary that maps the source column names to the corresponding field and values. Parameters ---------- imprt : TImports The import to check. columns : Iterable[Any] The columns to map. Returns ------- tuple A tuple containing a dictionary that maps the source column names to the corresponding field and values, and a list of the used columns. """ fields = BibFields.query.filter_by(destination=imprt.destination, autogenerated=False).all() fieldmapping = {} used_columns = [] for field in fields: if field.name_field in imprt.fieldmapping: if field.multi: correct = list(set(columns) & set(imprt.fieldmapping[field.name_field])) if len(correct) > 0: fieldmapping[field.source_column] = { "value": correct, "field": field, } used_columns.extend(correct) else: if imprt.fieldmapping[field.name_field] in columns: fieldmapping[field.source_column] = { "value": imprt.fieldmapping[field.name_field], "field": field, } used_columns.append(imprt.fieldmapping[field.name_field]) return fieldmapping, used_columns
[docs] def load_transient_data_in_dataframe( imprt: TImports, entity: Entity, source_cols: list, offset: int = None, limit: int = None ): """ Load data from the transient table into a pandas dataframe. Parameters ---------- imprt : TImports The import to load. entity : Entity The entity to load. source_cols : list The columns to load from the transient table. offset : int, optional The number of rows to skip. limit : int, optional The maximum number of rows to load. Returns ------- pandas.DataFrame The dataframe containing the loaded data. """ transient_table = imprt.destination.get_transient_table() source_cols = ["id_import", "line_no", entity.validity_column] + source_cols stmt = ( select(*[transient_table.c[col] for col in source_cols]) .where( transient_table.c.id_import == imprt.id_import, transient_table.c[entity.validity_column].isnot(None), ) .order_by(transient_table.c.line_no) ) if offset is not None: stmt = stmt.offset(offset) if limit is not None: stmt = stmt.limit(limit) records = db.session.execute(stmt).fetchall() df = pd.DataFrame.from_records( records, columns=source_cols, ).astype("object") return df
[docs] def update_transient_data_from_dataframe( imprt: TImports, entity: Entity, updated_cols: Set[str], dataframe: pd.DataFrame ): """ Update the transient table with the data from the dataframe. Parameters ---------- imprt : TImports The import to update. entity : Entity The entity to update. updated_cols : list The columns to update. df : pandas.DataFrame The dataframe to use for the update. Notes ----- The dataframe must have the columns 'id_import' and 'line_no'. """ if not updated_cols: return transient_table = imprt.destination.get_transient_table() updated_cols = ["id_import", "line_no"] + list(updated_cols) dataframe.replace({np.nan: None}, inplace=True) records = dataframe[updated_cols].to_dict(orient="records") if not records: return insert_stmt = pg_insert(transient_table) insert_stmt = insert_stmt.values(records).on_conflict_do_update( index_elements=updated_cols[:2], set_={col: insert_stmt.excluded[col] for col in updated_cols[2:]}, ) db.session.execute(insert_stmt)
[docs] def generate_pdf_from_template(template: str, data: Any) -> bytes: """ Generate a PDF document from a template. Parameters ---------- template : str The name of the template file to use. data : Any The data to pass to the template. Returns ------- bytes The PDF document as bytes. """ template_rendered = render_template(template, data=data) html_file = HTML( string=template_rendered, base_url=current_app.config["API_ENDPOINT"], encoding="utf-8", ) return html_file.write_pdf()
[docs] def get_mapping_data(import_: TImports, entity: Entity): """ Get the mapping data for a given import and entity. Parameters ---------- import_ : TImports The import to get the mapping data for. entity : Entity The entity to get the mapping data for. Returns ------- fields : dict A dictionary with the all fields associated with an entity (check gn_imports.bib_fields). This dictionary is keyed by the name field and valued by the corresponding BibField object. selected_fields : dict In the same format as fields, but only the fields contained in the mapping. source_cols : list List of fields to load in dataframe, mainly source column of non-nomenclature fields """ fields = {ef.field.name_field: ef.field for ef in entity.fields} selected_fields = { field_name: fields[field_name] for field_name, source_field in import_.fieldmapping.items() if source_field in import_.columns and field_name in fields } source_cols = set() for field in selected_fields.values(): # load source col of all non-nomenclature fields if field.mnemonique is None and field.source_field is not None: source_cols |= {field.source_field} # load source col of all mandatory fields if field.mandatory: source_cols |= {field.source_field} # load all selected field used in conditions conditions = set(field.mandatory_conditions or {}) | set(field.optional_conditions or {}) if conditions: source_cols |= set( [selected_fields[f].source_field for f in conditions if f in selected_fields] ) return fields, selected_fields, list(source_cols)
[docs] def get_required(import_: TImports, entity: Entity): fields, selected_fields, _ = get_mapping_data(import_, entity) required_columns = set([]) for field, bib_field in fields.items(): if bib_field.mandatory and field in selected_fields: required_columns.add(field) for field, bib_field in selected_fields.items(): if all([field_name in selected_fields for field_name in bib_field.required_conditions]): required_columns.add(field) for field, bib_field in selected_fields.items(): if all([field_name in selected_fields for field_name in bib_field.optional_conditions]): required_columns.remove(field) return required_columns
[docs] def compute_bounding_box( imprt: TImports, geom_entity_code: str, geom_4326_field_name: str, *, child_entity_code: str = None, transient_where_clause=None, destination_where_clause=None ): """ Compute the bounding box of an entity with a geometry in the given import, based on its entities tree (e.g. Station -> Habitat; Site -> Visite -> Observation). Parameters ---------- imprt : TImports The import to get the bounding box of. geom_entity_code : str The code of the entity that contains the geometry. geom_4326_field_name : str The name of the column in the geom entity table that contains the geometry. child_entity_code : str, optional The code of the last child entity (of the geom entity) to consider when computing the bounding box. If not given, bounding-box will be computed only on the geom entity. transient_where_clause : sqlalchemy.sql.elements.BooleanClauseList, optional A where clause to apply to the query when computing the bounding box of a processed import. destination_where_clause : sqlalchemy.sql.elements.BooleanClauseList, optional A where clause to apply to the query when computing the bounding box of a finished import. Returns ------- valid_bbox : dict The bounding box of all entities in the given import, in GeoJSON format. """ def get_entities_hierarchy(parent_entity, child_entity) -> Iterable[Entity]: """ Get all entities between the parent_entity and the child_entity, in order from parent to child. Parameters ---------- parent_entity : Entity The parent entity. child_entity : Entity The child entity. Yields ------ Entity The entities between the parent and child entity, in order from parent to child. """ current = child_entity while current != parent_entity and current: yield current current = current.parent parent_entity: Entity = db.session.execute(, code=geom_entity_code) ).scalar_one() parent_table = parent_entity.get_destination_table() transient_table = imprt.destination.get_transient_table() # If only one entity in an import destination if not child_entity_code: if imprt.date_end_import: table = parent_table elif imprt.processed: table = transient_table else: return None assert geom_4326_field_name in table.columns query = select(func.ST_AsGeojson(func.ST_Extent(table.c[geom_4326_field_name]))).where( table.c.id_import == imprt.id_import ) (valid_bbox,) = db.session.execute(query).fetchone() if valid_bbox: return json.loads(valid_bbox) return None child_entity: Entity = db.session.execute(, code=child_entity_code) ).scalar_one() # When the import is finished if imprt.date_end_import: entities = list(get_entities_hierarchy(parent_entity, child_entity)) entities.reverse() # Geom entity linked to the current import based on their children (or grand-children, etc.) query =[geom_4326_field_name].label("geom")) or_where_clause = [] for entity in entities: ent_table = entity.get_destination_table() query = query.join(ent_table) or_where_clause.append(ent_table.c.id_import == imprt.id_import) query.where(sa.or_(*or_where_clause)) # Merge with geom entity with an id_import equal to the current import query = sa.union( query,[geom_4326_field_name].label("geom")).where( parent_table.c.id_import == imprt.id_import ), ).subquery() if destination_where_clause: query = query.where(destination_where_clause) # When the import is processed (data are check and prepared but not loaded in the destination) elif imprt.processed: transient_table = imprt.destination.get_transient_table() # query all existing entities'geom query =[geom_4326_field_name].label("geom")).where( transient_table.c.id_import == imprt.id_import, # basic ! sa.and_( # Check that no parent entity is invalid transient_table.c[entity.validity_column] != False for entity in get_entities_hierarchy(parent_entity, child_entity) ), transient_table.c[parent_entity.validity_column] == None, # Check that parent entity already exists in DB transient_table.c[parent_entity.unique_column.dest_field] == parent_table.c[parent_entity.unique_column.dest_field], # for join ) # UNION between existing geom in the DB and new entities'geom in the transient table query = sa.union( query,[geom_4326_field_name].label("geom")).where( sa.and_( transient_table.c.id_import == imprt.id_import, transient_table.c[parent_entity.validity_column] == True, ) ), ).subquery() if transient_where_clause: query = query.where(transient_where_clause) else: return None # Compute the bounding box using geom entities returned by the query statement = select(func.ST_AsGeojson(func.ST_Extent(query.c.geom))) (valid_bbox,) = db.session.execute(statement).fetchone() # If a valid bbox is found, return it if valid_bbox: return json.loads(valid_bbox)