Code source de src.utils_flask_sqla.generic

from itertools import chain
from warnings import warn

from dateutil import parser
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.types import Boolean, Date, DateTime, Integer, Numeric

from .errors import UtilsSqlaError


[docs] def testDataType(value, sqlType, paramName): """ Test the type of a filter #TODO: antipatern: should raise something which can be exect by the function which use it # and not return the error """ if sqlType == Integer or isinstance(sqlType, (Integer)): try: int(value) except Exception as e: return "{0} must be an integer".format(paramName) if sqlType == Numeric or isinstance(sqlType, (Numeric)): try: float(value) except Exception as e: return "{0} must be an float (decimal separator .)".format(paramName) elif sqlType == DateTime or isinstance(sqlType, (Date, DateTime)): try: dt = parser.parse(value) except Exception as e: return "{0} must be an date (yyyy-mm-dd)".format(paramName) return None
[docs] def test_type_and_generate_query(param_name, value, model, q): """ Generate a query with the filter given, checking the params is the good type of the columns, and formmatting it Params: - param_name (str): the name of the column - value (any): the value of the filter - model (SQLA model) - q (SQLA Query) """ # check the attribut exist in the model try: col = getattr(model, param_name) except AttributeError as error: raise UtilsSqlaError(str(error)) sql_type = col.type if sql_type == Integer or isinstance(sql_type, (Integer)): try: return q.where(col == int(value)) except Exception as e: raise UtilsSqlaError("{0} must be an integer".format(param_name)) if sql_type == Numeric or isinstance(sql_type, (Numeric)): try: return q.where(col == float(value)) except Exception as e: raise UtilsSqlaError("{0} must be an float (decimal separator .)".format(param_name)) if sql_type == DateTime or isinstance(sql_type, (Date, DateTime)): try: return q.where(col == parser.parse(value)) except Exception as e: raise UtilsSqlaError("{0} must be an date (yyyy-mm-dd)".format(param_name)) if sql_type == Boolean or isinstance(sql_type, Boolean): try: return q.where(col.is_(bool(value))) except Exception: raise UtilsSqlaError("{0} must be a boolean".format(param_name))
""" Liste des types de données sql qui nécessite une sérialisation particulière en @TODO MANQUE FLOAT """
[docs] SERIALIZERS = { "date": lambda x: str(x) if x else None, "datetime": lambda x: str(x) if x else None, "time": lambda x: str(x) if x else None, "timestamp": lambda x: str(x) if x else None, "uuid": lambda x: str(x) if x else None, "numeric": lambda x: str(x) if x else None, }
[docs] class GenericTable: """ Classe permettant de créer à la volée un mapping d'une vue avec la base de données par rétroingénierie """ def __init__(self, tableName, schemaName, engine): """ params: - tableName - schemaName - engine : sqlalchemy instance engine for exemple : DB.engine if DB = Sqlalchemy() """ meta = MetaData(schema=schemaName) meta.reflect(views=True, bind=engine) try: self.tableDef = meta.tables["{}.{}".format(schemaName, tableName)] except KeyError: raise KeyError("table {}.{} doesn't exists".format(schemaName, tableName)) # Mise en place d'un mapping des colonnes en vue d'une sérialisation self.serialize_columns, self.db_cols = self.get_serialized_columns()
[docs] def get_serialized_columns(self, serializers=SERIALIZERS): """ Return a tuple of serialize_columns, and db_cols from the generic table """ regular_serialize = [] db_cols = [] for name, db_col in self.tableDef.columns.items(): if not db_col.type.__class__.__name__ == "Geometry": serialize_attr = ( name, serializers.get(db_col.type.__class__.__name__.lower(), lambda x: x), ) regular_serialize.append(serialize_attr) db_cols.append(db_col) return regular_serialize, db_cols
[docs] def as_dict(self, data, columns=[], fields=[]): fields = list(chain(fields, columns)) if columns: warn( "'columns' argument is deprecated. Please add columns to serialize " "directly in 'fields' argument.", DeprecationWarning, ) if fields: fprops = list(filter(lambda d: d[0] in fields, self.serialize_columns)) else: fprops = self.serialize_columns return {item: _serializer(getattr(data, item)) for item, _serializer in fprops}
[docs] class GenericQuery: """ Classe permettant de manipuler des objets GenericTable params: - DB: sqlalchemy instantce (DB if DB = Sqlalchemy()) - tableName - schemaName - filters: array of filter of the query - engine : sqlalchemy instance engine for exemple : DB.engine if DB = Sqlalchemy() - limit - offset """ def __init__( self, DB, tableName, schemaName, filters=[], limit=100, offset=0, ): self.DB = DB self.tableName = tableName self.schemaName = schemaName self.filters = filters self.limit = limit self.offset = offset self.view = GenericTable(tableName, schemaName, DB.engine)
[docs] def build_query_filters(self, query, parameters): """ Construction des filtres """ for f in parameters: query = self.build_query_filter(query, f, parameters.get(f)) return query
[docs] def build_query_filter(self, query, param_name, param_value): if param_name in self.view.tableDef.columns.keys(): query = query.where(self.view.tableDef.columns[param_name] == param_value) if param_name.startswith("ilike_"): col = self.view.tableDef.columns[param_name[6:]] if col.type.__class__.__name__ == "TEXT": query = query.where(col.ilike("%{}%".format(param_value))) if param_name.startswith("filter_d_"): col = self.view.tableDef.columns[param_name[12:]] col_type = col.type.__class__.__name__ test_type = testDataType(param_value, DateTime, col) and testDataType( param_value, Integer, col ) if test_type: raise UtilsSqlaError(message=test_type) if col_type in ("Date", "DateTime", "TIMESTAMP", "INTEGER"): if param_name.startswith("filter_d_up_"): query = query.where(col >= param_value) if param_name.startswith("filter_d_lo_"): query = query.where(col <= param_value) if param_name.startswith("filter_d_eq_"): query = query.where(col == param_value) if param_name.startswith("filter_n_"): col = self.view.tableDef.columns[param_name[12:]] col_type = col.type.__class__.__name__ test_type = testDataType(param_value, Numeric, col) if test_type: raise UtilsSqlaError(message=test_type) if param_name.startswith("filter_n_up_"): query = query.where(col >= param_value) if param_name.startswith("filter_n_lo_"): query = query.where(col <= param_value) return query
[docs] def build_query_order(self, query, parameters): # Ordonnancement # L'ordonnancement se base actuellement sur une seule colonne # et prend la forme suivante : nom_colonne[:ASC|DESC] if parameters.get("orderby", "").replace(" ", ""): order_by = parameters.get("orderby") col, *sort = order_by.split(":") if col in self.view.tableDef.columns.keys(): ordel_col = getattr(self.view.tableDef.columns, col) if (sort[0:1] or ["ASC"])[0].lower() == "desc": ordel_col = ordel_col.desc() return query.order_by(ordel_col) return query
[docs] def set_limit(self, q): return q.limit(self.limit).offset(self.offset * self.limit)
[docs] def raw_query(self, process_filter=True, with_limit=True): """ Renvoie la requete 'brute' (sans .all) - process_filter: application des filtres (et du sort) - with_limit: application de la limite sur la query """ q = self.DB.session.query(self.view.tableDef) if not process_filter: return q if self.filters: unordered_q = self.build_query_filters(q, self.filters) q = self.build_query_order(unordered_q, self.filters) if self.limit != -1 and with_limit: q = self.set_limit(q) return q
[docs] def query(self): """ Lance la requete et retourne l'objet sqlalchemy """ q = self.DB.session.query(self.view.tableDef) nb_result_without_filter = q.count() q = self.raw_query(process_filter=True, with_limit=False) total_filtered = q.count() if self.filters else nb_result_without_filter data = self.set_limit(q).all() return data, nb_result_without_filter, total_filtered
[docs] def return_query(self): """ Lance la requete (execute self.query()) et retourne les résutats dans un format standard """ data, nb_result_without_filter, nb_results = self.query() results = [self.view.as_dict(d) for d in data] return { "total": nb_result_without_filter, "total_filtered": nb_results, "page": self.offset, "limit": self.limit, "items": results, }
[docs] as_dict = return_query
[docs] def serializeQuery(data, columnDef): rows = [ { c["name"]: getattr(row, c["name"]) for c in columnDef if getattr(row, c["name"]) is not None } for row in data ] return rows
[docs] def serializeQueryOneResult(row, column_def): row = { c["name"]: getattr(row, c["name"]) for c in column_def if getattr(row, c["name"]) is not None } return row
[docs] def serializeQueryTest(data, column_def): rows = list() for row in data: inter = {} for c in column_def: if getattr(row, c["name"]) is not None: if isinstance(c["type"], (Date, DateTime, UUID)): inter[c["name"]] = str(getattr(row, c["name"])) elif isinstance(c["type"], Numeric): inter[c["name"]] = float(getattr(row, c["name"])) # elif not isinstance(c["type"], Geometry): # inter[c["name"]] = getattr(row, c["name"]) rows.append(inter) return rows