" % (
- self.months, self.days, self.microseconds)
-
- def __eq__(self, other):
- return other is not None and isinstance(other, Interval) and \
- self.months == other.months and self.days == other.days and \
- self.microseconds == other.microseconds
-
- def __neq__(self, other):
- return not self.__eq__(other)
-
-
-class PGType():
- def __init__(self, value):
- self.value = value
-
- def encode(self, encoding):
- return str(self.value).encode(encoding)
-
-
-class PGEnum(PGType):
- def __init__(self, value):
- if isinstance(value, str):
- self.value = value
- else:
- self.value = value.value
-
-
-class PGJson(PGType):
- def encode(self, encoding):
- return dumps(self.value).encode(encoding)
-
-
-class PGJsonb(PGType):
- def encode(self, encoding):
- return dumps(self.value).encode(encoding)
-
-
-class PGTsvector(PGType):
- pass
-
-
-class PGVarchar(str):
- pass
-
-
-class PGText(str):
- pass
-
-
-def pack_funcs(fmt):
- struc = Struct('!' + fmt)
- return struc.pack, struc.unpack_from
-
-
-i_pack, i_unpack = pack_funcs('i')
-h_pack, h_unpack = pack_funcs('h')
-q_pack, q_unpack = pack_funcs('q')
-d_pack, d_unpack = pack_funcs('d')
-f_pack, f_unpack = pack_funcs('f')
-iii_pack, iii_unpack = pack_funcs('iii')
-ii_pack, ii_unpack = pack_funcs('ii')
-qii_pack, qii_unpack = pack_funcs('qii')
-dii_pack, dii_unpack = pack_funcs('dii')
-ihihih_pack, ihihih_unpack = pack_funcs('ihihih')
-ci_pack, ci_unpack = pack_funcs('ci')
-bh_pack, bh_unpack = pack_funcs('bh')
-cccc_pack, cccc_unpack = pack_funcs('cccc')
-
-
-min_int2, max_int2 = -2 ** 15, 2 ** 15
-min_int4, max_int4 = -2 ** 31, 2 ** 31
-min_int8, max_int8 = -2 ** 63, 2 ** 63
-
-
-class Warning(Exception):
- """Generic exception raised for important database warnings like data
- truncations. This exception is not currently used by pg8000.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class Error(Exception):
- """Generic exception that is the base exception of all other error
- exceptions.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class InterfaceError(Error):
- """Generic exception raised for errors that are related to the database
- interface rather than the database itself. For example, if the interface
- attempts to use an SSL connection but the server refuses, an InterfaceError
- will be raised.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class DatabaseError(Error):
- """Generic exception raised for errors that are related to the database.
- This exception is currently never raised by pg8000.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class DataError(DatabaseError):
- """Generic exception raised for errors that are due to problems with the
- processed data. This exception is not currently raised by pg8000.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class OperationalError(DatabaseError):
- """
- Generic exception raised for errors that are related to the database's
- operation and not necessarily under the control of the programmer. This
- exception is currently never raised by pg8000.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class IntegrityError(DatabaseError):
- """
- Generic exception raised when the relational integrity of the database is
- affected. This exception is not currently raised by pg8000.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class InternalError(DatabaseError):
- """Generic exception raised when the database encounters an internal error.
- This is currently only raised when unexpected state occurs in the pg8000
- interface itself, and is typically the result of a interface bug.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class ProgrammingError(DatabaseError):
- """Generic exception raised for programming errors. For example, this
- exception is raised if more parameter fields are in a query string than
- there are available parameters.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class NotSupportedError(DatabaseError):
- """Generic exception raised in case a method or database API was used which
- is not supported by the database.
-
- This exception is part of the `DBAPI 2.0 specification
- `_.
- """
- pass
-
-
-class ArrayContentNotSupportedError(NotSupportedError):
- """
- Raised when attempting to transmit an array where the base type is not
- supported for binary data transfer by the interface.
- """
- pass
-
-
-class ArrayContentNotHomogenousError(ProgrammingError):
- """
- Raised when attempting to transmit an array that doesn't contain only a
- single type of object.
- """
- pass
-
-
-class ArrayDimensionsNotConsistentError(ProgrammingError):
- """
- Raised when attempting to transmit an array that has inconsistent
- multi-dimension sizes.
- """
- pass
-
-
-def Date(year, month, day):
- """Constuct an object holding a date value.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- :rtype: :class:`datetime.date`
- """
- return date(year, month, day)
-
-
-def Time(hour, minute, second):
- """Construct an object holding a time value.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- :rtype: :class:`datetime.time`
- """
- return time(hour, minute, second)
-
-
-def Timestamp(year, month, day, hour, minute, second):
- """Construct an object holding a timestamp value.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- :rtype: :class:`datetime.datetime`
- """
- return Datetime(year, month, day, hour, minute, second)
-
-
-def DateFromTicks(ticks):
- """Construct an object holding a date value from the given ticks value
- (number of seconds since the epoch).
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- :rtype: :class:`datetime.date`
- """
- return Date(*localtime(ticks)[:3])
-
-
-def TimeFromTicks(ticks):
- """Construct an objet holding a time value from the given ticks value
- (number of seconds since the epoch).
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- :rtype: :class:`datetime.time`
- """
- return Time(*localtime(ticks)[3:6])
-
-
-def TimestampFromTicks(ticks):
- """Construct an object holding a timestamp value from the given ticks value
- (number of seconds since the epoch).
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- :rtype: :class:`datetime.datetime`
- """
- return Timestamp(*localtime(ticks)[:6])
-
-
-def Binary(value):
- """Construct an object holding binary data.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
-
- """
- return value
-
-
-FC_TEXT = 0
-FC_BINARY = 1
-
-
-def convert_paramstyle(style, query):
- # I don't see any way to avoid scanning the query string char by char,
- # so we might as well take that careful approach and create a
- # state-based scanner. We'll use int variables for the state.
- OUTSIDE = 0 # outside quoted string
- INSIDE_SQ = 1 # inside single-quote string '...'
- INSIDE_QI = 2 # inside quoted identifier "..."
- INSIDE_ES = 3 # inside escaped single-quote string, E'...'
- INSIDE_PN = 4 # inside parameter name eg. :name
- INSIDE_CO = 5 # inside inline comment eg. --
-
- in_quote_escape = False
- in_param_escape = False
- placeholders = []
- output_query = []
- param_idx = map(lambda x: "$" + str(x), count(1))
- state = OUTSIDE
- prev_c = None
- for i, c in enumerate(query):
- if i + 1 < len(query):
- next_c = query[i + 1]
- else:
- next_c = None
-
- if state == OUTSIDE:
- if c == "'":
- output_query.append(c)
- if prev_c == 'E':
- state = INSIDE_ES
- else:
- state = INSIDE_SQ
- elif c == '"':
- output_query.append(c)
- state = INSIDE_QI
- elif c == '-':
- output_query.append(c)
- if prev_c == '-':
- state = INSIDE_CO
- elif style == "qmark" and c == "?":
- output_query.append(next(param_idx))
- elif style == "numeric" and c == ":" and next_c not in ':=' \
- and prev_c != ':':
- # Treat : as beginning of parameter name if and only
- # if it's the only : around
- # Needed to properly process type conversions
- # i.e. sum(x)::float
- output_query.append("$")
- elif style == "named" and c == ":" and next_c not in ':=' \
- and prev_c != ':':
- # Same logic for : as in numeric parameters
- state = INSIDE_PN
- placeholders.append('')
- elif style == "pyformat" and c == '%' and next_c == "(":
- state = INSIDE_PN
- placeholders.append('')
- elif style in ("format", "pyformat") and c == "%":
- style = "format"
- if in_param_escape:
- in_param_escape = False
- output_query.append(c)
- else:
- if next_c == "%":
- in_param_escape = True
- elif next_c == "s":
- state = INSIDE_PN
- output_query.append(next(param_idx))
- else:
- raise InterfaceError(
- "Only %s and %% are supported in the query.")
- else:
- output_query.append(c)
-
- elif state == INSIDE_SQ:
- if c == "'":
- if in_quote_escape:
- in_quote_escape = False
- else:
- if next_c == "'":
- in_quote_escape = True
- else:
- state = OUTSIDE
- output_query.append(c)
-
- elif state == INSIDE_QI:
- if c == '"':
- state = OUTSIDE
- output_query.append(c)
-
- elif state == INSIDE_ES:
- if c == "'" and prev_c != "\\":
- # check for escaped single-quote
- state = OUTSIDE
- output_query.append(c)
-
- elif state == INSIDE_PN:
- if style == 'named':
- placeholders[-1] += c
- if next_c is None or (not next_c.isalnum() and next_c != '_'):
- state = OUTSIDE
- try:
- pidx = placeholders.index(placeholders[-1], 0, -1)
- output_query.append("$" + str(pidx + 1))
- del placeholders[-1]
- except ValueError:
- output_query.append("$" + str(len(placeholders)))
- elif style == 'pyformat':
- if prev_c == ')' and c == "s":
- state = OUTSIDE
- try:
- pidx = placeholders.index(placeholders[-1], 0, -1)
- output_query.append("$" + str(pidx + 1))
- del placeholders[-1]
- except ValueError:
- output_query.append("$" + str(len(placeholders)))
- elif c in "()":
- pass
- else:
- placeholders[-1] += c
- elif style == 'format':
- state = OUTSIDE
-
- elif state == INSIDE_CO:
- output_query.append(c)
- if c == '\n':
- state = OUTSIDE
-
- prev_c = c
-
- if style in ('numeric', 'qmark', 'format'):
- def make_args(vals):
- return vals
- else:
- def make_args(vals):
- return tuple(vals[p] for p in placeholders)
-
- return ''.join(output_query), make_args
-
-
-EPOCH = Datetime(2000, 1, 1)
-EPOCH_TZ = EPOCH.replace(tzinfo=Timezone.utc)
-EPOCH_SECONDS = timegm(EPOCH.timetuple())
-INFINITY_MICROSECONDS = 2 ** 63 - 1
-MINUS_INFINITY_MICROSECONDS = -1 * INFINITY_MICROSECONDS - 1
-
-
-# data is 64-bit integer representing microseconds since 2000-01-01
-def timestamp_recv_integer(data, offset, length):
- micros = q_unpack(data, offset)[0]
- try:
- return EPOCH + Timedelta(microseconds=micros)
- except OverflowError:
- if micros == INFINITY_MICROSECONDS:
- return 'infinity'
- elif micros == MINUS_INFINITY_MICROSECONDS:
- return '-infinity'
- else:
- return micros
-
-
-# data is double-precision float representing seconds since 2000-01-01
-def timestamp_recv_float(data, offset, length):
- return Datetime.utcfromtimestamp(EPOCH_SECONDS + d_unpack(data, offset)[0])
-
-
-# data is 64-bit integer representing microseconds since 2000-01-01
-def timestamp_send_integer(v):
- return q_pack(
- int((timegm(v.timetuple()) - EPOCH_SECONDS) * 1e6) + v.microsecond)
-
-
-# data is double-precision float representing seconds since 2000-01-01
-def timestamp_send_float(v):
- return d_pack(timegm(v.timetuple()) + v.microsecond / 1e6 - EPOCH_SECONDS)
-
-
-def timestamptz_send_integer(v):
- # timestamps should be sent as UTC. If they have zone info,
- # convert them.
- return timestamp_send_integer(
- v.astimezone(Timezone.utc).replace(tzinfo=None))
-
-
-def timestamptz_send_float(v):
- # timestamps should be sent as UTC. If they have zone info,
- # convert them.
- return timestamp_send_float(
- v.astimezone(Timezone.utc).replace(tzinfo=None))
-
-
-# return a timezone-aware datetime instance if we're reading from a
-# "timestamp with timezone" type. The timezone returned will always be
-# UTC, but providing that additional information can permit conversion
-# to local.
-def timestamptz_recv_integer(data, offset, length):
- micros = q_unpack(data, offset)[0]
- try:
- return EPOCH_TZ + Timedelta(microseconds=micros)
- except OverflowError:
- if micros == INFINITY_MICROSECONDS:
- return 'infinity'
- elif micros == MINUS_INFINITY_MICROSECONDS:
- return '-infinity'
- else:
- return micros
-
-
-def timestamptz_recv_float(data, offset, length):
- return timestamp_recv_float(data, offset, length).replace(
- tzinfo=Timezone.utc)
-
-
-def interval_send_integer(v):
- microseconds = v.microseconds
- try:
- microseconds += int(v.seconds * 1e6)
- except AttributeError:
- pass
-
- try:
- months = v.months
- except AttributeError:
- months = 0
-
- return qii_pack(microseconds, v.days, months)
-
-
-def interval_send_float(v):
- seconds = v.microseconds / 1000.0 / 1000.0
- try:
- seconds += v.seconds
- except AttributeError:
- pass
-
- try:
- months = v.months
- except AttributeError:
- months = 0
-
- return dii_pack(seconds, v.days, months)
-
-
-def interval_recv_integer(data, offset, length):
- microseconds, days, months = qii_unpack(data, offset)
- if months == 0:
- seconds, micros = divmod(microseconds, 1e6)
- return Timedelta(days, seconds, micros)
- else:
- return Interval(microseconds, days, months)
-
-
-def interval_recv_float(data, offset, length):
- seconds, days, months = dii_unpack(data, offset)
- if months == 0:
- secs, microseconds = divmod(seconds, 1e6)
- return Timedelta(days, secs, microseconds)
- else:
- return Interval(int(seconds * 1000 * 1000), days, months)
-
-
-def int8_recv(data, offset, length):
- return q_unpack(data, offset)[0]
-
-
-def int2_recv(data, offset, length):
- return h_unpack(data, offset)[0]
-
-
-def int4_recv(data, offset, length):
- return i_unpack(data, offset)[0]
-
-
-def float4_recv(data, offset, length):
- return f_unpack(data, offset)[0]
-
-
-def float8_recv(data, offset, length):
- return d_unpack(data, offset)[0]
-
-
-def bytea_send(v):
- return v
-
-
-# bytea
-def bytea_recv(data, offset, length):
- return data[offset:offset + length]
-
-
-def uuid_send(v):
- return v.bytes
-
-
-def uuid_recv(data, offset, length):
- return UUID(bytes=data[offset:offset+length])
-
-
-def bool_send(v):
- return b"\x01" if v else b"\x00"
-
-
-NULL = i_pack(-1)
-
-NULL_BYTE = b'\x00'
-
-
-def null_send(v):
- return NULL
-
-
-def int_in(data, offset, length):
- return int(data[offset: offset + length])
-
-
-class Cursor():
- """A cursor object is returned by the :meth:`~Connection.cursor` method of
- a connection. It has the following attributes and methods:
-
- .. attribute:: arraysize
-
- This read/write attribute specifies the number of rows to fetch at a
- time with :meth:`fetchmany`. It defaults to 1.
-
- .. attribute:: connection
-
- This read-only attribute contains a reference to the connection object
- (an instance of :class:`Connection`) on which the cursor was
- created.
-
- This attribute is part of a DBAPI 2.0 extension. Accessing this
- attribute will generate the following warning: ``DB-API extension
- cursor.connection used``.
-
- .. attribute:: rowcount
-
- This read-only attribute contains the number of rows that the last
- ``execute()`` or ``executemany()`` method produced (for query
- statements like ``SELECT``) or affected (for modification statements
- like ``UPDATE``).
-
- The value is -1 if:
-
- - No ``execute()`` or ``executemany()`` method has been performed yet
- on the cursor.
- - There was no rowcount associated with the last ``execute()``.
- - At least one of the statements executed as part of an
- ``executemany()`` had no row count associated with it.
- - Using a ``SELECT`` query statement on PostgreSQL server older than
- version 9.
- - Using a ``COPY`` query statement on PostgreSQL server version 8.1 or
- older.
-
- This attribute is part of the `DBAPI 2.0 specification
- `_.
-
- .. attribute:: description
-
- This read-only attribute is a sequence of 7-item sequences. Each value
- contains information describing one result column. The 7 items
- returned for each column are (name, type_code, display_size,
- internal_size, precision, scale, null_ok). Only the first two values
- are provided by the current implementation.
-
- This attribute is part of the `DBAPI 2.0 specification
- `_.
- """
-
- def __init__(self, connection):
- self._c = connection
- self.arraysize = 1
- self.ps = None
- self._row_count = -1
- self._cached_rows = deque()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.close()
-
- @property
- def connection(self):
- warn("DB-API extension cursor.connection used", stacklevel=3)
- return self._c
-
- @property
- def rowcount(self):
- return self._row_count
-
- description = property(lambda self: self._getDescription())
-
- def _getDescription(self):
- if self.ps is None:
- return None
- row_desc = self.ps['row_desc']
- if len(row_desc) == 0:
- return None
- columns = []
- for col in row_desc:
- columns.append(
- (col["name"], col["type_oid"], None, None, None, None, None))
- return columns
-
- ##
- # Executes a database operation. Parameters may be provided as a sequence
- # or mapping and will be bound to variables in the operation.
- #
- # Stability: Part of the DBAPI 2.0 specification.
- def execute(self, operation, args=None, stream=None):
- """Executes a database operation. Parameters may be provided as a
- sequence, or as a mapping, depending upon the value of
- :data:`pg8000.paramstyle`.
-
- This method is part of the `DBAPI 2.0 specification
- `_.
-
- :param operation:
- The SQL statement to execute.
-
- :param args:
- If :data:`paramstyle` is ``qmark``, ``numeric``, or ``format``,
- this argument should be an array of parameters to bind into the
- statement. If :data:`paramstyle` is ``named``, the argument should
- be a dict mapping of parameters. If the :data:`paramstyle` is
- ``pyformat``, the argument value may be either an array or a
- mapping.
-
- :param stream: This is a pg8000 extension for use with the PostgreSQL
- `COPY
- `_
- command. For a COPY FROM the parameter must be a readable file-like
- object, and for COPY TO it must be writable.
-
- .. versionadded:: 1.9.11
- """
- try:
- self.stream = stream
-
- if not self._c.in_transaction and not self._c.autocommit:
- self._c.execute(self, "begin transaction", None)
- self._c.execute(self, operation, args)
- except AttributeError as e:
- if self._c is None:
- raise InterfaceError("Cursor closed")
- elif self._c._sock is None:
- raise InterfaceError("connection is closed")
- else:
- raise e
- return self
-
- def executemany(self, operation, param_sets):
- """Prepare a database operation, and then execute it against all
- parameter sequences or mappings provided.
-
- This method is part of the `DBAPI 2.0 specification
- `_.
-
- :param operation:
- The SQL statement to execute
- :param parameter_sets:
- A sequence of parameters to execute the statement with. The values
- in the sequence should be sequences or mappings of parameters, the
- same as the args argument of the :meth:`execute` method.
- """
- rowcounts = []
- for parameters in param_sets:
- self.execute(operation, parameters)
- rowcounts.append(self._row_count)
-
- self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
- return self
-
- def fetchone(self):
- """Fetch the next row of a query result set.
-
- This method is part of the `DBAPI 2.0 specification
- `_.
-
- :returns:
- A row as a sequence of field values, or ``None`` if no more rows
- are available.
- """
- try:
- return next(self)
- except StopIteration:
- return None
- except TypeError:
- raise ProgrammingError("attempting to use unexecuted cursor")
- except AttributeError:
- raise ProgrammingError("attempting to use unexecuted cursor")
-
- def fetchmany(self, num=None):
- """Fetches the next set of rows of a query result.
-
- This method is part of the `DBAPI 2.0 specification
- `_.
-
- :param size:
-
- The number of rows to fetch when called. If not provided, the
- :attr:`arraysize` attribute value is used instead.
-
- :returns:
-
- A sequence, each entry of which is a sequence of field values
- making up a row. If no more rows are available, an empty sequence
- will be returned.
- """
- try:
- return tuple(
- islice(self, self.arraysize if num is None else num))
- except TypeError:
- raise ProgrammingError("attempting to use unexecuted cursor")
-
- def fetchall(self):
- """Fetches all remaining rows of a query result.
-
- This method is part of the `DBAPI 2.0 specification
- `_.
-
- :returns:
-
- A sequence, each entry of which is a sequence of field values
- making up a row.
- """
- try:
- return tuple(self)
- except TypeError:
- raise ProgrammingError("attempting to use unexecuted cursor")
-
- def close(self):
- """Closes the cursor.
-
- This method is part of the `DBAPI 2.0 specification
- `_.
- """
- self._c = None
-
- def __iter__(self):
- """A cursor object is iterable to retrieve the rows from a query.
-
- This is a DBAPI 2.0 extension.
- """
- return self
-
- def setinputsizes(self, sizes):
- """This method is part of the `DBAPI 2.0 specification
- `_, however, it is not
- implemented by pg8000.
- """
- pass
-
- def setoutputsize(self, size, column=None):
- """This method is part of the `DBAPI 2.0 specification
- `_, however, it is not
- implemented by pg8000.
- """
- pass
-
- def __next__(self):
- try:
- return self._cached_rows.popleft()
- except IndexError:
- if self.ps is None:
- raise ProgrammingError("A query hasn't been issued.")
- elif len(self.ps['row_desc']) == 0:
- raise ProgrammingError("no result set")
- else:
- raise StopIteration()
-
-
-# Message codes
-NOTICE_RESPONSE = b"N"
-AUTHENTICATION_REQUEST = b"R"
-PARAMETER_STATUS = b"S"
-BACKEND_KEY_DATA = b"K"
-READY_FOR_QUERY = b"Z"
-ROW_DESCRIPTION = b"T"
-ERROR_RESPONSE = b"E"
-DATA_ROW = b"D"
-COMMAND_COMPLETE = b"C"
-PARSE_COMPLETE = b"1"
-BIND_COMPLETE = b"2"
-CLOSE_COMPLETE = b"3"
-PORTAL_SUSPENDED = b"s"
-NO_DATA = b"n"
-PARAMETER_DESCRIPTION = b"t"
-NOTIFICATION_RESPONSE = b"A"
-COPY_DONE = b"c"
-COPY_DATA = b"d"
-COPY_IN_RESPONSE = b"G"
-COPY_OUT_RESPONSE = b"H"
-EMPTY_QUERY_RESPONSE = b"I"
-
-BIND = b"B"
-PARSE = b"P"
-EXECUTE = b"E"
-FLUSH = b'H'
-SYNC = b'S'
-PASSWORD = b'p'
-DESCRIBE = b'D'
-TERMINATE = b'X'
-CLOSE = b'C'
-
-
-def _establish_ssl(_socket, ssl_params):
- if not isinstance(ssl_params, dict):
- ssl_params = {}
-
- try:
- import ssl as sslmodule
-
- keyfile = ssl_params.get('keyfile')
- certfile = ssl_params.get('certfile')
- ca_certs = ssl_params.get('ca_certs')
- if ca_certs is None:
- verify_mode = sslmodule.CERT_NONE
- else:
- verify_mode = sslmodule.CERT_REQUIRED
-
- # Int32(8) - Message length, including self.
- # Int32(80877103) - The SSL request code.
- _socket.sendall(ii_pack(8, 80877103))
- resp = _socket.recv(1)
- if resp == b'S':
- return sslmodule.wrap_socket(
- _socket, keyfile=keyfile, certfile=certfile,
- cert_reqs=verify_mode, ca_certs=ca_certs)
- else:
- raise InterfaceError("Server refuses SSL")
- except ImportError:
- raise InterfaceError(
- "SSL required but ssl module not available in "
- "this python installation")
-
-
-def create_message(code, data=b''):
- return code + i_pack(len(data) + 4) + data
-
-
-FLUSH_MSG = create_message(FLUSH)
-SYNC_MSG = create_message(SYNC)
-TERMINATE_MSG = create_message(TERMINATE)
-COPY_DONE_MSG = create_message(COPY_DONE)
-EXECUTE_MSG = create_message(EXECUTE, NULL_BYTE + i_pack(0))
-
-# DESCRIBE constants
-STATEMENT = b'S'
-PORTAL = b'P'
-
-# ErrorResponse codes
-RESPONSE_SEVERITY = "S" # always present
-RESPONSE_SEVERITY = "V" # always present
-RESPONSE_CODE = "C" # always present
-RESPONSE_MSG = "M" # always present
-RESPONSE_DETAIL = "D"
-RESPONSE_HINT = "H"
-RESPONSE_POSITION = "P"
-RESPONSE__POSITION = "p"
-RESPONSE__QUERY = "q"
-RESPONSE_WHERE = "W"
-RESPONSE_FILE = "F"
-RESPONSE_LINE = "L"
-RESPONSE_ROUTINE = "R"
-
-IDLE = b"I"
-IDLE_IN_TRANSACTION = b"T"
-IDLE_IN_FAILED_TRANSACTION = b"E"
-
-
-arr_trans = dict(zip(map(ord, "[] 'u"), list('{}') + [None] * 3))
-
-
-class Connection():
-
- # DBAPI Extension: supply exceptions as attributes on the connection
- Warning = property(lambda self: self._getError(Warning))
- Error = property(lambda self: self._getError(Error))
- InterfaceError = property(lambda self: self._getError(InterfaceError))
- DatabaseError = property(lambda self: self._getError(DatabaseError))
- OperationalError = property(lambda self: self._getError(OperationalError))
- IntegrityError = property(lambda self: self._getError(IntegrityError))
- InternalError = property(lambda self: self._getError(InternalError))
- ProgrammingError = property(lambda self: self._getError(ProgrammingError))
- NotSupportedError = property(
- lambda self: self._getError(NotSupportedError))
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.close()
-
- def _getError(self, error):
- warn(
- "DB-API extension connection.%s used" %
- error.__name__, stacklevel=3)
- return error
-
- def __init__(
- self, user, host, unix_sock, port, database, password, ssl,
- timeout, application_name, max_prepared_statements, tcp_keepalive):
- self._client_encoding = "utf8"
- self._commands_with_count = (
- b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY",
- b"SELECT")
- self.notifications = deque(maxlen=100)
- self.notices = deque(maxlen=100)
- self.parameter_statuses = deque(maxlen=100)
- self.max_prepared_statements = int(max_prepared_statements)
-
- if user is None:
- raise InterfaceError(
- "The 'user' connection parameter cannot be None")
-
- if isinstance(user, str):
- self.user = user.encode('utf8')
- else:
- self.user = user
-
- if isinstance(password, str):
- self.password = password.encode('utf8')
- else:
- self.password = password
-
- self.autocommit = False
- self._xid = None
-
- self._caches = {}
-
- try:
- if unix_sock is None and host is not None:
- self._usock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- elif unix_sock is not None:
- if not hasattr(socket, "AF_UNIX"):
- raise InterfaceError(
- "attempt to connect to unix socket on unsupported "
- "platform")
- self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- else:
- raise ProgrammingError(
- "one of host or unix_sock must be provided")
- if timeout is not None:
- self._usock.settimeout(timeout)
-
- if unix_sock is None and host is not None:
- self._usock.connect((host, port))
- elif unix_sock is not None:
- self._usock.connect(unix_sock)
-
- if ssl:
- self._usock = _establish_ssl(self._usock, ssl)
-
- self._sock = self._usock.makefile(mode="rwb")
- if tcp_keepalive:
- self._usock.setsockopt(
- socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- except socket.error as e:
- self._usock.close()
- raise InterfaceError("communication error", e)
- self._flush = self._sock.flush
- self._read = self._sock.read
- self._write = self._sock.write
- self._backend_key_data = None
-
- def text_out(v):
- return v.encode(self._client_encoding)
-
- def enum_out(v):
- return str(v.value).encode(self._client_encoding)
-
- def time_out(v):
- return v.isoformat().encode(self._client_encoding)
-
- def date_out(v):
- return v.isoformat().encode(self._client_encoding)
-
- def unknown_out(v):
- return str(v).encode(self._client_encoding)
-
- trans_tab = dict(zip(map(ord, '{}'), '[]'))
- glbls = {'Decimal': Decimal}
-
- def array_in(data, idx, length):
- arr = []
- prev_c = None
- for c in data[idx:idx+length].decode(
- self._client_encoding).translate(
- trans_tab).replace('NULL', 'None'):
- if c not in ('[', ']', ',', 'N') and prev_c in ('[', ','):
- arr.extend("Decimal('")
- elif c in (']', ',') and prev_c not in ('[', ']', ',', 'e'):
- arr.extend("')")
-
- arr.append(c)
- prev_c = c
- return eval(''.join(arr), glbls)
-
- def array_recv(data, idx, length):
- final_idx = idx + length
- dim, hasnull, typeoid = iii_unpack(data, idx)
- idx += 12
-
- # get type conversion method for typeoid
- conversion = self.pg_types[typeoid][1]
-
- # Read dimension info
- dim_lengths = []
- for i in range(dim):
- dim_lengths.append(ii_unpack(data, idx)[0])
- idx += 8
-
- # Read all array values
- values = []
- while idx < final_idx:
- element_len, = i_unpack(data, idx)
- idx += 4
- if element_len == -1:
- values.append(None)
- else:
- values.append(conversion(data, idx, element_len))
- idx += element_len
-
- # at this point, {{1,2,3},{4,5,6}}::int[][] looks like
- # [1,2,3,4,5,6]. go through the dimensions and fix up the array
- # contents to match expected dimensions
- for length in reversed(dim_lengths[1:]):
- values = list(map(list, zip(*[iter(values)] * length)))
- return values
-
- def vector_in(data, idx, length):
- return eval('[' + data[idx:idx+length].decode(
- self._client_encoding).replace(' ', ',') + ']')
-
- def text_recv(data, offset, length):
- return str(data[offset: offset + length], self._client_encoding)
-
- def bool_recv(data, offset, length):
- return data[offset] == 1
-
- def json_in(data, offset, length):
- return loads(
- str(data[offset: offset + length], self._client_encoding))
-
- def time_in(data, offset, length):
- hour = int(data[offset:offset + 2])
- minute = int(data[offset + 3:offset + 5])
- sec = Decimal(
- data[offset + 6:offset + length].decode(self._client_encoding))
- return time(
- hour, minute, int(sec), int((sec - int(sec)) * 1000000))
-
- def date_in(data, offset, length):
- d = data[offset:offset+length].decode(self._client_encoding)
- try:
- return date(int(d[:4]), int(d[5:7]), int(d[8:10]))
- except ValueError:
- return d
-
- def numeric_in(data, offset, length):
- return Decimal(
- data[offset: offset + length].decode(self._client_encoding))
-
- def numeric_out(d):
- return str(d).encode(self._client_encoding)
-
- self.pg_types = defaultdict(
- lambda: (FC_TEXT, text_recv), {
- 16: (FC_BINARY, bool_recv), # boolean
- 17: (FC_BINARY, bytea_recv), # bytea
- 19: (FC_BINARY, text_recv), # name type
- 20: (FC_BINARY, int8_recv), # int8
- 21: (FC_BINARY, int2_recv), # int2
- 22: (FC_TEXT, vector_in), # int2vector
- 23: (FC_BINARY, int4_recv), # int4
- 25: (FC_BINARY, text_recv), # TEXT type
- 26: (FC_TEXT, int_in), # oid
- 28: (FC_TEXT, int_in), # xid
- 114: (FC_TEXT, json_in), # json
- 700: (FC_BINARY, float4_recv), # float4
- 701: (FC_BINARY, float8_recv), # float8
- 705: (FC_BINARY, text_recv), # unknown
- 829: (FC_TEXT, text_recv), # MACADDR type
- 1000: (FC_BINARY, array_recv), # BOOL[]
- 1003: (FC_BINARY, array_recv), # NAME[]
- 1005: (FC_BINARY, array_recv), # INT2[]
- 1007: (FC_BINARY, array_recv), # INT4[]
- 1009: (FC_BINARY, array_recv), # TEXT[]
- 1014: (FC_BINARY, array_recv), # CHAR[]
- 1015: (FC_BINARY, array_recv), # VARCHAR[]
- 1016: (FC_BINARY, array_recv), # INT8[]
- 1021: (FC_BINARY, array_recv), # FLOAT4[]
- 1022: (FC_BINARY, array_recv), # FLOAT8[]
- 1042: (FC_BINARY, text_recv), # CHAR type
- 1043: (FC_BINARY, text_recv), # VARCHAR type
- 1082: (FC_TEXT, date_in), # date
- 1083: (FC_TEXT, time_in),
- 1114: (FC_BINARY, timestamp_recv_float), # timestamp w/ tz
- 1184: (FC_BINARY, timestamptz_recv_float),
- 1186: (FC_BINARY, interval_recv_integer),
- 1231: (FC_TEXT, array_in), # NUMERIC[]
- 1263: (FC_BINARY, array_recv), # cstring[]
- 1700: (FC_TEXT, numeric_in), # NUMERIC
- 2275: (FC_BINARY, text_recv), # cstring
- 2950: (FC_BINARY, uuid_recv), # uuid
- 3802: (FC_TEXT, json_in), # jsonb
- })
-
- self.py_types = {
- type(None): (-1, FC_BINARY, null_send), # null
- bool: (16, FC_BINARY, bool_send),
- bytearray: (17, FC_BINARY, bytea_send), # bytea
- 20: (20, FC_BINARY, q_pack), # int8
- 21: (21, FC_BINARY, h_pack), # int2
- 23: (23, FC_BINARY, i_pack), # int4
- PGText: (25, FC_TEXT, text_out), # text
- float: (701, FC_BINARY, d_pack), # float8
- PGEnum: (705, FC_TEXT, enum_out),
- date: (1082, FC_TEXT, date_out), # date
- time: (1083, FC_TEXT, time_out), # time
- 1114: (1114, FC_BINARY, timestamp_send_integer), # timestamp
- # timestamp w/ tz
- PGVarchar: (1043, FC_TEXT, text_out), # varchar
- 1184: (1184, FC_BINARY, timestamptz_send_integer),
- PGJson: (114, FC_TEXT, text_out),
- PGJsonb: (3802, FC_TEXT, text_out),
- Timedelta: (1186, FC_BINARY, interval_send_integer),
- Interval: (1186, FC_BINARY, interval_send_integer),
- Decimal: (1700, FC_TEXT, numeric_out), # Decimal
- PGTsvector: (3614, FC_TEXT, text_out),
- UUID: (2950, FC_BINARY, uuid_send)} # uuid
-
- self.inspect_funcs = {
- Datetime: self.inspect_datetime,
- list: self.array_inspect,
- tuple: self.array_inspect,
- int: self.inspect_int}
-
- self.py_types[bytes] = (17, FC_BINARY, bytea_send) # bytea
- self.py_types[str] = (705, FC_TEXT, text_out) # unknown
- self.py_types[enum.Enum] = (705, FC_TEXT, enum_out)
-
- def inet_out(v):
- return str(v).encode(self._client_encoding)
-
- def inet_in(data, offset, length):
- inet_str = data[offset: offset + length].decode(
- self._client_encoding)
- if '/' in inet_str:
- return ip_network(inet_str, False)
- else:
- return ip_address(inet_str)
-
- self.py_types[IPv4Address] = (869, FC_TEXT, inet_out) # inet
- self.py_types[IPv6Address] = (869, FC_TEXT, inet_out) # inet
- self.py_types[IPv4Network] = (869, FC_TEXT, inet_out) # inet
- self.py_types[IPv6Network] = (869, FC_TEXT, inet_out) # inet
- self.pg_types[869] = (FC_TEXT, inet_in) # inet
-
- self.message_types = {
- NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE,
- AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST,
- PARAMETER_STATUS: self.handle_PARAMETER_STATUS,
- BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA,
- READY_FOR_QUERY: self.handle_READY_FOR_QUERY,
- ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION,
- ERROR_RESPONSE: self.handle_ERROR_RESPONSE,
- EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE,
- DATA_ROW: self.handle_DATA_ROW,
- COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE,
- PARSE_COMPLETE: self.handle_PARSE_COMPLETE,
- BIND_COMPLETE: self.handle_BIND_COMPLETE,
- CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE,
- PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED,
- NO_DATA: self.handle_NO_DATA,
- PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION,
- NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE,
- COPY_DONE: self.handle_COPY_DONE,
- COPY_DATA: self.handle_COPY_DATA,
- COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE,
- COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE}
-
- # Int32 - Message length, including self.
- # Int32(196608) - Protocol version number. Version 3.0.
- # Any number of key/value pairs, terminated by a zero byte:
- # String - A parameter name (user, database, or options)
- # String - Parameter value
- protocol = 196608
- val = bytearray(
- i_pack(protocol) + b"user\x00" + self.user + NULL_BYTE)
- if database is not None:
- if isinstance(database, str):
- database = database.encode('utf8')
- val.extend(b"database\x00" + database + NULL_BYTE)
- if application_name is not None:
- if isinstance(application_name, str):
- application_name = application_name.encode('utf8')
- val.extend(
- b"application_name\x00" + application_name + NULL_BYTE)
- val.append(0)
- self._write(i_pack(len(val) + 4))
- self._write(val)
- self._flush()
-
- self._cursor = self.cursor()
- code = self.error = None
- while code not in (READY_FOR_QUERY, ERROR_RESPONSE):
- code, data_len = ci_unpack(self._read(5))
- self.message_types[code](self._read(data_len - 4), None)
- if self.error is not None:
- raise self.error
-
- self.in_transaction = False
-
- def handle_ERROR_RESPONSE(self, data, ps):
- msg = dict(
- (
- s[:1].decode(self._client_encoding),
- s[1:].decode(self._client_encoding)) for s in
- data.split(NULL_BYTE) if s != b'')
-
- response_code = msg[RESPONSE_CODE]
- if response_code == '28000':
- cls = InterfaceError
- elif response_code == '23505':
- cls = IntegrityError
- else:
- cls = ProgrammingError
-
- self.error = cls(msg)
-
- def handle_EMPTY_QUERY_RESPONSE(self, data, ps):
- self.error = ProgrammingError("query was empty")
-
- def handle_CLOSE_COMPLETE(self, data, ps):
- pass
-
- def handle_PARSE_COMPLETE(self, data, ps):
- # Byte1('1') - Identifier.
- # Int32(4) - Message length, including self.
- pass
-
- def handle_BIND_COMPLETE(self, data, ps):
- pass
-
- def handle_PORTAL_SUSPENDED(self, data, cursor):
- pass
-
- def handle_PARAMETER_DESCRIPTION(self, data, ps):
- # Well, we don't really care -- we're going to send whatever we
- # want and let the database deal with it. But thanks anyways!
-
- # count = h_unpack(data)[0]
- # type_oids = unpack_from("!" + "i" * count, data, 2)
- pass
-
- def handle_COPY_DONE(self, data, ps):
- self._copy_done = True
-
- def handle_COPY_OUT_RESPONSE(self, data, ps):
- # Int8(1) - 0 textual, 1 binary
- # Int16(2) - Number of columns
- # Int16(N) - Format codes for each column (0 text, 1 binary)
-
- is_binary, num_cols = bh_unpack(data)
- # column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
- if ps.stream is None:
- raise InterfaceError(
- "An output stream is required for the COPY OUT response.")
-
- def handle_COPY_DATA(self, data, ps):
- ps.stream.write(data)
-
- def handle_COPY_IN_RESPONSE(self, data, ps):
- # Int16(2) - Number of columns
- # Int16(N) - Format codes for each column (0 text, 1 binary)
- is_binary, num_cols = bh_unpack(data)
- # column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
- if ps.stream is None:
- raise InterfaceError(
- "An input stream is required for the COPY IN response.")
-
- bffr = bytearray(8192)
- while True:
- bytes_read = ps.stream.readinto(bffr)
- if bytes_read == 0:
- break
- self._write(COPY_DATA + i_pack(bytes_read + 4))
- self._write(bffr[:bytes_read])
- self._flush()
-
- # Send CopyDone
- # Byte1('c') - Identifier.
- # Int32(4) - Message length, including self.
- self._write(COPY_DONE_MSG)
- self._write(SYNC_MSG)
- self._flush()
-
- def handle_NOTIFICATION_RESPONSE(self, data, ps):
- ##
- # A message sent if this connection receives a NOTIFY that it was
- # LISTENing for.
- #
- # Stability: Added in pg8000 v1.03. When limited to accessing
- # properties from a notification event dispatch, stability is
- # guaranteed for v1.xx.
- backend_pid = i_unpack(data)[0]
- idx = 4
- null = data.find(NULL_BYTE, idx) - idx
- condition = data[idx:idx + null].decode("ascii")
- idx += null + 1
- null = data.find(NULL_BYTE, idx) - idx
- # additional_info = data[idx:idx + null]
-
- self.notifications.append((backend_pid, condition))
-
- def cursor(self):
- """Creates a :class:`Cursor` object bound to this
- connection.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- return Cursor(self)
-
- def commit(self):
- """Commits the current database transaction.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- self.execute(self._cursor, "commit", None)
-
- def rollback(self):
- """Rolls back the current database transaction.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- if not self.in_transaction:
- return
- self.execute(self._cursor, "rollback", None)
-
- def close(self):
- """Closes the database connection.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- try:
- # Byte1('X') - Identifies the message as a terminate message.
- # Int32(4) - Message length, including self.
- self._write(TERMINATE_MSG)
- self._flush()
- self._sock.close()
- except AttributeError:
- raise InterfaceError("connection is closed")
- except ValueError:
- raise InterfaceError("connection is closed")
- except socket.error:
- pass
- finally:
- self._usock.close()
- self._sock = None
-
- def handle_AUTHENTICATION_REQUEST(self, data, cursor):
- # Int32 - An authentication code that represents different
- # authentication messages:
- # 0 = AuthenticationOk
- # 5 = MD5 pwd
- # 2 = Kerberos v5 (not supported by pg8000)
- # 3 = Cleartext pwd
- # 4 = crypt() pwd (not supported by pg8000)
- # 6 = SCM credential (not supported by pg8000)
- # 7 = GSSAPI (not supported by pg8000)
- # 8 = GSSAPI data (not supported by pg8000)
- # 9 = SSPI (not supported by pg8000)
- # Some authentication messages have additional data following the
- # authentication code. That data is documented in the appropriate
- # class.
- auth_code = i_unpack(data)[0]
- if auth_code == 0:
- pass
- elif auth_code == 3:
- if self.password is None:
- raise InterfaceError(
- "server requesting password authentication, but no "
- "password was provided")
- self._send_message(PASSWORD, self.password + NULL_BYTE)
- self._flush()
- elif auth_code == 5:
- ##
- # A message representing the backend requesting an MD5 hashed
- # password response. The response will be sent as
- # md5(md5(pwd + login) + salt).
-
- # Additional message data:
- # Byte4 - Hash salt.
- salt = b"".join(cccc_unpack(data, 4))
- if self.password is None:
- raise InterfaceError(
- "server requesting MD5 password authentication, but no "
- "password was provided")
- pwd = b"md5" + md5(
- md5(self.password + self.user).hexdigest().encode("ascii") +
- salt).hexdigest().encode("ascii")
- # Byte1('p') - Identifies the message as a password message.
- # Int32 - Message length including self.
- # String - The password. Password may be encrypted.
- self._send_message(PASSWORD, pwd + NULL_BYTE)
- self._flush()
-
- elif auth_code == 10:
- # AuthenticationSASL
- mechanisms = [
- m.decode('ascii') for m in data[4:-1].split(NULL_BYTE)]
-
- self.auth = ScramClient(
- mechanisms, self.user.decode('utf8'),
- self.password.decode('utf8'))
-
- init = self.auth.get_client_first().encode('utf8')
-
- # SASLInitialResponse
- self._write(
- create_message(
- PASSWORD,
- b'SCRAM-SHA-256' + NULL_BYTE + i_pack(len(init)) + init))
- self._flush()
-
- elif auth_code == 11:
- # AuthenticationSASLContinue
- self.auth.set_server_first(data[4:].decode('utf8'))
-
- # SASLResponse
- msg = self.auth.get_client_final().encode('utf8')
- self._write(create_message(PASSWORD, msg))
- self._flush()
-
- elif auth_code == 12:
- # AuthenticationSASLFinal
- self.auth.set_server_final(data[4:].decode('utf8'))
-
- elif auth_code in (2, 4, 6, 7, 8, 9):
- raise InterfaceError(
- "Authentication method " + str(auth_code) +
- " not supported by pg8000.")
- else:
- raise InterfaceError(
- "Authentication method " + str(auth_code) +
- " not recognized by pg8000.")
-
- def handle_READY_FOR_QUERY(self, data, ps):
- # Byte1 - Status indicator.
- self.in_transaction = data != IDLE
-
- def handle_BACKEND_KEY_DATA(self, data, ps):
- self._backend_key_data = data
-
- def inspect_datetime(self, value):
- if value.tzinfo is None:
- return self.py_types[1114] # timestamp
- else:
- return self.py_types[1184] # send as timestamptz
-
- def inspect_int(self, value):
- if min_int2 < value < max_int2:
- return self.py_types[21]
- if min_int4 < value < max_int4:
- return self.py_types[23]
- if min_int8 < value < max_int8:
- return self.py_types[20]
-
- def make_params(self, values):
- params = []
- for value in values:
- typ = type(value)
- try:
- params.append(self.py_types[typ])
- except KeyError:
- try:
- params.append(self.inspect_funcs[typ](value))
- except KeyError as e:
- param = None
- for k, v in self.py_types.items():
- try:
- if isinstance(value, k):
- param = v
- break
- except TypeError:
- pass
-
- if param is None:
- for k, v in self.inspect_funcs.items():
- try:
- if isinstance(value, k):
- param = v(value)
- break
- except TypeError:
- pass
- except KeyError:
- pass
-
- if param is None:
- raise NotSupportedError(
- "type " + str(e) + " not mapped to pg type")
- else:
- params.append(param)
-
- return tuple(params)
-
- def handle_ROW_DESCRIPTION(self, data, cursor):
- count = h_unpack(data)[0]
- idx = 2
- for i in range(count):
- name = data[idx:data.find(NULL_BYTE, idx)]
- idx += len(name) + 1
- field = dict(
- zip((
- "table_oid", "column_attrnum", "type_oid", "type_size",
- "type_modifier", "format"), ihihih_unpack(data, idx)))
- field['name'] = name
- idx += 18
- cursor.ps['row_desc'].append(field)
- field['pg8000_fc'], field['func'] = \
- self.pg_types[field['type_oid']]
-
- def execute(self, cursor, operation, vals):
- if vals is None:
- vals = ()
-
- paramstyle = pg8000.paramstyle
- pid = getpid()
- try:
- cache = self._caches[paramstyle][pid]
- except KeyError:
- try:
- param_cache = self._caches[paramstyle]
- except KeyError:
- param_cache = self._caches[paramstyle] = {}
-
- try:
- cache = param_cache[pid]
- except KeyError:
- cache = param_cache[pid] = {'statement': {}, 'ps': {}}
-
- try:
- statement, make_args = cache['statement'][operation]
- except KeyError:
- statement, make_args = cache['statement'][operation] = \
- convert_paramstyle(paramstyle, operation)
-
- args = make_args(vals)
- params = self.make_params(args)
- key = operation, params
-
- try:
- ps = cache['ps'][key]
- cursor.ps = ps
- except KeyError:
- statement_nums = [0]
- for style_cache in self._caches.values():
- try:
- pid_cache = style_cache[pid]
- for csh in pid_cache['ps'].values():
- statement_nums.append(csh['statement_num'])
- except KeyError:
- pass
-
- statement_num = sorted(statement_nums)[-1] + 1
- statement_name = '_'.join(
- ("pg8000", "statement", str(pid), str(statement_num)))
- statement_name_bin = statement_name.encode('ascii') + NULL_BYTE
- ps = {
- 'statement_name_bin': statement_name_bin,
- 'pid': pid,
- 'statement_num': statement_num,
- 'row_desc': [],
- 'param_funcs': tuple(x[2] for x in params)}
- cursor.ps = ps
-
- param_fcs = tuple(x[1] for x in params)
-
- # Byte1('P') - Identifies the message as a Parse command.
- # Int32 - Message length, including self.
- # String - Prepared statement name. An empty string selects the
- # unnamed prepared statement.
- # String - The query string.
- # Int16 - Number of parameter data types specified (can be zero).
- # For each parameter:
- # Int32 - The OID of the parameter data type.
- val = bytearray(statement_name_bin)
- val.extend(statement.encode(self._client_encoding) + NULL_BYTE)
- val.extend(h_pack(len(params)))
- for oid, fc, send_func in params:
- # Parse message doesn't seem to handle the -1 type_oid for NULL
- # values that other messages handle. So we'll provide type_oid
- # 705, the PG "unknown" type.
- val.extend(i_pack(705 if oid == -1 else oid))
-
- # Byte1('D') - Identifies the message as a describe command.
- # Int32 - Message length, including self.
- # Byte1 - 'S' for prepared statement, 'P' for portal.
- # String - The name of the item to describe.
- self._send_message(PARSE, val)
- self._send_message(DESCRIBE, STATEMENT + statement_name_bin)
- self._write(SYNC_MSG)
-
- try:
- self._flush()
- except AttributeError as e:
- if self._sock is None:
- raise InterfaceError("connection is closed")
- else:
- raise e
-
- self.handle_messages(cursor)
-
- # We've got row_desc that allows us to identify what we're
- # going to get back from this statement.
- output_fc = tuple(
- self.pg_types[f['type_oid']][0] for f in ps['row_desc'])
-
- ps['input_funcs'] = tuple(f['func'] for f in ps['row_desc'])
- # Byte1('B') - Identifies the Bind command.
- # Int32 - Message length, including self.
- # String - Name of the destination portal.
- # String - Name of the source prepared statement.
- # Int16 - Number of parameter format codes.
- # For each parameter format code:
- # Int16 - The parameter format code.
- # Int16 - Number of parameter values.
- # For each parameter value:
- # Int32 - The length of the parameter value, in bytes, not
- # including this length. -1 indicates a NULL parameter
- # value, in which no value bytes follow.
- # Byte[n] - Value of the parameter.
- # Int16 - The number of result-column format codes.
- # For each result-column format code:
- # Int16 - The format code.
- ps['bind_1'] = NULL_BYTE + statement_name_bin + \
- h_pack(len(params)) + \
- pack("!" + "h" * len(param_fcs), *param_fcs) + \
- h_pack(len(params))
-
- ps['bind_2'] = h_pack(len(output_fc)) + \
- pack("!" + "h" * len(output_fc), *output_fc)
-
- if len(cache['ps']) > self.max_prepared_statements:
- for p in cache['ps'].values():
- self.close_prepared_statement(p['statement_name_bin'])
- cache['ps'].clear()
-
- cache['ps'][key] = ps
-
- cursor._cached_rows.clear()
- cursor._row_count = -1
-
- # Byte1('B') - Identifies the Bind command.
- # Int32 - Message length, including self.
- # String - Name of the destination portal.
- # String - Name of the source prepared statement.
- # Int16 - Number of parameter format codes.
- # For each parameter format code:
- # Int16 - The parameter format code.
- # Int16 - Number of parameter values.
- # For each parameter value:
- # Int32 - The length of the parameter value, in bytes, not
- # including this length. -1 indicates a NULL parameter
- # value, in which no value bytes follow.
- # Byte[n] - Value of the parameter.
- # Int16 - The number of result-column format codes.
- # For each result-column format code:
- # Int16 - The format code.
- retval = bytearray(ps['bind_1'])
- for value, send_func in zip(args, ps['param_funcs']):
- if value is None:
- val = NULL
- else:
- val = send_func(value)
- retval.extend(i_pack(len(val)))
- retval.extend(val)
- retval.extend(ps['bind_2'])
-
- self._send_message(BIND, retval)
- self.send_EXECUTE(cursor)
- self._write(SYNC_MSG)
- self._flush()
- self.handle_messages(cursor)
-
- def _send_message(self, code, data):
- try:
- self._write(code)
- self._write(i_pack(len(data) + 4))
- self._write(data)
- self._write(FLUSH_MSG)
- except ValueError as e:
- if str(e) == "write to closed file":
- raise InterfaceError("connection is closed")
- else:
- raise e
- except AttributeError:
- raise InterfaceError("connection is closed")
-
- def send_EXECUTE(self, cursor):
- # Byte1('E') - Identifies the message as an execute message.
- # Int32 - Message length, including self.
- # String - The name of the portal to execute.
- # Int32 - Maximum number of rows to return, if portal
- # contains a query # that returns rows.
- # 0 = no limit.
- self._write(EXECUTE_MSG)
- self._write(FLUSH_MSG)
-
- def handle_NO_DATA(self, msg, ps):
- pass
-
- def handle_COMMAND_COMPLETE(self, data, cursor):
- values = data[:-1].split(b' ')
- command = values[0]
- if command in self._commands_with_count:
- row_count = int(values[-1])
- if cursor._row_count == -1:
- cursor._row_count = row_count
- else:
- cursor._row_count += row_count
-
- if command in (b"ALTER", b"CREATE"):
- for scache in self._caches.values():
- for pcache in scache.values():
- for ps in pcache['ps'].values():
- self.close_prepared_statement(ps['statement_name_bin'])
- pcache['ps'].clear()
-
- def handle_DATA_ROW(self, data, cursor):
- data_idx = 2
- row = []
- for func in cursor.ps['input_funcs']:
- vlen = i_unpack(data, data_idx)[0]
- data_idx += 4
- if vlen == -1:
- row.append(None)
- else:
- row.append(func(data, data_idx, vlen))
- data_idx += vlen
- cursor._cached_rows.append(row)
-
- def handle_messages(self, cursor):
- code = self.error = None
-
- while code != READY_FOR_QUERY:
- code, data_len = ci_unpack(self._read(5))
- self.message_types[code](self._read(data_len - 4), cursor)
-
- if self.error is not None:
- raise self.error
-
- # Byte1('C') - Identifies the message as a close command.
- # Int32 - Message length, including self.
- # Byte1 - 'S' for prepared statement, 'P' for portal.
- # String - The name of the item to close.
- def close_prepared_statement(self, statement_name_bin):
- self._send_message(CLOSE, STATEMENT + statement_name_bin)
- self._write(SYNC_MSG)
- self._flush()
- self.handle_messages(self._cursor)
-
- # Byte1('N') - Identifier
- # Int32 - Message length
- # Any number of these, followed by a zero byte:
- # Byte1 - code identifying the field type (see responseKeys)
- # String - field value
- def handle_NOTICE_RESPONSE(self, data, ps):
- self.notices.append(
- dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE)))
-
- def handle_PARAMETER_STATUS(self, data, ps):
- pos = data.find(NULL_BYTE)
- key, value = data[:pos], data[pos + 1:-1]
- self.parameter_statuses.append((key, value))
- if key == b"client_encoding":
- encoding = value.decode("ascii").lower()
- self._client_encoding = pg_to_py_encodings.get(encoding, encoding)
-
- elif key == b"integer_datetimes":
- if value == b'on':
-
- self.py_types[1114] = (1114, FC_BINARY, timestamp_send_integer)
- self.pg_types[1114] = (FC_BINARY, timestamp_recv_integer)
-
- self.py_types[1184] = (
- 1184, FC_BINARY, timestamptz_send_integer)
- self.pg_types[1184] = (FC_BINARY, timestamptz_recv_integer)
-
- self.py_types[Interval] = (
- 1186, FC_BINARY, interval_send_integer)
- self.py_types[Timedelta] = (
- 1186, FC_BINARY, interval_send_integer)
- self.pg_types[1186] = (FC_BINARY, interval_recv_integer)
- else:
- self.py_types[1114] = (1114, FC_BINARY, timestamp_send_float)
- self.pg_types[1114] = (FC_BINARY, timestamp_recv_float)
- self.py_types[1184] = (1184, FC_BINARY, timestamptz_send_float)
- self.pg_types[1184] = (FC_BINARY, timestamptz_recv_float)
-
- self.py_types[Interval] = (
- 1186, FC_BINARY, interval_send_float)
- self.py_types[Timedelta] = (
- 1186, FC_BINARY, interval_send_float)
- self.pg_types[1186] = (FC_BINARY, interval_recv_float)
-
- elif key == b"server_version":
- self._server_version = LooseVersion(value.decode('ascii'))
- if self._server_version < LooseVersion('8.2.0'):
- self._commands_with_count = (
- b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH")
- elif self._server_version < LooseVersion('9.0.0'):
- self._commands_with_count = (
- b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH",
- b"COPY")
-
- def array_inspect(self, value):
- # Check if array has any values. If empty, we can just assume it's an
- # array of strings
- first_element = array_find_first_element(value)
- if first_element is None:
- oid = 25
- # Use binary ARRAY format to avoid having to properly
- # escape text in the array literals
- fc = FC_BINARY
- array_oid = pg_array_types[oid]
- else:
- # supported array output
- typ = type(first_element)
-
- if issubclass(typ, int):
- # special int array support -- send as smallest possible array
- # type
- typ = int
- int2_ok, int4_ok, int8_ok = True, True, True
- for v in array_flatten(value):
- if v is None:
- continue
- if min_int2 < v < max_int2:
- continue
- int2_ok = False
- if min_int4 < v < max_int4:
- continue
- int4_ok = False
- if min_int8 < v < max_int8:
- continue
- int8_ok = False
- if int2_ok:
- array_oid = 1005 # INT2[]
- oid, fc, send_func = (21, FC_BINARY, h_pack)
- elif int4_ok:
- array_oid = 1007 # INT4[]
- oid, fc, send_func = (23, FC_BINARY, i_pack)
- elif int8_ok:
- array_oid = 1016 # INT8[]
- oid, fc, send_func = (20, FC_BINARY, q_pack)
- else:
- raise ArrayContentNotSupportedError(
- "numeric not supported as array contents")
- else:
- try:
- oid, fc, send_func = self.make_params((first_element,))[0]
-
- # If unknown or string, assume it's a string array
- if oid in (705, 1043, 25):
- oid = 25
- # Use binary ARRAY format to avoid having to properly
- # escape text in the array literals
- fc = FC_BINARY
- array_oid = pg_array_types[oid]
- except KeyError:
- raise ArrayContentNotSupportedError(
- "oid " + str(oid) + " not supported as array contents")
- except NotSupportedError:
- raise ArrayContentNotSupportedError(
- "type " + str(typ) +
- " not supported as array contents")
- if fc == FC_BINARY:
- def send_array(arr):
- # check that all array dimensions are consistent
- array_check_dimensions(arr)
-
- has_null = array_has_null(arr)
- dim_lengths = array_dim_lengths(arr)
- data = bytearray(iii_pack(len(dim_lengths), has_null, oid))
- for i in dim_lengths:
- data.extend(ii_pack(i, 1))
- for v in array_flatten(arr):
- if v is None:
- data += i_pack(-1)
- elif isinstance(v, typ):
- inner_data = send_func(v)
- data += i_pack(len(inner_data))
- data += inner_data
- else:
- raise ArrayContentNotHomogenousError(
- "not all array elements are of type " + str(typ))
- return data
- else:
- def send_array(arr):
- array_check_dimensions(arr)
- ar = deepcopy(arr)
- for a, i, v in walk_array(ar):
- if v is None:
- a[i] = 'NULL'
- elif isinstance(v, typ):
- a[i] = send_func(v).decode('ascii')
- else:
- raise ArrayContentNotHomogenousError(
- "not all array elements are of type " + str(typ))
- return str(ar).translate(arr_trans).encode('ascii')
-
- return (array_oid, fc, send_array)
-
- def xid(self, format_id, global_transaction_id, branch_qualifier):
- """Create a Transaction IDs (only global_transaction_id is used in pg)
- format_id and branch_qualifier are not used in postgres
- global_transaction_id may be any string identifier supported by
- postgres returns a tuple
- (format_id, global_transaction_id, branch_qualifier)"""
- return (format_id, global_transaction_id, branch_qualifier)
-
- def tpc_begin(self, xid):
- """Begins a TPC transaction with the given transaction ID xid.
-
- This method should be called outside of a transaction (i.e. nothing may
- have executed since the last .commit() or .rollback()).
-
- Furthermore, it is an error to call .commit() or .rollback() within the
- TPC transaction. A ProgrammingError is raised, if the application calls
- .commit() or .rollback() during an active TPC transaction.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- self._xid = xid
- if self.autocommit:
- self.execute(self._cursor, "begin transaction", None)
-
- def tpc_prepare(self):
- """Performs the first phase of a transaction started with .tpc_begin().
- A ProgrammingError is be raised if this method is called outside of a
- TPC transaction.
-
- After calling .tpc_prepare(), no statements can be executed until
- .tpc_commit() or .tpc_rollback() have been called.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- q = "PREPARE TRANSACTION '%s';" % (self._xid[1],)
- self.execute(self._cursor, q, None)
-
- def tpc_commit(self, xid=None):
- """When called with no arguments, .tpc_commit() commits a TPC
- transaction previously prepared with .tpc_prepare().
-
- If .tpc_commit() is called prior to .tpc_prepare(), a single phase
- commit is performed. A transaction manager may choose to do this if
- only a single resource is participating in the global transaction.
-
- When called with a transaction ID xid, the database commits the given
- transaction. If an invalid transaction ID is provided, a
- ProgrammingError will be raised. This form should be called outside of
- a transaction, and is intended for use in recovery.
-
- On return, the TPC transaction is ended.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- if xid is None:
- xid = self._xid
-
- if xid is None:
- raise ProgrammingError(
- "Cannot tpc_commit() without a TPC transaction!")
-
- try:
- previous_autocommit_mode = self.autocommit
- self.autocommit = True
- if xid in self.tpc_recover():
- self.execute(
- self._cursor, "COMMIT PREPARED '%s';" % (xid[1], ),
- None)
- else:
- # a single-phase commit
- self.commit()
- finally:
- self.autocommit = previous_autocommit_mode
- self._xid = None
-
- def tpc_rollback(self, xid=None):
- """When called with no arguments, .tpc_rollback() rolls back a TPC
- transaction. It may be called before or after .tpc_prepare().
-
- When called with a transaction ID xid, it rolls back the given
- transaction. If an invalid transaction ID is provided, a
- ProgrammingError is raised. This form should be called outside of a
- transaction, and is intended for use in recovery.
-
- On return, the TPC transaction is ended.
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- if xid is None:
- xid = self._xid
-
- if xid is None:
- raise ProgrammingError(
- "Cannot tpc_rollback() without a TPC prepared transaction!")
-
- try:
- previous_autocommit_mode = self.autocommit
- self.autocommit = True
- if xid in self.tpc_recover():
- # a two-phase rollback
- self.execute(
- self._cursor, "ROLLBACK PREPARED '%s';" % (xid[1],),
- None)
- else:
- # a single-phase rollback
- self.rollback()
- finally:
- self.autocommit = previous_autocommit_mode
- self._xid = None
-
- def tpc_recover(self):
- """Returns a list of pending transaction IDs suitable for use with
- .tpc_commit(xid) or .tpc_rollback(xid).
-
- This function is part of the `DBAPI 2.0 specification
- `_.
- """
- try:
- previous_autocommit_mode = self.autocommit
- self.autocommit = True
- curs = self.cursor()
- curs.execute("select gid FROM pg_prepared_xacts")
- return [self.xid(0, row[0], '') for row in curs]
- finally:
- self.autocommit = previous_autocommit_mode
-
-
-# pg element oid -> pg array typeoid
-pg_array_types = {
- 16: 1000,
- 25: 1009, # TEXT[]
- 701: 1022,
- 1043: 1009,
- 1700: 1231, # NUMERIC[]
-}
-
-
-# PostgreSQL encodings:
-# http://www.postgresql.org/docs/8.3/interactive/multibyte.html
-# Python encodings:
-# http://www.python.org/doc/2.4/lib/standard-encodings.html
-#
-# Commented out encodings don't require a name change between PostgreSQL and
-# Python. If the py side is None, then the encoding isn't supported.
-pg_to_py_encodings = {
- # Not supported:
- "mule_internal": None,
- "euc_tw": None,
-
- # Name fine as-is:
- # "euc_jp",
- # "euc_jis_2004",
- # "euc_kr",
- # "gb18030",
- # "gbk",
- # "johab",
- # "sjis",
- # "shift_jis_2004",
- # "uhc",
- # "utf8",
-
- # Different name:
- "euc_cn": "gb2312",
- "iso_8859_5": "is8859_5",
- "iso_8859_6": "is8859_6",
- "iso_8859_7": "is8859_7",
- "iso_8859_8": "is8859_8",
- "koi8": "koi8_r",
- "latin1": "iso8859-1",
- "latin2": "iso8859_2",
- "latin3": "iso8859_3",
- "latin4": "iso8859_4",
- "latin5": "iso8859_9",
- "latin6": "iso8859_10",
- "latin7": "iso8859_13",
- "latin8": "iso8859_14",
- "latin9": "iso8859_15",
- "sql_ascii": "ascii",
- "win866": "cp886",
- "win874": "cp874",
- "win1250": "cp1250",
- "win1251": "cp1251",
- "win1252": "cp1252",
- "win1253": "cp1253",
- "win1254": "cp1254",
- "win1255": "cp1255",
- "win1256": "cp1256",
- "win1257": "cp1257",
- "win1258": "cp1258",
- "unicode": "utf-8", # Needed for Amazon Redshift
-}
-
-
-def walk_array(arr):
- for i, v in enumerate(arr):
- if isinstance(v, list):
- for a, i2, v2 in walk_array(v):
- yield a, i2, v2
- else:
- yield arr, i, v
-
-
-def array_find_first_element(arr):
- for v in array_flatten(arr):
- if v is not None:
- return v
- return None
-
-
-def array_flatten(arr):
- for v in arr:
- if isinstance(v, list):
- for v2 in array_flatten(v):
- yield v2
- else:
- yield v
-
-
-def array_check_dimensions(arr):
- if len(arr) > 0:
- v0 = arr[0]
- if isinstance(v0, list):
- req_len = len(v0)
- req_inner_lengths = array_check_dimensions(v0)
- for v in arr:
- inner_lengths = array_check_dimensions(v)
- if len(v) != req_len or inner_lengths != req_inner_lengths:
- raise ArrayDimensionsNotConsistentError(
- "array dimensions not consistent")
- retval = [req_len]
- retval.extend(req_inner_lengths)
- return retval
- else:
- # make sure nothing else at this level is a list
- for v in arr:
- if isinstance(v, list):
- raise ArrayDimensionsNotConsistentError(
- "array dimensions not consistent")
- return []
-
-
-def array_has_null(arr):
- for v in array_flatten(arr):
- if v is None:
- return True
- return False
-
-
-def array_dim_lengths(arr):
- len_arr = len(arr)
- retval = [len_arr]
- if len_arr > 0:
- v0 = arr[0]
- if isinstance(v0, list):
- retval.extend(array_dim_lengths(v0))
- return retval
diff --git a/source/libraries/pg8000/scramp/__init__.py b/source/libraries/pg8000/scramp/__init__.py
deleted file mode 100644
index d679ef4..0000000
--- a/source/libraries/pg8000/scramp/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from ..scramp.core import ScramClient, ScramServer, ScramException
-
-__all__ = [ScramClient, ScramServer, ScramException]
diff --git a/source/libraries/pg8000/scramp/_version.py b/source/libraries/pg8000/scramp/_version.py
deleted file mode 100644
index b425b2a..0000000
--- a/source/libraries/pg8000/scramp/_version.py
+++ /dev/null
@@ -1,520 +0,0 @@
-
-# This file helps to compute a version number in source trees obtained from
-# git-archive tarball (such as those provided by githubs download-from-tag
-# feature). Distribution tarballs (built by setup.py sdist) and build
-# directories (produced by setup.py build) will contain a much shorter file
-# that just contains the computed version number.
-
-# This file is released into the public domain. Generated by
-# versioneer-0.18 (https://github.com/warner/python-versioneer)
-
-"""Git implementation of _version.py."""
-
-import errno
-import os
-import re
-import subprocess
-import sys
-
-
-def get_keywords():
- """Get the keywords needed to look up the version information."""
- # these strings will be replaced by git during git-archive.
- # setup.py/versioneer.py will grep for the variable names, so they must
- # each be defined on a line of their own. _version.py will just call
- # get_keywords().
- git_refnames = "$Format:%d$"
- git_full = "$Format:%H$"
- git_date = "$Format:%ci$"
- keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
- return keywords
-
-
-class VersioneerConfig:
- """Container for Versioneer configuration parameters."""
-
-
-def get_config():
- """Create, populate and return the VersioneerConfig() object."""
- # these strings are filled in when 'setup.py versioneer' creates
- # _version.py
- cfg = VersioneerConfig()
- cfg.VCS = "git"
- cfg.style = "pep440"
- cfg.tag_prefix = ""
- cfg.parentdir_prefix = "scramp-"
- cfg.versionfile_source = "scramp/_version.py"
- cfg.verbose = False
- return cfg
-
-
-class NotThisMethod(Exception):
- """Exception raised if a method is not valid for the current scenario."""
-
-
-LONG_VERSION_PY = {}
-HANDLERS = {}
-
-
-def register_vcs_handler(vcs, method): # decorator
- """Decorator to mark a method as the handler for a particular VCS."""
- def decorate(f):
- """Store f in HANDLERS[vcs][method]."""
- if vcs not in HANDLERS:
- HANDLERS[vcs] = {}
- HANDLERS[vcs][method] = f
- return f
- return decorate
-
-
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
- """Call the given command(s)."""
- assert isinstance(commands, list)
- p = None
- for c in commands:
- try:
- dispcmd = str([c] + args)
- # remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen([c] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
- break
- except EnvironmentError:
- e = sys.exc_info()[1]
- if e.errno == errno.ENOENT:
- continue
- if verbose:
- print("unable to run %s" % dispcmd)
- print(e)
- return None, None
- else:
- if verbose:
- print("unable to find command, tried %s" % (commands,))
- return None, None
- stdout = p.communicate()[0].strip()
- if sys.version_info[0] >= 3:
- stdout = stdout.decode()
- if p.returncode != 0:
- if verbose:
- print("unable to run %s (error)" % dispcmd)
- print("stdout was %s" % stdout)
- return None, p.returncode
- return stdout, p.returncode
-
-
-def versions_from_parentdir(parentdir_prefix, root, verbose):
- """Try to determine the version from the parent directory name.
-
- Source tarballs conventionally unpack into a directory that includes both
- the project name and a version string. We will also support searching up
- two directory levels for an appropriately named parent directory
- """
- rootdirs = []
-
- for i in range(3):
- dirname = os.path.basename(root)
- if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
- else:
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
-
- if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
- raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
-
-
-@register_vcs_handler("git", "get_keywords")
-def git_get_keywords(versionfile_abs):
- """Extract version information from the given file."""
- # the code embedded in _version.py can just fetch the value of these
- # keywords. When used from setup.py, we don't want to import _version.py,
- # so we do it with a regexp instead. This function is not used from
- # _version.py.
- keywords = {}
- try:
- f = open(versionfile_abs, "r")
- for line in f.readlines():
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- f.close()
- except EnvironmentError:
- pass
- return keywords
-
-
-@register_vcs_handler("git", "keywords")
-def git_versions_from_keywords(keywords, tag_prefix, verbose):
- """Get version information from git keywords."""
- if not keywords:
- raise NotThisMethod("no keywords at all, weird")
- date = keywords.get("date")
- if date is not None:
- # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
- # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
- # -like" string, which we must then edit to make compliant), because
- # it's been around since git-1.5.3, and it's too difficult to
- # discover which version we're using, or to work around using an
- # older one.
- date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
- refnames = keywords["refnames"].strip()
- if refnames.startswith("$Format"):
- if verbose:
- print("keywords are unexpanded, not using")
- raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = set([r.strip() for r in refnames.strip("()").split(",")])
- # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
- # just "foo-1.0". If we see a "tag: " prefix, prefer those.
- TAG = "tag: "
- tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
- if not tags:
- # Either we're using git < 1.8.3, or there really are no tags. We use
- # a heuristic: assume all version tags have a digit. The old git %d
- # expansion behaves like git log --decorate=short and strips out the
- # refs/heads/ and refs/tags/ prefixes that would let us distinguish
- # between branches and tags. By ignoring refnames without digits, we
- # filter out many common branch names like "release" and
- # "stabilization", as well as "HEAD" and "master".
- tags = set([r for r in refs if re.search(r'\d', r)])
- if verbose:
- print("discarding '%s', no digits" % ",".join(refs - tags))
- if verbose:
- print("likely tags: %s" % ",".join(sorted(tags)))
- for ref in sorted(tags):
- # sorting will prefer e.g. "2.0" over "2.0rc1"
- if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
- if verbose:
- print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
- # no suitable tags, so version is "0+unknown", but full hex is still there
- if verbose:
- print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
-
-
-@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
- """Get version from 'git describe' in the root of the source tree.
-
- This only gets called if the git-archive 'subst' keywords were *not*
- expanded, and _version.py hasn't already been rewritten with a short
- version string, meaning we're inside a checked out source tree.
- """
- GITS = ["git"]
- if sys.platform == "win32":
- GITS = ["git.cmd", "git.exe"]
-
- out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
- if rc != 0:
- if verbose:
- print("Directory %s not under git control" % root)
- raise NotThisMethod("'git rev-parse --git-dir' returned error")
-
- # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
- # if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match", "%s*" % tag_prefix],
- cwd=root)
- # --long was added in git-1.5.5
- if describe_out is None:
- raise NotThisMethod("'git describe' failed")
- describe_out = describe_out.strip()
- full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
- if full_out is None:
- raise NotThisMethod("'git rev-parse' failed")
- full_out = full_out.strip()
-
- pieces = {}
- pieces["long"] = full_out
- pieces["short"] = full_out[:7] # maybe improved later
- pieces["error"] = None
-
- # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
- # TAG might have hyphens.
- git_describe = describe_out
-
- # look for -dirty suffix
- dirty = git_describe.endswith("-dirty")
- pieces["dirty"] = dirty
- if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
-
- # now we have TAG-NUM-gHEX or HEX
-
- if "-" in git_describe:
- # TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
- if not mo:
- # unparseable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
- return pieces
-
- # tag
- full_tag = mo.group(1)
- if not full_tag.startswith(tag_prefix):
- if verbose:
- fmt = "tag '%s' doesn't start with prefix '%s'"
- print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
- return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
-
- # distance: number of commits since tag
- pieces["distance"] = int(mo.group(2))
-
- # commit: short hex revision ID
- pieces["short"] = mo.group(3)
-
- else:
- # HEX: no tags
- pieces["closest-tag"] = None
- count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
- cwd=root)
- pieces["distance"] = int(count_out) # total number of commits
-
- # commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
- cwd=root)[0].strip()
- pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
-
- return pieces
-
-
-def plus_or_dot(pieces):
- """Return a + if we don't already have one, else return a ."""
- if "+" in pieces.get("closest-tag", ""):
- return "."
- return "+"
-
-
-def render_pep440(pieces):
- """Build up version string, with post-release "local version identifier".
-
- Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
- get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
-
- Exceptions:
- 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += plus_or_dot(pieces)
- rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_pre(pieces):
- """TAG[.post.devDISTANCE] -- No -dirty.
-
- Exceptions:
- 1: no tags. 0.post.devDISTANCE
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"]:
- rendered += ".post.dev%d" % pieces["distance"]
- else:
- # exception #1
- rendered = "0.post.dev%d" % pieces["distance"]
- return rendered
-
-
-def render_pep440_post(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX] .
-
- The ".dev0" means dirty. Note that .dev0 sorts backwards
- (a dirty tree will appear "older" than the corresponding clean one),
- but you shouldn't be releasing software with -dirty anyways.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%s" % pieces["short"]
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += "+g%s" % pieces["short"]
- return rendered
-
-
-def render_pep440_old(pieces):
- """TAG[.postDISTANCE[.dev0]] .
-
- The ".dev0" means dirty.
-
- Eexceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- return rendered
-
-
-def render_git_describe(pieces):
- """TAG[-DISTANCE-gHEX][-dirty].
-
- Like 'git describe --tags --dirty --always'.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"]:
- rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render_git_describe_long(pieces):
- """TAG-DISTANCE-gHEX[-dirty].
-
- Like 'git describe --tags --dirty --always -long'.
- The distance/hash is unconditional.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render(pieces, style):
- """Render the given version pieces into the requested style."""
- if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
-
- if not style or style == "default":
- style = "pep440" # the default
-
- if style == "pep440":
- rendered = render_pep440(pieces)
- elif style == "pep440-pre":
- rendered = render_pep440_pre(pieces)
- elif style == "pep440-post":
- rendered = render_pep440_post(pieces)
- elif style == "pep440-old":
- rendered = render_pep440_old(pieces)
- elif style == "git-describe":
- rendered = render_git_describe(pieces)
- elif style == "git-describe-long":
- rendered = render_git_describe_long(pieces)
- else:
- raise ValueError("unknown style '%s'" % style)
-
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
-
-
-def get_versions():
- """Get version information or return default if unable to do so."""
- # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
- # __file__, we can work backwards from there to the root. Some
- # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
- # case we can only use expanded keywords.
-
- cfg = get_config()
- verbose = cfg.verbose
-
- try:
- return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
- verbose)
- except NotThisMethod:
- pass
-
- try:
- root = os.path.realpath(__file__)
- # versionfile_source is the relative path from the top of the source
- # tree (where the .git directory might live) to this file. Invert
- # this to find the root from __file__.
- for i in cfg.versionfile_source.split('/'):
- root = os.path.dirname(root)
- except NameError:
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to find root of source tree",
- "date": None}
-
- try:
- pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
- return render(pieces, cfg.style)
- except NotThisMethod:
- pass
-
- try:
- if cfg.parentdir_prefix:
- return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
- except NotThisMethod:
- pass
-
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to compute version", "date": None}
diff --git a/source/libraries/pg8000/scramp/core.py b/source/libraries/pg8000/scramp/core.py
deleted file mode 100644
index 20e4ea2..0000000
--- a/source/libraries/pg8000/scramp/core.py
+++ /dev/null
@@ -1,354 +0,0 @@
-import hmac
-from uuid import uuid4
-from base64 import b64encode, b64decode
-import hashlib
-from stringprep import (
- in_table_a1, in_table_b1, in_table_c21_c22, in_table_c3, in_table_c4,
- in_table_c5, in_table_c6, in_table_c7, in_table_c8, in_table_c9,
- in_table_c12, in_table_d1, in_table_d2)
-import unicodedata
-from os import urandom
-from enum import IntEnum, unique
-
-# https://tools.ietf.org/html/rfc5802
-# https://www.rfc-editor.org/rfc/rfc7677.txt
-
-
-@unique
-class ClientStage(IntEnum):
- get_client_first = 1
- set_server_first = 2
- get_client_final = 3
- set_server_final = 4
-
-
-@unique
-class ServerStage(IntEnum):
- set_client_first = 1
- get_server_first = 2
- set_client_final = 3
- get_server_final = 4
-
-
-def _check_stage(Stages, current_stage, next_stage):
- if current_stage is None:
- if next_stage != 1:
- raise ScramException(
- "The method " + Stages(1).name + " must be called first.")
- elif current_stage == 4:
- raise ScramException(
- "The authentication sequence has already finished.")
- elif next_stage != current_stage + 1:
- raise ScramException(
- "The next method to be called is " + Stages(current_stage + 1) +
- ", not this method.")
-
-
-class ScramException(Exception):
- pass
-
-
-MECHANISMS = ('SCRAM-SHA-1', 'SCRAM-SHA-256')
-
-
-HASHES = {
- 'SCRAM-SHA-1': hashlib.sha1,
- 'SCRAM-SHA-256': hashlib.sha256
-}
-
-
-class ScramClient():
- def __init__(self, mechanisms, username, password, c_nonce=None):
- self.mech = None
- for mech in MECHANISMS:
- if mech in mechanisms:
- self.mech = mech
-
- if self.mech is None:
- raise ScramException(
- "The only recognized mechanisms are " + str(MECHANISMS) +
- "and none of those can be found in " + mechanisms + ".")
-
- self.hf = HASHES[self.mech]
-
- if c_nonce is None:
- self.c_nonce = _make_nonce()
- else:
- self.c_nonce = c_nonce
-
- self.username = username
- self.password = password
- self.stage = None
-
- def _set_stage(self, next_stage):
- _check_stage(ClientStage, self.stage, next_stage)
- self.stage = next_stage
-
- def get_client_first(self):
- self._set_stage(ClientStage.get_client_first)
- self.client_first_bare, client_first = _get_client_first(
- self.username, self.c_nonce)
- return client_first
-
- def set_server_first(self, message):
- self._set_stage(ClientStage.set_server_first)
- self.server_first = message
- self.auth_message, self.nonce, self.salt, self.iterations = \
- _set_server_first(message, self.c_nonce, self.client_first_bare)
-
- def get_client_final(self):
- self._set_stage(ClientStage.get_client_final)
- self.server_signature, cfinal = _get_client_final(
- self.hf, self.password, self.salt, self.iterations, self.nonce,
- self.auth_message)
- return cfinal
-
- def set_server_final(self, message):
- self._set_stage(ClientStage.set_server_final)
- _set_server_final(message, self.server_signature)
-
-
-class ScramServer():
- def __init__(
- self, password_fn, s_nonce=None, iterations=4096, salt=None,
- mechanism='SCRAM-SHA-256'):
- if mechanism not in MECHANISMS:
- raise ScramException(
- "The only recognized mechanisms are " + str(MECHANISMS) +
- ".")
- self.mechanism = mechanism
- self.hf = HASHES[self.mechanism]
-
- if s_nonce is None:
- self.s_nonce = _make_nonce()
- else:
- self.s_nonce = s_nonce
-
- if salt is None:
- self.salt = _b64enc(urandom(16))
- else:
- self.salt = salt
-
- self.password_fn = password_fn
- self.iterations = iterations
- self.stage = None
-
- def _set_stage(self, next_stage):
- _check_stage(ServerStage, self.stage, next_stage)
- self.stage = next_stage
-
- def set_client_first(self, client_first):
- self._set_stage(ServerStage.set_client_first)
- self.nonce, self.user, self.client_first_bare = _set_client_first(
- client_first, self.s_nonce)
- self.password = self.password_fn(self.user)
-
- def get_server_first(self):
- self._set_stage(ServerStage.get_server_first)
- self.auth_message, server_first = _get_server_first(
- self.nonce, self.salt, self.iterations, self.client_first_bare)
- return server_first
-
- def set_client_final(self, client_final):
- self._set_stage(ServerStage.set_client_final)
- self.server_signature = _set_client_final(
- self.hf, client_final, self.s_nonce, self.password, self.salt,
- self.iterations, self.auth_message)
-
- def get_server_final(self):
- self._set_stage(ServerStage.get_server_final)
- return _get_server_final(self.server_signature)
-
-
-def _make_nonce():
- return str(uuid4()).replace('-', '')
-
-
-def _make_auth_message(nonce, client_first_bare, server_first):
- msg = client_first_bare, server_first, 'c=' + _b64enc(b'n,,'), 'r=' + nonce
- return ','.join(msg)
-
-
-def _proof_signature(hf, password, salt, iterations, auth_msg):
- salted_password = _hi(
- hf, _uenc(saslprep(password)), _b64dec(salt), iterations)
- client_key = _hmac(hf, salted_password, b"Client Key")
- stored_key = _h(hf, client_key)
-
- client_signature = _hmac(hf, stored_key, _uenc(auth_msg))
- client_proof = _xor(client_key, client_signature)
-
- server_key = _hmac(hf, salted_password, b"Server Key")
- server_signature = _hmac(hf, server_key, _uenc(auth_msg))
- return _b64enc(client_proof), _b64enc(server_signature)
-
-
-def _hmac(hf, key, msg):
- return hmac.new(key, msg=msg, digestmod=hf).digest()
-
-
-def _h(hf, msg):
- return hf(msg).digest()
-
-
-def _hi(hf, password, salt, iterations):
- u = ui = _hmac(hf, password, salt + b'\x00\x00\x00\x01')
- for i in range(iterations - 1):
- ui = _hmac(hf, password, ui)
- u = _xor(u, ui)
- return u
-
-
-def _hi_iter(password, mac, iterations):
- if iterations == 0:
- return mac
- else:
- new_mac = _hmac(password, mac)
- return _xor(_hi_iter(password, new_mac, iterations-1), mac)
-
-
-def _parse_message(msg):
- return dict((e[0], e[2:]) for e in msg.split(',') if len(e) > 1)
-
-
-def _b64enc(binary):
- return b64encode(binary).decode('utf8')
-
-
-def _b64dec(string):
- return b64decode(string)
-
-
-def _uenc(string):
- return string.encode('utf-8')
-
-
-def _xor(bytes1, bytes2):
- return bytes(a ^ b for a, b in zip(bytes1, bytes2))
-
-
-def _get_client_first(username, c_nonce):
- bare = ','.join(('n=' + saslprep(username), 'r=' + c_nonce))
- return bare, 'n,,' + bare
-
-
-def _set_client_first(client_first, s_nonce):
- msg = _parse_message(client_first)
- c_nonce = msg['r']
- nonce = c_nonce + s_nonce
- user = msg['n']
- client_first_bare = client_first[3:]
-
- return nonce, user, client_first_bare
-
-
-def _get_server_first(nonce, salt, iterations, client_first_bare):
- sfirst = ','.join(('r=' + nonce, 's=' + salt, 'i=' + str(iterations)))
- auth_msg = _make_auth_message(nonce, client_first_bare, sfirst)
- return auth_msg, sfirst
-
-
-def _set_server_first(server_first, c_nonce, client_first_bare):
- msg = _parse_message(server_first)
- nonce = msg['r']
- salt = msg['s']
- iterations = int(msg['i'])
-
- if not nonce.startswith(c_nonce):
- raise ScramException("Client nonce doesn't match.")
-
- auth_msg = _make_auth_message(nonce, client_first_bare, server_first)
- return auth_msg, nonce, salt, iterations
-
-
-def _get_client_final(hf, password, salt, iterations, nonce, auth_msg):
- client_proof, server_signature = _proof_signature(
- hf, password, salt, iterations, auth_msg)
-
- message = ['c=' + _b64enc(b'n,,'), 'r=' + nonce, 'p=' + client_proof]
- return server_signature, ','.join(message)
-
-
-def _set_client_final(
- hf, client_final, s_nonce, password, salt, iterations, auth_msg):
-
- msg = _parse_message(client_final)
- nonce = msg['r']
- proof = msg['p']
-
- if not nonce.endswith(s_nonce):
- raise ScramException("Server nonce doesn't match.")
-
- client_proof, server_signature = _proof_signature(
- hf, password, salt, iterations, auth_msg)
-
- if client_proof != proof:
- raise ScramException("The proofs don't match")
-
- return server_signature
-
-
-def _get_server_final(server_signature):
- return 'v=' + server_signature
-
-
-def _set_server_final(message, server_signature):
- msg = _parse_message(message)
- if server_signature != msg['v']:
- raise ScramException("The server signature doesn't match.")
-
-
-def saslprep(source):
- # mapping stage
- # - map non-ascii spaces to U+0020 (stringprep C.1.2)
- # - strip 'commonly mapped to nothing' chars (stringprep B.1)
- data = ''.join(
- ' ' if in_table_c12(c) else c for c in source if not in_table_b1(c))
-
- # normalize to KC form
- data = unicodedata.normalize('NFKC', data)
- if not data:
- return ''
-
- # check for invalid bi-directional strings.
- # stringprep requires the following:
- # - chars in C.8 must be prohibited.
- # - if any R/AL chars in string:
- # - no L chars allowed in string
- # - first and last must be R/AL chars
- # this checks if start/end are R/AL chars. if so, prohibited loop
- # will forbid all L chars. if not, prohibited loop will forbid all
- # R/AL chars instead. in both cases, prohibited loop takes care of C.8.
- is_ral_char = in_table_d1
- if is_ral_char(data[0]):
- if not is_ral_char(data[-1]):
- raise ValueError("malformed bidi sequence")
- # forbid L chars within R/AL sequence.
- is_forbidden_bidi_char = in_table_d2
- else:
- # forbid R/AL chars if start not setup correctly; L chars allowed.
- is_forbidden_bidi_char = is_ral_char
-
- # check for prohibited output
- # stringprep tables A.1, B.1, C.1.2, C.2 - C.9
- for c in data:
- # check for chars mapping stage should have removed
- assert not in_table_b1(c), "failed to strip B.1 in mapping stage"
- assert not in_table_c12(c), "failed to replace C.1.2 in mapping stage"
-
- # check for forbidden chars
- for f, msg in (
- (in_table_a1, "unassigned code points forbidden"),
- (in_table_c21_c22, "control characters forbidden"),
- (in_table_c3, "private use characters forbidden"),
- (in_table_c4, "non-char code points forbidden"),
- (in_table_c5, "surrogate codes forbidden"),
- (in_table_c6, "non-plaintext chars forbidden"),
- (in_table_c7, "non-canonical chars forbidden"),
- (in_table_c8, "display-modifying/deprecated chars forbidden"),
- (in_table_c9, "tagged characters forbidden"),
- (is_forbidden_bidi_char, "forbidden bidi character")):
- if f(c):
- raise ValueError(msg)
-
- return data
diff --git a/source/zaz.py b/source/zaz.py
index 71fe021..dee3df0 100644
--- a/source/zaz.py
+++ b/source/zaz.py
@@ -44,6 +44,15 @@ from conf import (
class LiboXML(object):
+ CONTEXT = {
+ 'calc': 'com.sun.star.sheet.SpreadsheetDocument',
+ 'writer': 'com.sun.star.text.TextDocument',
+ 'impress': 'com.sun.star.presentation.PresentationDocument',
+ 'draw': 'com.sun.star.drawing.DrawingDocument',
+ 'base': 'com.sun.star.sdb.OfficeDatabaseDocument',
+ 'math': 'com.sun.star.formula.FormulaProperties',
+ 'basic': 'com.sun.star.script.BasicIDE',
+ }
TYPES = {
'py': 'application/vnd.sun.star.uno-component;type=Python',
'zip': 'application/binary',
@@ -63,10 +72,21 @@ class LiboXML(object):
'xmlns:xlink': 'http://www.w3.org/1999/xlink',
'xmlns:d': 'http://openoffice.org/extensions/description/2006',
}
+ NS_ADDONS = {
+ 'xmlns:xs': 'http://www.w3.org/2001/XMLSchema',
+ 'xmlns:oor': 'http://openoffice.org/2001/registry',
+ }
+ NS_UPDATE = {
+ 'xmlns': 'http://openoffice.org/extensions/update/2006',
+ 'xmlns:d': 'http://openoffice.org/extensions/description/2006',
+ 'xmlns:xlink': 'http://www.w3.org/1999/xlink',
+ }
def __init__(self):
self._manifest = None
self._paths = []
+ self._path_images = ''
+ self._toolbars = []
def _save_path(self, attr):
self._paths.append(attr['{{{}}}full-path'.format(self.NS_MANIFEST['manifest'])])
@@ -173,6 +193,158 @@ class LiboXML(object):
}
ET.SubElement(node, 'license-text', attr)
+ if data['update']:
+ node = ET.SubElement(doc, 'update-information')
+ ET.SubElement(node, 'src', {'xlink:href': data['update']})
+
+ return self._get_xml(doc)
+
+ def _get_context(self, args):
+ if not args:
+ return ''
+ context = ','.join([self.CONTEXT[v] for v in args.split(',')])
+ return context
+
+ def _add_node_value(self, node, name, value='_self'):
+ attr = {'oor:name': name, 'oor:type': 'xs:string'}
+ sn = ET.SubElement(node, 'prop', attr)
+ sn = ET.SubElement(sn, 'value')
+ sn.text = value
+ return
+
+ def _add_menu(self, id_extension, node, index, menu):
+ attr = {
+ 'oor:name': index,
+ 'oor:op': 'replace',
+ }
+ subnode = ET.SubElement(node, 'node', attr)
+ attr = {'oor:name': 'Title', 'oor:type': 'xs:string'}
+ sn1 = ET.SubElement(subnode, 'prop', attr)
+ for k, v in menu['title'].items():
+ sn2 = ET.SubElement(sn1, 'value', {'xml:lang': k})
+ sn2.text = v
+ value = self._get_context(menu['context'])
+ self._add_node_value(subnode, 'Context', value)
+
+ if 'submenu' in menu:
+ sn = ET.SubElement(subnode, 'node', {'oor:name': 'Submenu'})
+ for i, m in enumerate(menu['submenu']):
+ self._add_menu(id_extension, sn, f'{index}.s{i}', m)
+ if m.get('toolbar', False):
+ self._toolbars.append(m)
+ return
+
+ value = f"service:{id_extension}?{menu['argument']}"
+ self._add_node_value(subnode, 'URL', value)
+ self._add_node_value(subnode, 'Target')
+ value = f"%origin%/{self._path_images}/{menu['icon']}"
+ self._add_node_value(subnode, 'ImageIdentifier', value)
+ return
+
+ def new_addons(self, id_extension, data):
+ self._path_images = data['images']
+ attr = {
+ 'oor:name': 'Addons',
+ 'oor:package': 'org.openoffice.Office',
+ }
+ attr.update(self.NS_ADDONS)
+ doc = ET.Element('oor:component-data', attr)
+ parent = ET.SubElement(doc, 'node', {'oor:name': 'AddonUI'})
+ node = ET.SubElement(parent, 'node', {'oor:name': data['parent']})
+
+ op = 'fuse'
+ if data['parent'] == 'OfficeMenuBar':
+ op = 'replace'
+
+ attr = {'oor:name': id_extension, 'oor:op': op}
+ node = ET.SubElement(node, 'node', attr)
+
+ if data['parent'] == 'OfficeMenuBar':
+ attr = {'oor:name': 'Title', 'oor:type': 'xs:string'}
+ subnode = ET.SubElement(node, 'prop', attr)
+ for k, v in data['main'].items():
+ sn = ET.SubElement(subnode, 'value', {'xml:lang': k})
+ sn.text = v
+
+ self._add_node_value(node, 'Target')
+ node = ET.SubElement(node, 'node', {'oor:name': 'Submenu'})
+
+ for i, menu in enumerate(data['menus']):
+ self._add_menu(id_extension, node, f'm{i}', menu)
+ if menu.get('toolbar', False):
+ self._toolbars.append(menu)
+
+ if self._toolbars:
+ attr = {'oor:name': 'OfficeToolBar'}
+ toolbar = ET.SubElement(parent, 'node', attr)
+ attr = {'oor:name': id_extension, 'oor:op': 'replace'}
+ toolbar = ET.SubElement(toolbar, 'node', attr)
+ for t, menu in enumerate(self._toolbars):
+ self._add_menu(id_extension, toolbar, f't{t}', menu)
+
+ return self._get_xml(doc)
+
+ def _add_shortcut(self, node, key, id_extension, arg):
+ attr = {'oor:name': key, 'oor:op': 'fuse'}
+ subnode = ET.SubElement(node, 'node', attr)
+ subnode = ET.SubElement(subnode, 'prop', {'oor:name': 'Command'})
+ subnode = ET.SubElement(subnode, 'value', {'xml:lang': 'en-US'})
+ subnode.text = f"service:{id_extension}?{arg}"
+ return
+
+ def _get_acceleartors(self, menu):
+ if 'submenu' in menu:
+ for m in menu['submenu']:
+ return self._get_acceleartors(m)
+
+ if not menu.get('shortcut', ''):
+ return ''
+
+ return menu
+
+ def new_accelerators(self, id_extension, menus):
+ attr = {
+ 'oor:name': 'Accelerators',
+ 'oor:package': 'org.openoffice.Office',
+ }
+ attr.update(self.NS_ADDONS)
+ doc = ET.Element('oor:component-data', attr)
+ parent = ET.SubElement(doc, 'node', {'oor:name': 'PrimaryKeys'})
+
+ data = []
+ for m in menus:
+ info = self._get_acceleartors(m)
+ if info:
+ data.append(info)
+
+ node_global = None
+ node_modules = None
+ for m in data:
+ if m['context']:
+ if node_modules is None:
+ node_modules = ET.SubElement(
+ parent, 'node', {'oor:name': 'Modules'})
+ for app in m['context'].split(','):
+ node = ET.SubElement(
+ node_modules, 'node', {'oor:name': self.CONTEXT[app]})
+ self._add_shortcut(
+ node, m['shortcut'], id_extension, m['argument'])
+ else:
+ if node_global is None:
+ node_global = ET.SubElement(
+ parent, 'node', {'oor:name': 'Global'})
+ self._add_shortcut(
+ node_global, m['shortcut'], id_extension, m['argument'])
+
+ return self._get_xml(doc)
+
+ def new_update(self, extension, url_oxt):
+ doc = ET.Element('description', self.NS_UPDATE)
+ ET.SubElement(doc, 'identifier', {'value': extension['id']})
+ ET.SubElement(doc, 'version', {'value': extension['version']})
+ node = ET.SubElement(doc, 'update-download')
+ ET.SubElement(node, 'src', {'xlink:href': url_oxt})
+ node = ET.SubElement(doc, 'release-notes')
return self._get_xml(doc)
def _get_xml(self, doc):
@@ -214,11 +386,7 @@ def _get_files(path, filters=''):
def _compress_oxt():
log.info('Compress OXT extension...')
- path = DIRS['files']
- if not _exists(path):
- _mkdir(path)
-
- path_oxt = _join(path, FILES['oxt'])
+ path_oxt = _join(DIRS['files'], FILES['oxt'])
z = zipfile.ZipFile(path_oxt, 'w', compression=zipfile.ZIP_DEFLATED)
root_len = len(os.path.abspath(DIRS['source']))
@@ -232,10 +400,6 @@ def _compress_oxt():
z.write(fullpath, file_name, zipfile.ZIP_DEFLATED)
z.close()
- if DATA['update']:
- path_xml = _join(path, FILES['update'])
- _save(path_xml, DATA['update'])
-
log.info('Extension OXT created sucesfully...')
return
@@ -353,6 +517,10 @@ def _compile_idl():
def _update_files():
+ path_files = DIRS['files']
+ if not _exists(path_files):
+ _mkdir(path_files)
+
path_source = DIRS['source']
for k, v in INFO.items():
@@ -388,27 +556,33 @@ def _update_files():
data = xml.new_description(DATA['description'])
_save(path, data)
+ if TYPE_EXTENSION == 1:
+ path = _join(path_source, FILES['addons'])
+ data = xml.new_addons(EXTENSION['id'], DATA['addons'])
+ _save(path, data)
+ path = _join(path_source, DIRS['office'])
+ _mkdir(path)
+ path = _join(path_source, DIRS['office'], FILES['shortcut'])
+ data = xml.new_accelerators(EXTENSION['id'], DATA['addons']['menus'])
+ _save(path, data)
- path = _join(path_source, DIRS['office'])
- _mkdir(path)
- path = _join(path_source, DIRS['office'], FILES['shortcut'])
- _save(path, DATA['shortcut'])
-
- path = _join(path_source, FILES['addons'])
- _save(path, DATA['addons'])
-
if TYPE_EXTENSION == 3:
path = _join(path_source, FILES['addin'])
_save(path, DATA['addin'])
if USE_LOCALES:
msg = "Don't forget generate DOMAIN.pot for locales"
- log.info(msg)
for lang in EXTENSION['languages']:
path = _join(path_source, DIRS['locales'], lang, 'LC_MESSAGES')
Path(path).mkdir(parents=True, exist_ok=True)
+ log.info(msg)
+
+ if DATA['update']:
+ path_xml = _join(path_files, FILES['update'])
+ data = xml.new_update(EXTENSION, DATA['update'])
+ _save(path_xml, data)
_compile_idl()
return