Source code for aodncore.pipeline.db

import csv
import psycopg2
from psycopg2 import sql, extras
import yaml

from .exceptions import InvalidSQLConnectionError, InvalidSQLTransactionError, MissingFileError
from ..util import find_file, is_nonstring_iterable
from ..table import get_field_type, get_tableschema_descriptor

__all__ = [
    'DatabaseInteractions'
]


[docs]class DatabaseInteractions(object): """Database connection object. This class should be instantiated via the 'with DatabaseInteractions() as...' method, so the __enter__ and __exit__ functions will be correctly implemented. """ # private methods def __init__(self, config, schema_base_path, logger): self._conn = None self._cur = None self.config = config self._logger = logger self.schema_base_path = schema_base_path self.status = 'initiated' def __enter__(self): # Call database connection method and then create a cursor self._conn = self.__connect() self._cur = self._conn.cursor() self.status = 'connected' return self def __exit__(self, exc_type, exc_val, exc_tb): # Check for errors and roll back changes if they exist, otherwise commit changes # Finally close the cursor and the connection if exc_type: self._logger.info("Rolling back changes") self._conn.rollback() self.status = 'rolled_back' else: self._logger.info("Committing changes") self._conn.commit() self.status = 'committed' self._cur.close() self._conn.close() def __connect(self): """Connect to the PostgreSQL database server. :return: The database connection. """ params = self.config try: # use ssl if it is enabled on the server return psycopg2.connect(sslmode='prefer', **params) except Exception as error: raise InvalidSQLConnectionError(error) def __exec(self, statement): """Execute an SQL transaction using the instance cursor. :param statement: A string containing an SQL statement or multiple statements separated by semi-colons. """ try: self._cur.execute(sql.SQL(statement)) self._logger.sysinfo(self._cur.query) self._logger.sysinfo(self._cur.statusmessage) except Exception as error: raise InvalidSQLTransactionError(error) def __query(self, statement, many=False): """Execute an SQL statement using the instance cursor. :param statement: A string containing an SQL statement or multiple statements separated by semi-colons. :param many: Boolean value indicating whether to return 1 or many records. :return: A Dict of one record (the first) if many is False, otherwise a List of Dicts containing all records """ dict_cur = self._conn.cursor(cursor_factory=extras.DictCursor) try: dict_cur.execute(sql.SQL(statement)) self._logger.sysinfo(dict_cur.query) self._logger.sysinfo(dict_cur.statusmessage) return dict_cur.fetchall() if many else dict_cur.fetchone() except Exception as error: raise InvalidSQLTransactionError(error) def __exec_copy(self, statement, file): """Execute a COPY FROM statement using the instance cursor. :param statement: A string containing a COPY statement. :param file: A readable file-like object. :return: None """ try: self._cur.copy_expert(sql.SQL(statement), file) self._logger.sysinfo(self._cur.query) except Exception as error: raise InvalidSQLTransactionError(error) # public methods
[docs] def compare_schemas(self): """Placeholder for possible future implementation of schema version checking :return: boolean - True if schemas match, else False """ self._logger.info("Compare schema not yet implemented...") return True
[docs] def truncate_table(self, step): """Truncate the specified table. :param step: A dict containing 'name' and 'type' (at least) keys - step.name is the name of the database object - step.type is the type of database object - the database transaction will only be performed if type = 'table' """ if step['type'] == 'table': self.__exec("TRUNCATE TABLE {}".format(step['name']))
[docs] def refresh_materialized_view(self, step): """Refresh the specified materialized view. :param step: A dict containing 'name' and 'type' (at least) keys - step.name is the name of the database object - step.type is the type of database object - the database transaction will only be performed if type = 'materialized view' """ if step['type'] == 'materialized view': self.__exec("REFRESH MATERIALIZED VIEW {}".format(step['name']))
[docs] def drop_object(self, step): """Drop the specified database object. The database transaction uses the IF EXISTS parameter, so will not error if the database object does not exist; and also the CASCADE parameter meaning that a previous call to this method may have already cascaded to the current database object. :param step: A dict containing 'name' and 'type' (at least) keys - step.name is the name of the database object - step.type is the type of database object """ self._logger.info("Dropping {type} {name}".format(**step)) stmt = "DROP {type} IF EXISTS {name} CASCADE".format(**step) self.__exec(stmt)
[docs] def load_data_from_csv(self, step): """Function to read a csv file prior to loading into the specified table. Currently uses the utf-8 encoding to read the csv file, and reads the headings into the COPY FROM statement - the latter may not be necessary as it is assumed that the file has been validated in a previous handler step. :param step: A dict containing 'name' and 'local_path' (at least) keys - step.name is the name of the target table - step.local_path is the full path to the source file (csv) """ fn = step.get('local_path', '') if fn: try: with open(fn, encoding="utf-8") as f: self._logger.info("Loading data from {}".format(fn)) stmt = "COPY {} FROM STDIN WITH HEADER CSV".format(step['name']) self.__exec_copy(stmt, f) except FileNotFoundError as e: raise MissingFileError(e)
[docs] def execute_sql_file(self, step): """Function to read an SQL file prior to executing against the database. :param step: A dict containing 'name' (at least) key - step.name is the name used as part of the match regular expression """ fn = find_file(self.schema_base_path, r'{name}(\..*)?\.sql'.format(name=step['name'])) if fn: self._logger.info("Executing sql from {}".format(fn)) with open(fn) as stream: self.__exec(stream.read())
[docs] def create_table_from_yaml_file(self, step): """Function to read an yaml file and use it to build a CREATE TABLE script for execution against the database. :param step: A dict containing 'name' and 'type' (at least) keys - step.name is the name used as part of the match regular expression - step.type is the type of database object. Type should always be table in this context """ fn = find_file(self.schema_base_path, r'{name}(\..*)?\.(?:yml|yaml)'.format(name=step['name'])) if fn and step['type'] == 'table': self._logger.info("Creating {type} {name} from {fn}".format(fn=fn, **step)) with open(fn) as stream: schema = get_tableschema_descriptor(yaml.safe_load(stream), 'schema') columns = [] for f in schema['fields']: f['type'] = get_field_type(f['type']) columns.append('{name} {type}'.format(**f)) pk = schema.get('primaryKey') if pk: pk = pk if is_nonstring_iterable(pk) else [pk] columns.append("PRIMARY KEY ({})".format(','.join(pk))) self.__exec('CREATE TABLE {} ({})'.format(step['name'], ','.join(columns)))
[docs] def get_spatial_extent(self, db_schema, table, column, resolution): """Function to retrieve spatial data from the database. :param db_schema: string containing name of schema :param table: string containing name of table :param column: string containing name of column :param resolution: int as resolution of polygons """ self._logger.info("Retrieving spatial extent") query = "SELECT BoundingPolygonAsGml3('{schema}','{table}','{column}',{resolution})".format( schema=db_schema, table=table, column=column, resolution=resolution ) return self.__query(query)
[docs] def get_temporal_extent(self, table, column): """Function to retrieve temporal data from the database. :param table: string containing name of table :param column: string containing name of column """ self._logger.info("Retrieving temporal extent") query = """ SELECT TO_CHAR(timezone('UTC'::text, MIN("{column}")), 'YYYY-MM-DDThh:mm:ss') as min_value, TO_CHAR(timezone('UTC'::text, MAX("{column}")), 'YYYY-MM-DDThh:mm:ss') as max_value FROM {table} """.format( table=table, column=column ) return self.__query(query)
[docs] def get_vertical_extent(self, table, column): """Function to retrieve vertical data from the database. :param table: string containing name of table :param column: string containing name of column """ self._logger.info("Retrieving vertical extent") query = """ SELECT MIN("{column}") as "min_value", MAX("{column}") as "max_value" FROM {table} """.format( table=table, column=column ) return self.__query(query)