zaz/source/libo.py

322 lines
8.5 KiB
Python

#!/usr/bin/env python3
# == Rapid Develop Macros in LibreOffice ==
# ~ This file is part of ZAZ.
# ~ ZAZ is free software: you can redistribute it and/or modify
# ~ it under the terms of the GNU General Public License as published by
# ~ the Free Software Foundation, either version 3 of the License, or
# ~ (at your option) any later version.
# ~ ZAZ is distributed in the hope that it will be useful,
# ~ but WITHOUT ANY WARRANTY; without even the implied warranty of
# ~ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# ~ GNU General Public License for more details.
# ~ You should have received a copy of the GNU General Public License
# ~ along with ZAZ. If not, see <https://www.gnu.org/licenses/>.
class BaseDateField(DateField):
def db_value(self, value):
return _date_to_struct(value)
def python_value(self, value):
return _struct_to_date(value)
class BaseTimeField(TimeField):
def db_value(self, value):
return _date_to_struct(value)
def python_value(self, value):
return _struct_to_date(value)
class BaseDateTimeField(DateTimeField):
def db_value(self, value):
return _date_to_struct(value)
def python_value(self, value):
return _struct_to_date(value)
class FirebirdDatabase(Database):
field_types = {'BOOL': 'BOOLEAN', 'DATETIME': 'TIMESTAMP'}
def __init__(self, database, **kwargs):
super().__init__(database, **kwargs)
self._db = database
def _connect(self):
return self._db
def create_tables(self, models, **options):
options['safe'] = False
tables = self._db.tables
models = [m for m in models if not m.__name__.lower() in tables]
super().create_tables(models, **options)
def execute_sql(self, sql, params=None, commit=True):
with __exception_wrapper__:
cursor = self._db.execute(sql, params)
return cursor
def last_insert_id(self, cursor, query_type=None):
# ~ debug('LAST_ID', cursor)
return 0
def rows_affected(self, cursor):
return self._db.rows_affected
@property
def path(self):
return self._db.path
class BaseRow:
pass
class BaseQuery(object):
PY_TYPES = {
'SQL_LONG': 'getLong',
'SQL_VARYING': 'getString',
'SQL_FLOAT': 'getFloat',
'SQL_BOOLEAN': 'getBoolean',
'SQL_TYPE_DATE': 'getDate',
'SQL_TYPE_TIME': 'getTime',
'SQL_TIMESTAMP': 'getTimestamp',
}
TYPES_DATE = ('SQL_TYPE_DATE', 'SQL_TYPE_TIME', 'SQL_TIMESTAMP')
def __init__(self, query):
self._query = query
self._meta = query.MetaData
self._cols = self._meta.ColumnCount
self._names = query.Columns.ElementNames
self._data = self._get_data()
def __getitem__(self, index):
return self._data[index]
def __iter__(self):
self._index = 0
return self
def __next__(self):
try:
row = self._data[self._index]
except IndexError:
raise StopIteration
self._index += 1
return row
def _to_python(self, index):
type_field = self._meta.getColumnTypeName(index)
value = getattr(self._query, self.PY_TYPES[type_field])(index)
if type_field in self.TYPES_DATE:
value = _struct_to_date(value)
return value
def _get_row(self):
row = BaseRow()
for i in range(1, self._cols + 1):
column_name = self._meta.getColumnName(i)
value = self._to_python(i)
setattr(row, column_name, value)
return row
def _get_data(self):
data = []
while self._query.next():
row = self._get_row()
data.append(row)
return data
@property
def tuples(self):
data = [tuple(r.__dict__.values()) for r in self._data]
return tuple(data)
@property
def dicts(self):
data = [r.__dict__ for r in self._data]
return tuple(data)
class LOBase(object):
DB_TYPES = {
str: 'setString',
int: 'setInt',
float: 'setFloat',
bool: 'setBoolean',
Date: 'setDate',
Time: 'setTime',
DateTime: 'setTimestamp',
}
# ~ setArray
# ~ setBinaryStream
# ~ setBlob
# ~ setByte
# ~ setBytes
# ~ setCharacterStream
# ~ setClob
# ~ setNull
# ~ setObject
# ~ setObjectNull
# ~ setObjectWithInfo
# ~ setPropertyValue
# ~ setRef
def __init__(self, obj, args={}):
self._obj = obj
self._type = BASE
self._path = args.get('path', '')
self._dbc = create_instance('com.sun.star.sdb.DatabaseContext')
self._rows_affected = 0
if self._path:
self._name = Path(self._path).name
path_url = _path_url(self._path)
db = self._dbc.createInstance()
db.URL = 'sdbc:embedded:firebird'
db.DatabaseDocument.storeAsURL(path_url, ())
self.register()
self._obj = db
else:
self._name = self._obj
if Path(self._name).exists():
self._path = self._name
self._name = Path(self._path).name
if self.is_registered:
db = self._dbc.getByName(self.name)
self._path = _path_system(self._dbc.getDatabaseLocation(self.name))
self._obj = db
else:
path_url = _path_url(self._path)
self._dbc.registerDatabaseLocation(self.name, path_url)
db = self._dbc.getByName(self.name)
self._con = db.getConnection('', '')
def __contains__(self, item):
return item in self.tables
@property
def obj(self):
return self._obj
@property
def name(self):
return self._name
@property
def path(self):
return self._path
@property
def is_registered(self):
return self._dbc.hasRegisteredDatabase(self.name)
@property
def tables(self):
tables = [t.Name.lower() for t in self._con.getTables()]
return tables
@property
def rows_affected(self):
return self._rows_affected
def register(self):
if not self.is_registered:
path_url = _path_url(self._path)
self._dbc.registerDatabaseLocation(self.name, path_url)
return
def revoke(self, name):
self._dbc.revokeDatabaseLocation(name)
return True
def save(self):
self.obj.DatabaseDocument.store()
self.refresh()
return
def close(self):
self._con.close()
return
def refresh(self):
self._con.getTables().refresh()
return
def initialize(self, database_proxy, tables):
db = FirebirdDatabase(self)
database_proxy.initialize(db)
db.create_tables(tables)
return
def _validate_sql(self, sql, params):
limit = ' LIMIT '
for p in params:
sql = sql.replace('?', f"'{p}'", 1)
if limit in sql:
sql = sql.split(limit)[0]
sql = sql.replace('SELECT', f'SELECT FIRST {params[-1]}')
return sql
def cursor(self, sql, params):
if sql.startswith('SELECT'):
sql = self._validate_sql(sql, params)
cursor = self._con.prepareStatement(sql)
return cursor
if not params:
cursor = self._con.createStatement()
return cursor
cursor = self._con.prepareStatement(sql)
for i, v in enumerate(params, 1):
t = type(v)
if not t in self.DB_TYPES:
error('Type not support')
debug((i, t, v, self.DB_TYPES[t]))
getattr(cursor, self.DB_TYPES[t])(i, v)
return cursor
def execute(self, sql, params):
debug(sql, params)
cursor = self.cursor(sql, params)
if sql.startswith('SELECT'):
result = cursor.executeQuery()
elif params:
result = cursor.executeUpdate()
self._rows_affected = result
self.save()
else:
result = cursor.execute(sql)
self.save()
return result
def select(self, sql):
debug('SELECT', sql)
if not sql.startswith('SELECT'):
return ()
cursor = self._con.prepareStatement(sql)
query = cursor.executeQuery()
return BaseQuery(query)
def get_query(self, query):
sql, args = query.sql()
sql = self._validate_sql(sql, args)
return self.select(sql)