#!/usr/bin/env python3 # == Rapid Develop Macros in LibreOffice == # ~ This file is part of ZAZ. # ~ ZAZ is free software: you can redistribute it and/or modify # ~ it under the terms of the GNU General Public License as published by # ~ the Free Software Foundation, either version 3 of the License, or # ~ (at your option) any later version. # ~ ZAZ is distributed in the hope that it will be useful, # ~ but WITHOUT ANY WARRANTY; without even the implied warranty of # ~ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # ~ GNU General Public License for more details. # ~ You should have received a copy of the GNU General Public License # ~ along with ZAZ. If not, see . class BaseDateField(DateField): def db_value(self, value): return _date_to_struct(value) def python_value(self, value): return _struct_to_date(value) class BaseTimeField(TimeField): def db_value(self, value): return _date_to_struct(value) def python_value(self, value): return _struct_to_date(value) class BaseDateTimeField(DateTimeField): def db_value(self, value): return _date_to_struct(value) def python_value(self, value): return _struct_to_date(value) class FirebirdDatabase(Database): field_types = {'BOOL': 'BOOLEAN', 'DATETIME': 'TIMESTAMP'} def __init__(self, database, **kwargs): super().__init__(database, **kwargs) self._db = database def _connect(self): return self._db def create_tables(self, models, **options): options['safe'] = False tables = self._db.tables models = [m for m in models if not m.__name__.lower() in tables] super().create_tables(models, **options) def execute_sql(self, sql, params=None, commit=True): with __exception_wrapper__: cursor = self._db.execute(sql, params) return cursor def last_insert_id(self, cursor, query_type=None): # ~ debug('LAST_ID', cursor) return 0 def rows_affected(self, cursor): return self._db.rows_affected @property def path(self): return self._db.path class BaseRow: pass class BaseQuery(object): PY_TYPES = { 'SQL_LONG': 'getLong', 'SQL_VARYING': 'getString', 'SQL_FLOAT': 'getFloat', 'SQL_BOOLEAN': 'getBoolean', 'SQL_TYPE_DATE': 'getDate', 'SQL_TYPE_TIME': 'getTime', 'SQL_TIMESTAMP': 'getTimestamp', } TYPES_DATE = ('SQL_TYPE_DATE', 'SQL_TYPE_TIME', 'SQL_TIMESTAMP') def __init__(self, query): self._query = query self._meta = query.MetaData self._cols = self._meta.ColumnCount self._names = query.Columns.ElementNames self._data = self._get_data() def __getitem__(self, index): return self._data[index] def __iter__(self): self._index = 0 return self def __next__(self): try: row = self._data[self._index] except IndexError: raise StopIteration self._index += 1 return row def _to_python(self, index): type_field = self._meta.getColumnTypeName(index) value = getattr(self._query, self.PY_TYPES[type_field])(index) if type_field in self.TYPES_DATE: value = _struct_to_date(value) return value def _get_row(self): row = BaseRow() for i in range(1, self._cols + 1): column_name = self._meta.getColumnName(i) value = self._to_python(i) setattr(row, column_name, value) return row def _get_data(self): data = [] while self._query.next(): row = self._get_row() data.append(row) return data @property def tuples(self): data = [tuple(r.__dict__.values()) for r in self._data] return tuple(data) @property def dicts(self): data = [r.__dict__ for r in self._data] return tuple(data) class LOBase(object): DB_TYPES = { str: 'setString', int: 'setInt', float: 'setFloat', bool: 'setBoolean', Date: 'setDate', Time: 'setTime', DateTime: 'setTimestamp', } # ~ setArray # ~ setBinaryStream # ~ setBlob # ~ setByte # ~ setBytes # ~ setCharacterStream # ~ setClob # ~ setNull # ~ setObject # ~ setObjectNull # ~ setObjectWithInfo # ~ setPropertyValue # ~ setRef def __init__(self, obj, args={}): self._obj = obj self._type = BASE self._path = args.get('path', '') self._dbc = create_instance('com.sun.star.sdb.DatabaseContext') self._rows_affected = 0 if self._path: self._name = Path(self._path).name path_url = _path_url(self._path) db = self._dbc.createInstance() db.URL = 'sdbc:embedded:firebird' db.DatabaseDocument.storeAsURL(path_url, ()) self.register() self._obj = db else: self._name = self._obj if Path(self._name).exists(): self._path = self._name self._name = Path(self._path).name if self.is_registered: db = self._dbc.getByName(self.name) self._path = _path_system(self._dbc.getDatabaseLocation(self.name)) self._obj = db else: path_url = _path_url(self._path) self._dbc.registerDatabaseLocation(self.name, path_url) db = self._dbc.getByName(self.name) self._con = db.getConnection('', '') def __contains__(self, item): return item in self.tables @property def obj(self): return self._obj @property def name(self): return self._name @property def path(self): return self._path @property def is_registered(self): return self._dbc.hasRegisteredDatabase(self.name) @property def tables(self): tables = [t.Name.lower() for t in self._con.getTables()] return tables @property def rows_affected(self): return self._rows_affected def register(self): if not self.is_registered: path_url = _path_url(self._path) self._dbc.registerDatabaseLocation(self.name, path_url) return def revoke(self, name): self._dbc.revokeDatabaseLocation(name) return True def save(self): self.obj.DatabaseDocument.store() self.refresh() return def close(self): self._con.close() return def refresh(self): self._con.getTables().refresh() return def initialize(self, database_proxy, tables): db = FirebirdDatabase(self) database_proxy.initialize(db) db.create_tables(tables) return def _validate_sql(self, sql, params): limit = ' LIMIT ' for p in params: sql = sql.replace('?', f"'{p}'", 1) if limit in sql: sql = sql.split(limit)[0] sql = sql.replace('SELECT', f'SELECT FIRST {params[-1]}') return sql def cursor(self, sql, params): if sql.startswith('SELECT'): sql = self._validate_sql(sql, params) cursor = self._con.prepareStatement(sql) return cursor if not params: cursor = self._con.createStatement() return cursor cursor = self._con.prepareStatement(sql) for i, v in enumerate(params, 1): t = type(v) if not t in self.DB_TYPES: error('Type not support') debug((i, t, v, self.DB_TYPES[t])) getattr(cursor, self.DB_TYPES[t])(i, v) return cursor def execute(self, sql, params): debug(sql, params) cursor = self.cursor(sql, params) if sql.startswith('SELECT'): result = cursor.executeQuery() elif params: result = cursor.executeUpdate() self._rows_affected = result self.save() else: result = cursor.execute(sql) self.save() return result def select(self, sql): debug('SELECT', sql) if not sql.startswith('SELECT'): return () cursor = self._con.prepareStatement(sql) query = cursor.executeQuery() return BaseQuery(query) def get_query(self, query): sql, args = query.sql() sql = self._validate_sql(sql, args) return self.select(sql)