from datetime import ( timedelta as Timedelta, datetime as Datetime, date, time) from warnings import warn import socket from struct import pack from hashlib import md5 from decimal import Decimal from collections import deque, defaultdict from itertools import count, islice from uuid import UUID from copy import deepcopy from calendar import timegm from distutils.version import LooseVersion from struct import Struct from time import localtime import pg8000 from json import loads, dumps from os import getpid from .scramp import ScramClient import enum from ipaddress import ( ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network) from datetime import timezone as Timezone # Copyright (c) 2007-2009, Mathieu Fenniak # Copyright (c) The Contributors # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # * The name of the author may not be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. __author__ = "Mathieu Fenniak" ZERO = Timedelta(0) BINARY = bytes class Interval(): """An Interval represents a measurement of time. In PostgreSQL, an interval is defined in the measure of months, days, and microseconds; as such, the pg8000 interval type represents the same information. Note that values of the :attr:`microseconds`, :attr:`days` and :attr:`months` properties are independently measured and cannot be converted to each other. A month may be 28, 29, 30, or 31 days, and a day may occasionally be lengthened slightly by a leap second. .. attribute:: microseconds Measure of microseconds in the interval. The microseconds value is constrained to fit into a signed 64-bit integer. Any attempt to set a value too large or too small will result in an OverflowError being raised. .. attribute:: days Measure of days in the interval. The days value is constrained to fit into a signed 32-bit integer. Any attempt to set a value too large or too small will result in an OverflowError being raised. .. attribute:: months Measure of months in the interval. The months value is constrained to fit into a signed 32-bit integer. Any attempt to set a value too large or too small will result in an OverflowError being raised. """ def __init__(self, microseconds=0, days=0, months=0): self.microseconds = microseconds self.days = days self.months = months def _setMicroseconds(self, value): if not isinstance(value, int): raise TypeError("microseconds must be an integer type") elif not (min_int8 < value < max_int8): raise OverflowError( "microseconds must be representable as a 64-bit integer") else: self._microseconds = value def _setDays(self, value): if not isinstance(value, int): raise TypeError("days must be an integer type") elif not (min_int4 < value < max_int4): raise OverflowError( "days must be representable as a 32-bit integer") else: self._days = value def _setMonths(self, value): if not isinstance(value, int): raise TypeError("months must be an integer type") elif not (min_int4 < value < max_int4): raise OverflowError( "months must be representable as a 32-bit integer") else: self._months = value microseconds = property(lambda self: self._microseconds, _setMicroseconds) days = property(lambda self: self._days, _setDays) months = property(lambda self: self._months, _setMonths) def __repr__(self): return "" % ( self.months, self.days, self.microseconds) def __eq__(self, other): return other is not None and isinstance(other, Interval) and \ self.months == other.months and self.days == other.days and \ self.microseconds == other.microseconds def __neq__(self, other): return not self.__eq__(other) class PGType(): def __init__(self, value): self.value = value def encode(self, encoding): return str(self.value).encode(encoding) class PGEnum(PGType): def __init__(self, value): if isinstance(value, str): self.value = value else: self.value = value.value class PGJson(PGType): def encode(self, encoding): return dumps(self.value).encode(encoding) class PGJsonb(PGType): def encode(self, encoding): return dumps(self.value).encode(encoding) class PGTsvector(PGType): pass class PGVarchar(str): pass class PGText(str): pass def pack_funcs(fmt): struc = Struct('!' + fmt) return struc.pack, struc.unpack_from i_pack, i_unpack = pack_funcs('i') h_pack, h_unpack = pack_funcs('h') q_pack, q_unpack = pack_funcs('q') d_pack, d_unpack = pack_funcs('d') f_pack, f_unpack = pack_funcs('f') iii_pack, iii_unpack = pack_funcs('iii') ii_pack, ii_unpack = pack_funcs('ii') qii_pack, qii_unpack = pack_funcs('qii') dii_pack, dii_unpack = pack_funcs('dii') ihihih_pack, ihihih_unpack = pack_funcs('ihihih') ci_pack, ci_unpack = pack_funcs('ci') bh_pack, bh_unpack = pack_funcs('bh') cccc_pack, cccc_unpack = pack_funcs('cccc') min_int2, max_int2 = -2 ** 15, 2 ** 15 min_int4, max_int4 = -2 ** 31, 2 ** 31 min_int8, max_int8 = -2 ** 63, 2 ** 63 class Warning(Exception): """Generic exception raised for important database warnings like data truncations. This exception is not currently used by pg8000. This exception is part of the `DBAPI 2.0 specification `_. """ pass class Error(Exception): """Generic exception that is the base exception of all other error exceptions. This exception is part of the `DBAPI 2.0 specification `_. """ pass class InterfaceError(Error): """Generic exception raised for errors that are related to the database interface rather than the database itself. For example, if the interface attempts to use an SSL connection but the server refuses, an InterfaceError will be raised. This exception is part of the `DBAPI 2.0 specification `_. """ pass class DatabaseError(Error): """Generic exception raised for errors that are related to the database. This exception is currently never raised by pg8000. This exception is part of the `DBAPI 2.0 specification `_. """ pass class DataError(DatabaseError): """Generic exception raised for errors that are due to problems with the processed data. This exception is not currently raised by pg8000. This exception is part of the `DBAPI 2.0 specification `_. """ pass class OperationalError(DatabaseError): """ Generic exception raised for errors that are related to the database's operation and not necessarily under the control of the programmer. This exception is currently never raised by pg8000. This exception is part of the `DBAPI 2.0 specification `_. """ pass class IntegrityError(DatabaseError): """ Generic exception raised when the relational integrity of the database is affected. This exception is not currently raised by pg8000. This exception is part of the `DBAPI 2.0 specification `_. """ pass class InternalError(DatabaseError): """Generic exception raised when the database encounters an internal error. This is currently only raised when unexpected state occurs in the pg8000 interface itself, and is typically the result of a interface bug. This exception is part of the `DBAPI 2.0 specification `_. """ pass class ProgrammingError(DatabaseError): """Generic exception raised for programming errors. For example, this exception is raised if more parameter fields are in a query string than there are available parameters. This exception is part of the `DBAPI 2.0 specification `_. """ pass class NotSupportedError(DatabaseError): """Generic exception raised in case a method or database API was used which is not supported by the database. This exception is part of the `DBAPI 2.0 specification `_. """ pass class ArrayContentNotSupportedError(NotSupportedError): """ Raised when attempting to transmit an array where the base type is not supported for binary data transfer by the interface. """ pass class ArrayContentNotHomogenousError(ProgrammingError): """ Raised when attempting to transmit an array that doesn't contain only a single type of object. """ pass class ArrayDimensionsNotConsistentError(ProgrammingError): """ Raised when attempting to transmit an array that has inconsistent multi-dimension sizes. """ pass def Date(year, month, day): """Constuct an object holding a date value. This function is part of the `DBAPI 2.0 specification `_. :rtype: :class:`datetime.date` """ return date(year, month, day) def Time(hour, minute, second): """Construct an object holding a time value. This function is part of the `DBAPI 2.0 specification `_. :rtype: :class:`datetime.time` """ return time(hour, minute, second) def Timestamp(year, month, day, hour, minute, second): """Construct an object holding a timestamp value. This function is part of the `DBAPI 2.0 specification `_. :rtype: :class:`datetime.datetime` """ return Datetime(year, month, day, hour, minute, second) def DateFromTicks(ticks): """Construct an object holding a date value from the given ticks value (number of seconds since the epoch). This function is part of the `DBAPI 2.0 specification `_. :rtype: :class:`datetime.date` """ return Date(*localtime(ticks)[:3]) def TimeFromTicks(ticks): """Construct an objet holding a time value from the given ticks value (number of seconds since the epoch). This function is part of the `DBAPI 2.0 specification `_. :rtype: :class:`datetime.time` """ return Time(*localtime(ticks)[3:6]) def TimestampFromTicks(ticks): """Construct an object holding a timestamp value from the given ticks value (number of seconds since the epoch). This function is part of the `DBAPI 2.0 specification `_. :rtype: :class:`datetime.datetime` """ return Timestamp(*localtime(ticks)[:6]) def Binary(value): """Construct an object holding binary data. This function is part of the `DBAPI 2.0 specification `_. """ return value FC_TEXT = 0 FC_BINARY = 1 def convert_paramstyle(style, query): # I don't see any way to avoid scanning the query string char by char, # so we might as well take that careful approach and create a # state-based scanner. We'll use int variables for the state. OUTSIDE = 0 # outside quoted string INSIDE_SQ = 1 # inside single-quote string '...' INSIDE_QI = 2 # inside quoted identifier "..." INSIDE_ES = 3 # inside escaped single-quote string, E'...' INSIDE_PN = 4 # inside parameter name eg. :name INSIDE_CO = 5 # inside inline comment eg. -- in_quote_escape = False in_param_escape = False placeholders = [] output_query = [] param_idx = map(lambda x: "$" + str(x), count(1)) state = OUTSIDE prev_c = None for i, c in enumerate(query): if i + 1 < len(query): next_c = query[i + 1] else: next_c = None if state == OUTSIDE: if c == "'": output_query.append(c) if prev_c == 'E': state = INSIDE_ES else: state = INSIDE_SQ elif c == '"': output_query.append(c) state = INSIDE_QI elif c == '-': output_query.append(c) if prev_c == '-': state = INSIDE_CO elif style == "qmark" and c == "?": output_query.append(next(param_idx)) elif style == "numeric" and c == ":" and next_c not in ':=' \ and prev_c != ':': # Treat : as beginning of parameter name if and only # if it's the only : around # Needed to properly process type conversions # i.e. sum(x)::float output_query.append("$") elif style == "named" and c == ":" and next_c not in ':=' \ and prev_c != ':': # Same logic for : as in numeric parameters state = INSIDE_PN placeholders.append('') elif style == "pyformat" and c == '%' and next_c == "(": state = INSIDE_PN placeholders.append('') elif style in ("format", "pyformat") and c == "%": style = "format" if in_param_escape: in_param_escape = False output_query.append(c) else: if next_c == "%": in_param_escape = True elif next_c == "s": state = INSIDE_PN output_query.append(next(param_idx)) else: raise InterfaceError( "Only %s and %% are supported in the query.") else: output_query.append(c) elif state == INSIDE_SQ: if c == "'": if in_quote_escape: in_quote_escape = False else: if next_c == "'": in_quote_escape = True else: state = OUTSIDE output_query.append(c) elif state == INSIDE_QI: if c == '"': state = OUTSIDE output_query.append(c) elif state == INSIDE_ES: if c == "'" and prev_c != "\\": # check for escaped single-quote state = OUTSIDE output_query.append(c) elif state == INSIDE_PN: if style == 'named': placeholders[-1] += c if next_c is None or (not next_c.isalnum() and next_c != '_'): state = OUTSIDE try: pidx = placeholders.index(placeholders[-1], 0, -1) output_query.append("$" + str(pidx + 1)) del placeholders[-1] except ValueError: output_query.append("$" + str(len(placeholders))) elif style == 'pyformat': if prev_c == ')' and c == "s": state = OUTSIDE try: pidx = placeholders.index(placeholders[-1], 0, -1) output_query.append("$" + str(pidx + 1)) del placeholders[-1] except ValueError: output_query.append("$" + str(len(placeholders))) elif c in "()": pass else: placeholders[-1] += c elif style == 'format': state = OUTSIDE elif state == INSIDE_CO: output_query.append(c) if c == '\n': state = OUTSIDE prev_c = c if style in ('numeric', 'qmark', 'format'): def make_args(vals): return vals else: def make_args(vals): return tuple(vals[p] for p in placeholders) return ''.join(output_query), make_args EPOCH = Datetime(2000, 1, 1) EPOCH_TZ = EPOCH.replace(tzinfo=Timezone.utc) EPOCH_SECONDS = timegm(EPOCH.timetuple()) INFINITY_MICROSECONDS = 2 ** 63 - 1 MINUS_INFINITY_MICROSECONDS = -1 * INFINITY_MICROSECONDS - 1 # data is 64-bit integer representing microseconds since 2000-01-01 def timestamp_recv_integer(data, offset, length): micros = q_unpack(data, offset)[0] try: return EPOCH + Timedelta(microseconds=micros) except OverflowError: if micros == INFINITY_MICROSECONDS: return 'infinity' elif micros == MINUS_INFINITY_MICROSECONDS: return '-infinity' else: return micros # data is double-precision float representing seconds since 2000-01-01 def timestamp_recv_float(data, offset, length): return Datetime.utcfromtimestamp(EPOCH_SECONDS + d_unpack(data, offset)[0]) # data is 64-bit integer representing microseconds since 2000-01-01 def timestamp_send_integer(v): return q_pack( int((timegm(v.timetuple()) - EPOCH_SECONDS) * 1e6) + v.microsecond) # data is double-precision float representing seconds since 2000-01-01 def timestamp_send_float(v): return d_pack(timegm(v.timetuple()) + v.microsecond / 1e6 - EPOCH_SECONDS) def timestamptz_send_integer(v): # timestamps should be sent as UTC. If they have zone info, # convert them. return timestamp_send_integer( v.astimezone(Timezone.utc).replace(tzinfo=None)) def timestamptz_send_float(v): # timestamps should be sent as UTC. If they have zone info, # convert them. return timestamp_send_float( v.astimezone(Timezone.utc).replace(tzinfo=None)) # return a timezone-aware datetime instance if we're reading from a # "timestamp with timezone" type. The timezone returned will always be # UTC, but providing that additional information can permit conversion # to local. def timestamptz_recv_integer(data, offset, length): micros = q_unpack(data, offset)[0] try: return EPOCH_TZ + Timedelta(microseconds=micros) except OverflowError: if micros == INFINITY_MICROSECONDS: return 'infinity' elif micros == MINUS_INFINITY_MICROSECONDS: return '-infinity' else: return micros def timestamptz_recv_float(data, offset, length): return timestamp_recv_float(data, offset, length).replace( tzinfo=Timezone.utc) def interval_send_integer(v): microseconds = v.microseconds try: microseconds += int(v.seconds * 1e6) except AttributeError: pass try: months = v.months except AttributeError: months = 0 return qii_pack(microseconds, v.days, months) def interval_send_float(v): seconds = v.microseconds / 1000.0 / 1000.0 try: seconds += v.seconds except AttributeError: pass try: months = v.months except AttributeError: months = 0 return dii_pack(seconds, v.days, months) def interval_recv_integer(data, offset, length): microseconds, days, months = qii_unpack(data, offset) if months == 0: seconds, micros = divmod(microseconds, 1e6) return Timedelta(days, seconds, micros) else: return Interval(microseconds, days, months) def interval_recv_float(data, offset, length): seconds, days, months = dii_unpack(data, offset) if months == 0: secs, microseconds = divmod(seconds, 1e6) return Timedelta(days, secs, microseconds) else: return Interval(int(seconds * 1000 * 1000), days, months) def int8_recv(data, offset, length): return q_unpack(data, offset)[0] def int2_recv(data, offset, length): return h_unpack(data, offset)[0] def int4_recv(data, offset, length): return i_unpack(data, offset)[0] def float4_recv(data, offset, length): return f_unpack(data, offset)[0] def float8_recv(data, offset, length): return d_unpack(data, offset)[0] def bytea_send(v): return v # bytea def bytea_recv(data, offset, length): return data[offset:offset + length] def uuid_send(v): return v.bytes def uuid_recv(data, offset, length): return UUID(bytes=data[offset:offset+length]) def bool_send(v): return b"\x01" if v else b"\x00" NULL = i_pack(-1) NULL_BYTE = b'\x00' def null_send(v): return NULL def int_in(data, offset, length): return int(data[offset: offset + length]) class Cursor(): """A cursor object is returned by the :meth:`~Connection.cursor` method of a connection. It has the following attributes and methods: .. attribute:: arraysize This read/write attribute specifies the number of rows to fetch at a time with :meth:`fetchmany`. It defaults to 1. .. attribute:: connection This read-only attribute contains a reference to the connection object (an instance of :class:`Connection`) on which the cursor was created. This attribute is part of a DBAPI 2.0 extension. Accessing this attribute will generate the following warning: ``DB-API extension cursor.connection used``. .. attribute:: rowcount This read-only attribute contains the number of rows that the last ``execute()`` or ``executemany()`` method produced (for query statements like ``SELECT``) or affected (for modification statements like ``UPDATE``). The value is -1 if: - No ``execute()`` or ``executemany()`` method has been performed yet on the cursor. - There was no rowcount associated with the last ``execute()``. - At least one of the statements executed as part of an ``executemany()`` had no row count associated with it. - Using a ``SELECT`` query statement on PostgreSQL server older than version 9. - Using a ``COPY`` query statement on PostgreSQL server version 8.1 or older. This attribute is part of the `DBAPI 2.0 specification `_. .. attribute:: description This read-only attribute is a sequence of 7-item sequences. Each value contains information describing one result column. The 7 items returned for each column are (name, type_code, display_size, internal_size, precision, scale, null_ok). Only the first two values are provided by the current implementation. This attribute is part of the `DBAPI 2.0 specification `_. """ def __init__(self, connection): self._c = connection self.arraysize = 1 self.ps = None self._row_count = -1 self._cached_rows = deque() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() @property def connection(self): warn("DB-API extension cursor.connection used", stacklevel=3) return self._c @property def rowcount(self): return self._row_count description = property(lambda self: self._getDescription()) def _getDescription(self): if self.ps is None: return None row_desc = self.ps['row_desc'] if len(row_desc) == 0: return None columns = [] for col in row_desc: columns.append( (col["name"], col["type_oid"], None, None, None, None, None)) return columns ## # Executes a database operation. Parameters may be provided as a sequence # or mapping and will be bound to variables in the operation. #

# Stability: Part of the DBAPI 2.0 specification. def execute(self, operation, args=None, stream=None): """Executes a database operation. Parameters may be provided as a sequence, or as a mapping, depending upon the value of :data:`pg8000.paramstyle`. This method is part of the `DBAPI 2.0 specification `_. :param operation: The SQL statement to execute. :param args: If :data:`paramstyle` is ``qmark``, ``numeric``, or ``format``, this argument should be an array of parameters to bind into the statement. If :data:`paramstyle` is ``named``, the argument should be a dict mapping of parameters. If the :data:`paramstyle` is ``pyformat``, the argument value may be either an array or a mapping. :param stream: This is a pg8000 extension for use with the PostgreSQL `COPY `_ command. For a COPY FROM the parameter must be a readable file-like object, and for COPY TO it must be writable. .. versionadded:: 1.9.11 """ try: self.stream = stream if not self._c.in_transaction and not self._c.autocommit: self._c.execute(self, "begin transaction", None) self._c.execute(self, operation, args) except AttributeError as e: if self._c is None: raise InterfaceError("Cursor closed") elif self._c._sock is None: raise InterfaceError("connection is closed") else: raise e return self def executemany(self, operation, param_sets): """Prepare a database operation, and then execute it against all parameter sequences or mappings provided. This method is part of the `DBAPI 2.0 specification `_. :param operation: The SQL statement to execute :param parameter_sets: A sequence of parameters to execute the statement with. The values in the sequence should be sequences or mappings of parameters, the same as the args argument of the :meth:`execute` method. """ rowcounts = [] for parameters in param_sets: self.execute(operation, parameters) rowcounts.append(self._row_count) self._row_count = -1 if -1 in rowcounts else sum(rowcounts) return self def fetchone(self): """Fetch the next row of a query result set. This method is part of the `DBAPI 2.0 specification `_. :returns: A row as a sequence of field values, or ``None`` if no more rows are available. """ try: return next(self) except StopIteration: return None except TypeError: raise ProgrammingError("attempting to use unexecuted cursor") except AttributeError: raise ProgrammingError("attempting to use unexecuted cursor") def fetchmany(self, num=None): """Fetches the next set of rows of a query result. This method is part of the `DBAPI 2.0 specification `_. :param size: The number of rows to fetch when called. If not provided, the :attr:`arraysize` attribute value is used instead. :returns: A sequence, each entry of which is a sequence of field values making up a row. If no more rows are available, an empty sequence will be returned. """ try: return tuple( islice(self, self.arraysize if num is None else num)) except TypeError: raise ProgrammingError("attempting to use unexecuted cursor") def fetchall(self): """Fetches all remaining rows of a query result. This method is part of the `DBAPI 2.0 specification `_. :returns: A sequence, each entry of which is a sequence of field values making up a row. """ try: return tuple(self) except TypeError: raise ProgrammingError("attempting to use unexecuted cursor") def close(self): """Closes the cursor. This method is part of the `DBAPI 2.0 specification `_. """ self._c = None def __iter__(self): """A cursor object is iterable to retrieve the rows from a query. This is a DBAPI 2.0 extension. """ return self def setinputsizes(self, sizes): """This method is part of the `DBAPI 2.0 specification `_, however, it is not implemented by pg8000. """ pass def setoutputsize(self, size, column=None): """This method is part of the `DBAPI 2.0 specification `_, however, it is not implemented by pg8000. """ pass def __next__(self): try: return self._cached_rows.popleft() except IndexError: if self.ps is None: raise ProgrammingError("A query hasn't been issued.") elif len(self.ps['row_desc']) == 0: raise ProgrammingError("no result set") else: raise StopIteration() # Message codes NOTICE_RESPONSE = b"N" AUTHENTICATION_REQUEST = b"R" PARAMETER_STATUS = b"S" BACKEND_KEY_DATA = b"K" READY_FOR_QUERY = b"Z" ROW_DESCRIPTION = b"T" ERROR_RESPONSE = b"E" DATA_ROW = b"D" COMMAND_COMPLETE = b"C" PARSE_COMPLETE = b"1" BIND_COMPLETE = b"2" CLOSE_COMPLETE = b"3" PORTAL_SUSPENDED = b"s" NO_DATA = b"n" PARAMETER_DESCRIPTION = b"t" NOTIFICATION_RESPONSE = b"A" COPY_DONE = b"c" COPY_DATA = b"d" COPY_IN_RESPONSE = b"G" COPY_OUT_RESPONSE = b"H" EMPTY_QUERY_RESPONSE = b"I" BIND = b"B" PARSE = b"P" EXECUTE = b"E" FLUSH = b'H' SYNC = b'S' PASSWORD = b'p' DESCRIBE = b'D' TERMINATE = b'X' CLOSE = b'C' def _establish_ssl(_socket, ssl_params): if not isinstance(ssl_params, dict): ssl_params = {} try: import ssl as sslmodule keyfile = ssl_params.get('keyfile') certfile = ssl_params.get('certfile') ca_certs = ssl_params.get('ca_certs') if ca_certs is None: verify_mode = sslmodule.CERT_NONE else: verify_mode = sslmodule.CERT_REQUIRED # Int32(8) - Message length, including self. # Int32(80877103) - The SSL request code. _socket.sendall(ii_pack(8, 80877103)) resp = _socket.recv(1) if resp == b'S': return sslmodule.wrap_socket( _socket, keyfile=keyfile, certfile=certfile, cert_reqs=verify_mode, ca_certs=ca_certs) else: raise InterfaceError("Server refuses SSL") except ImportError: raise InterfaceError( "SSL required but ssl module not available in " "this python installation") def create_message(code, data=b''): return code + i_pack(len(data) + 4) + data FLUSH_MSG = create_message(FLUSH) SYNC_MSG = create_message(SYNC) TERMINATE_MSG = create_message(TERMINATE) COPY_DONE_MSG = create_message(COPY_DONE) EXECUTE_MSG = create_message(EXECUTE, NULL_BYTE + i_pack(0)) # DESCRIBE constants STATEMENT = b'S' PORTAL = b'P' # ErrorResponse codes RESPONSE_SEVERITY = "S" # always present RESPONSE_SEVERITY = "V" # always present RESPONSE_CODE = "C" # always present RESPONSE_MSG = "M" # always present RESPONSE_DETAIL = "D" RESPONSE_HINT = "H" RESPONSE_POSITION = "P" RESPONSE__POSITION = "p" RESPONSE__QUERY = "q" RESPONSE_WHERE = "W" RESPONSE_FILE = "F" RESPONSE_LINE = "L" RESPONSE_ROUTINE = "R" IDLE = b"I" IDLE_IN_TRANSACTION = b"T" IDLE_IN_FAILED_TRANSACTION = b"E" arr_trans = dict(zip(map(ord, "[] 'u"), list('{}') + [None] * 3)) class Connection(): # DBAPI Extension: supply exceptions as attributes on the connection Warning = property(lambda self: self._getError(Warning)) Error = property(lambda self: self._getError(Error)) InterfaceError = property(lambda self: self._getError(InterfaceError)) DatabaseError = property(lambda self: self._getError(DatabaseError)) OperationalError = property(lambda self: self._getError(OperationalError)) IntegrityError = property(lambda self: self._getError(IntegrityError)) InternalError = property(lambda self: self._getError(InternalError)) ProgrammingError = property(lambda self: self._getError(ProgrammingError)) NotSupportedError = property( lambda self: self._getError(NotSupportedError)) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def _getError(self, error): warn( "DB-API extension connection.%s used" % error.__name__, stacklevel=3) return error def __init__( self, user, host, unix_sock, port, database, password, ssl, timeout, application_name, max_prepared_statements, tcp_keepalive): self._client_encoding = "utf8" self._commands_with_count = ( b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY", b"SELECT") self.notifications = deque(maxlen=100) self.notices = deque(maxlen=100) self.parameter_statuses = deque(maxlen=100) self.max_prepared_statements = int(max_prepared_statements) if user is None: raise InterfaceError( "The 'user' connection parameter cannot be None") if isinstance(user, str): self.user = user.encode('utf8') else: self.user = user if isinstance(password, str): self.password = password.encode('utf8') else: self.password = password self.autocommit = False self._xid = None self._caches = {} try: if unix_sock is None and host is not None: self._usock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) elif unix_sock is not None: if not hasattr(socket, "AF_UNIX"): raise InterfaceError( "attempt to connect to unix socket on unsupported " "platform") self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: raise ProgrammingError( "one of host or unix_sock must be provided") if timeout is not None: self._usock.settimeout(timeout) if unix_sock is None and host is not None: self._usock.connect((host, port)) elif unix_sock is not None: self._usock.connect(unix_sock) if ssl: self._usock = _establish_ssl(self._usock, ssl) self._sock = self._usock.makefile(mode="rwb") if tcp_keepalive: self._usock.setsockopt( socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) except socket.error as e: self._usock.close() raise InterfaceError("communication error", e) self._flush = self._sock.flush self._read = self._sock.read self._write = self._sock.write self._backend_key_data = None def text_out(v): return v.encode(self._client_encoding) def enum_out(v): return str(v.value).encode(self._client_encoding) def time_out(v): return v.isoformat().encode(self._client_encoding) def date_out(v): return v.isoformat().encode(self._client_encoding) def unknown_out(v): return str(v).encode(self._client_encoding) trans_tab = dict(zip(map(ord, '{}'), '[]')) glbls = {'Decimal': Decimal} def array_in(data, idx, length): arr = [] prev_c = None for c in data[idx:idx+length].decode( self._client_encoding).translate( trans_tab).replace('NULL', 'None'): if c not in ('[', ']', ',', 'N') and prev_c in ('[', ','): arr.extend("Decimal('") elif c in (']', ',') and prev_c not in ('[', ']', ',', 'e'): arr.extend("')") arr.append(c) prev_c = c return eval(''.join(arr), glbls) def array_recv(data, idx, length): final_idx = idx + length dim, hasnull, typeoid = iii_unpack(data, idx) idx += 12 # get type conversion method for typeoid conversion = self.pg_types[typeoid][1] # Read dimension info dim_lengths = [] for i in range(dim): dim_lengths.append(ii_unpack(data, idx)[0]) idx += 8 # Read all array values values = [] while idx < final_idx: element_len, = i_unpack(data, idx) idx += 4 if element_len == -1: values.append(None) else: values.append(conversion(data, idx, element_len)) idx += element_len # at this point, {{1,2,3},{4,5,6}}::int[][] looks like # [1,2,3,4,5,6]. go through the dimensions and fix up the array # contents to match expected dimensions for length in reversed(dim_lengths[1:]): values = list(map(list, zip(*[iter(values)] * length))) return values def vector_in(data, idx, length): return eval('[' + data[idx:idx+length].decode( self._client_encoding).replace(' ', ',') + ']') def text_recv(data, offset, length): return str(data[offset: offset + length], self._client_encoding) def bool_recv(data, offset, length): return data[offset] == 1 def json_in(data, offset, length): return loads( str(data[offset: offset + length], self._client_encoding)) def time_in(data, offset, length): hour = int(data[offset:offset + 2]) minute = int(data[offset + 3:offset + 5]) sec = Decimal( data[offset + 6:offset + length].decode(self._client_encoding)) return time( hour, minute, int(sec), int((sec - int(sec)) * 1000000)) def date_in(data, offset, length): d = data[offset:offset+length].decode(self._client_encoding) try: return date(int(d[:4]), int(d[5:7]), int(d[8:10])) except ValueError: return d def numeric_in(data, offset, length): return Decimal( data[offset: offset + length].decode(self._client_encoding)) def numeric_out(d): return str(d).encode(self._client_encoding) self.pg_types = defaultdict( lambda: (FC_TEXT, text_recv), { 16: (FC_BINARY, bool_recv), # boolean 17: (FC_BINARY, bytea_recv), # bytea 19: (FC_BINARY, text_recv), # name type 20: (FC_BINARY, int8_recv), # int8 21: (FC_BINARY, int2_recv), # int2 22: (FC_TEXT, vector_in), # int2vector 23: (FC_BINARY, int4_recv), # int4 25: (FC_BINARY, text_recv), # TEXT type 26: (FC_TEXT, int_in), # oid 28: (FC_TEXT, int_in), # xid 114: (FC_TEXT, json_in), # json 700: (FC_BINARY, float4_recv), # float4 701: (FC_BINARY, float8_recv), # float8 705: (FC_BINARY, text_recv), # unknown 829: (FC_TEXT, text_recv), # MACADDR type 1000: (FC_BINARY, array_recv), # BOOL[] 1003: (FC_BINARY, array_recv), # NAME[] 1005: (FC_BINARY, array_recv), # INT2[] 1007: (FC_BINARY, array_recv), # INT4[] 1009: (FC_BINARY, array_recv), # TEXT[] 1014: (FC_BINARY, array_recv), # CHAR[] 1015: (FC_BINARY, array_recv), # VARCHAR[] 1016: (FC_BINARY, array_recv), # INT8[] 1021: (FC_BINARY, array_recv), # FLOAT4[] 1022: (FC_BINARY, array_recv), # FLOAT8[] 1042: (FC_BINARY, text_recv), # CHAR type 1043: (FC_BINARY, text_recv), # VARCHAR type 1082: (FC_TEXT, date_in), # date 1083: (FC_TEXT, time_in), 1114: (FC_BINARY, timestamp_recv_float), # timestamp w/ tz 1184: (FC_BINARY, timestamptz_recv_float), 1186: (FC_BINARY, interval_recv_integer), 1231: (FC_TEXT, array_in), # NUMERIC[] 1263: (FC_BINARY, array_recv), # cstring[] 1700: (FC_TEXT, numeric_in), # NUMERIC 2275: (FC_BINARY, text_recv), # cstring 2950: (FC_BINARY, uuid_recv), # uuid 3802: (FC_TEXT, json_in), # jsonb }) self.py_types = { type(None): (-1, FC_BINARY, null_send), # null bool: (16, FC_BINARY, bool_send), bytearray: (17, FC_BINARY, bytea_send), # bytea 20: (20, FC_BINARY, q_pack), # int8 21: (21, FC_BINARY, h_pack), # int2 23: (23, FC_BINARY, i_pack), # int4 PGText: (25, FC_TEXT, text_out), # text float: (701, FC_BINARY, d_pack), # float8 PGEnum: (705, FC_TEXT, enum_out), date: (1082, FC_TEXT, date_out), # date time: (1083, FC_TEXT, time_out), # time 1114: (1114, FC_BINARY, timestamp_send_integer), # timestamp # timestamp w/ tz PGVarchar: (1043, FC_TEXT, text_out), # varchar 1184: (1184, FC_BINARY, timestamptz_send_integer), PGJson: (114, FC_TEXT, text_out), PGJsonb: (3802, FC_TEXT, text_out), Timedelta: (1186, FC_BINARY, interval_send_integer), Interval: (1186, FC_BINARY, interval_send_integer), Decimal: (1700, FC_TEXT, numeric_out), # Decimal PGTsvector: (3614, FC_TEXT, text_out), UUID: (2950, FC_BINARY, uuid_send)} # uuid self.inspect_funcs = { Datetime: self.inspect_datetime, list: self.array_inspect, tuple: self.array_inspect, int: self.inspect_int} self.py_types[bytes] = (17, FC_BINARY, bytea_send) # bytea self.py_types[str] = (705, FC_TEXT, text_out) # unknown self.py_types[enum.Enum] = (705, FC_TEXT, enum_out) def inet_out(v): return str(v).encode(self._client_encoding) def inet_in(data, offset, length): inet_str = data[offset: offset + length].decode( self._client_encoding) if '/' in inet_str: return ip_network(inet_str, False) else: return ip_address(inet_str) self.py_types[IPv4Address] = (869, FC_TEXT, inet_out) # inet self.py_types[IPv6Address] = (869, FC_TEXT, inet_out) # inet self.py_types[IPv4Network] = (869, FC_TEXT, inet_out) # inet self.py_types[IPv6Network] = (869, FC_TEXT, inet_out) # inet self.pg_types[869] = (FC_TEXT, inet_in) # inet self.message_types = { NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE, AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST, PARAMETER_STATUS: self.handle_PARAMETER_STATUS, BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA, READY_FOR_QUERY: self.handle_READY_FOR_QUERY, ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION, ERROR_RESPONSE: self.handle_ERROR_RESPONSE, EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE, DATA_ROW: self.handle_DATA_ROW, COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE, PARSE_COMPLETE: self.handle_PARSE_COMPLETE, BIND_COMPLETE: self.handle_BIND_COMPLETE, CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE, PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED, NO_DATA: self.handle_NO_DATA, PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION, NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE, COPY_DONE: self.handle_COPY_DONE, COPY_DATA: self.handle_COPY_DATA, COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE, COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE} # Int32 - Message length, including self. # Int32(196608) - Protocol version number. Version 3.0. # Any number of key/value pairs, terminated by a zero byte: # String - A parameter name (user, database, or options) # String - Parameter value protocol = 196608 val = bytearray( i_pack(protocol) + b"user\x00" + self.user + NULL_BYTE) if database is not None: if isinstance(database, str): database = database.encode('utf8') val.extend(b"database\x00" + database + NULL_BYTE) if application_name is not None: if isinstance(application_name, str): application_name = application_name.encode('utf8') val.extend( b"application_name\x00" + application_name + NULL_BYTE) val.append(0) self._write(i_pack(len(val) + 4)) self._write(val) self._flush() self._cursor = self.cursor() code = self.error = None while code not in (READY_FOR_QUERY, ERROR_RESPONSE): code, data_len = ci_unpack(self._read(5)) self.message_types[code](self._read(data_len - 4), None) if self.error is not None: raise self.error self.in_transaction = False def handle_ERROR_RESPONSE(self, data, ps): msg = dict( ( s[:1].decode(self._client_encoding), s[1:].decode(self._client_encoding)) for s in data.split(NULL_BYTE) if s != b'') response_code = msg[RESPONSE_CODE] if response_code == '28000': cls = InterfaceError elif response_code == '23505': cls = IntegrityError else: cls = ProgrammingError self.error = cls(msg) def handle_EMPTY_QUERY_RESPONSE(self, data, ps): self.error = ProgrammingError("query was empty") def handle_CLOSE_COMPLETE(self, data, ps): pass def handle_PARSE_COMPLETE(self, data, ps): # Byte1('1') - Identifier. # Int32(4) - Message length, including self. pass def handle_BIND_COMPLETE(self, data, ps): pass def handle_PORTAL_SUSPENDED(self, data, cursor): pass def handle_PARAMETER_DESCRIPTION(self, data, ps): # Well, we don't really care -- we're going to send whatever we # want and let the database deal with it. But thanks anyways! # count = h_unpack(data)[0] # type_oids = unpack_from("!" + "i" * count, data, 2) pass def handle_COPY_DONE(self, data, ps): self._copy_done = True def handle_COPY_OUT_RESPONSE(self, data, ps): # Int8(1) - 0 textual, 1 binary # Int16(2) - Number of columns # Int16(N) - Format codes for each column (0 text, 1 binary) is_binary, num_cols = bh_unpack(data) # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) if ps.stream is None: raise InterfaceError( "An output stream is required for the COPY OUT response.") def handle_COPY_DATA(self, data, ps): ps.stream.write(data) def handle_COPY_IN_RESPONSE(self, data, ps): # Int16(2) - Number of columns # Int16(N) - Format codes for each column (0 text, 1 binary) is_binary, num_cols = bh_unpack(data) # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) if ps.stream is None: raise InterfaceError( "An input stream is required for the COPY IN response.") bffr = bytearray(8192) while True: bytes_read = ps.stream.readinto(bffr) if bytes_read == 0: break self._write(COPY_DATA + i_pack(bytes_read + 4)) self._write(bffr[:bytes_read]) self._flush() # Send CopyDone # Byte1('c') - Identifier. # Int32(4) - Message length, including self. self._write(COPY_DONE_MSG) self._write(SYNC_MSG) self._flush() def handle_NOTIFICATION_RESPONSE(self, data, ps): ## # A message sent if this connection receives a NOTIFY that it was # LISTENing for. #

# Stability: Added in pg8000 v1.03. When limited to accessing # properties from a notification event dispatch, stability is # guaranteed for v1.xx. backend_pid = i_unpack(data)[0] idx = 4 null = data.find(NULL_BYTE, idx) - idx condition = data[idx:idx + null].decode("ascii") idx += null + 1 null = data.find(NULL_BYTE, idx) - idx # additional_info = data[idx:idx + null] self.notifications.append((backend_pid, condition)) def cursor(self): """Creates a :class:`Cursor` object bound to this connection. This function is part of the `DBAPI 2.0 specification `_. """ return Cursor(self) def commit(self): """Commits the current database transaction. This function is part of the `DBAPI 2.0 specification `_. """ self.execute(self._cursor, "commit", None) def rollback(self): """Rolls back the current database transaction. This function is part of the `DBAPI 2.0 specification `_. """ if not self.in_transaction: return self.execute(self._cursor, "rollback", None) def close(self): """Closes the database connection. This function is part of the `DBAPI 2.0 specification `_. """ try: # Byte1('X') - Identifies the message as a terminate message. # Int32(4) - Message length, including self. self._write(TERMINATE_MSG) self._flush() self._sock.close() except AttributeError: raise InterfaceError("connection is closed") except ValueError: raise InterfaceError("connection is closed") except socket.error: pass finally: self._usock.close() self._sock = None def handle_AUTHENTICATION_REQUEST(self, data, cursor): # Int32 - An authentication code that represents different # authentication messages: # 0 = AuthenticationOk # 5 = MD5 pwd # 2 = Kerberos v5 (not supported by pg8000) # 3 = Cleartext pwd # 4 = crypt() pwd (not supported by pg8000) # 6 = SCM credential (not supported by pg8000) # 7 = GSSAPI (not supported by pg8000) # 8 = GSSAPI data (not supported by pg8000) # 9 = SSPI (not supported by pg8000) # Some authentication messages have additional data following the # authentication code. That data is documented in the appropriate # class. auth_code = i_unpack(data)[0] if auth_code == 0: pass elif auth_code == 3: if self.password is None: raise InterfaceError( "server requesting password authentication, but no " "password was provided") self._send_message(PASSWORD, self.password + NULL_BYTE) self._flush() elif auth_code == 5: ## # A message representing the backend requesting an MD5 hashed # password response. The response will be sent as # md5(md5(pwd + login) + salt). # Additional message data: # Byte4 - Hash salt. salt = b"".join(cccc_unpack(data, 4)) if self.password is None: raise InterfaceError( "server requesting MD5 password authentication, but no " "password was provided") pwd = b"md5" + md5( md5(self.password + self.user).hexdigest().encode("ascii") + salt).hexdigest().encode("ascii") # Byte1('p') - Identifies the message as a password message. # Int32 - Message length including self. # String - The password. Password may be encrypted. self._send_message(PASSWORD, pwd + NULL_BYTE) self._flush() elif auth_code == 10: # AuthenticationSASL mechanisms = [ m.decode('ascii') for m in data[4:-1].split(NULL_BYTE)] self.auth = ScramClient( mechanisms, self.user.decode('utf8'), self.password.decode('utf8')) init = self.auth.get_client_first().encode('utf8') # SASLInitialResponse self._write( create_message( PASSWORD, b'SCRAM-SHA-256' + NULL_BYTE + i_pack(len(init)) + init)) self._flush() elif auth_code == 11: # AuthenticationSASLContinue self.auth.set_server_first(data[4:].decode('utf8')) # SASLResponse msg = self.auth.get_client_final().encode('utf8') self._write(create_message(PASSWORD, msg)) self._flush() elif auth_code == 12: # AuthenticationSASLFinal self.auth.set_server_final(data[4:].decode('utf8')) elif auth_code in (2, 4, 6, 7, 8, 9): raise InterfaceError( "Authentication method " + str(auth_code) + " not supported by pg8000.") else: raise InterfaceError( "Authentication method " + str(auth_code) + " not recognized by pg8000.") def handle_READY_FOR_QUERY(self, data, ps): # Byte1 - Status indicator. self.in_transaction = data != IDLE def handle_BACKEND_KEY_DATA(self, data, ps): self._backend_key_data = data def inspect_datetime(self, value): if value.tzinfo is None: return self.py_types[1114] # timestamp else: return self.py_types[1184] # send as timestamptz def inspect_int(self, value): if min_int2 < value < max_int2: return self.py_types[21] if min_int4 < value < max_int4: return self.py_types[23] if min_int8 < value < max_int8: return self.py_types[20] def make_params(self, values): params = [] for value in values: typ = type(value) try: params.append(self.py_types[typ]) except KeyError: try: params.append(self.inspect_funcs[typ](value)) except KeyError as e: param = None for k, v in self.py_types.items(): try: if isinstance(value, k): param = v break except TypeError: pass if param is None: for k, v in self.inspect_funcs.items(): try: if isinstance(value, k): param = v(value) break except TypeError: pass except KeyError: pass if param is None: raise NotSupportedError( "type " + str(e) + " not mapped to pg type") else: params.append(param) return tuple(params) def handle_ROW_DESCRIPTION(self, data, cursor): count = h_unpack(data)[0] idx = 2 for i in range(count): name = data[idx:data.find(NULL_BYTE, idx)] idx += len(name) + 1 field = dict( zip(( "table_oid", "column_attrnum", "type_oid", "type_size", "type_modifier", "format"), ihihih_unpack(data, idx))) field['name'] = name idx += 18 cursor.ps['row_desc'].append(field) field['pg8000_fc'], field['func'] = \ self.pg_types[field['type_oid']] def execute(self, cursor, operation, vals): if vals is None: vals = () paramstyle = pg8000.paramstyle pid = getpid() try: cache = self._caches[paramstyle][pid] except KeyError: try: param_cache = self._caches[paramstyle] except KeyError: param_cache = self._caches[paramstyle] = {} try: cache = param_cache[pid] except KeyError: cache = param_cache[pid] = {'statement': {}, 'ps': {}} try: statement, make_args = cache['statement'][operation] except KeyError: statement, make_args = cache['statement'][operation] = \ convert_paramstyle(paramstyle, operation) args = make_args(vals) params = self.make_params(args) key = operation, params try: ps = cache['ps'][key] cursor.ps = ps except KeyError: statement_nums = [0] for style_cache in self._caches.values(): try: pid_cache = style_cache[pid] for csh in pid_cache['ps'].values(): statement_nums.append(csh['statement_num']) except KeyError: pass statement_num = sorted(statement_nums)[-1] + 1 statement_name = '_'.join( ("pg8000", "statement", str(pid), str(statement_num))) statement_name_bin = statement_name.encode('ascii') + NULL_BYTE ps = { 'statement_name_bin': statement_name_bin, 'pid': pid, 'statement_num': statement_num, 'row_desc': [], 'param_funcs': tuple(x[2] for x in params)} cursor.ps = ps param_fcs = tuple(x[1] for x in params) # Byte1('P') - Identifies the message as a Parse command. # Int32 - Message length, including self. # String - Prepared statement name. An empty string selects the # unnamed prepared statement. # String - The query string. # Int16 - Number of parameter data types specified (can be zero). # For each parameter: # Int32 - The OID of the parameter data type. val = bytearray(statement_name_bin) val.extend(statement.encode(self._client_encoding) + NULL_BYTE) val.extend(h_pack(len(params))) for oid, fc, send_func in params: # Parse message doesn't seem to handle the -1 type_oid for NULL # values that other messages handle. So we'll provide type_oid # 705, the PG "unknown" type. val.extend(i_pack(705 if oid == -1 else oid)) # Byte1('D') - Identifies the message as a describe command. # Int32 - Message length, including self. # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to describe. self._send_message(PARSE, val) self._send_message(DESCRIBE, STATEMENT + statement_name_bin) self._write(SYNC_MSG) try: self._flush() except AttributeError as e: if self._sock is None: raise InterfaceError("connection is closed") else: raise e self.handle_messages(cursor) # We've got row_desc that allows us to identify what we're # going to get back from this statement. output_fc = tuple( self.pg_types[f['type_oid']][0] for f in ps['row_desc']) ps['input_funcs'] = tuple(f['func'] for f in ps['row_desc']) # Byte1('B') - Identifies the Bind command. # Int32 - Message length, including self. # String - Name of the destination portal. # String - Name of the source prepared statement. # Int16 - Number of parameter format codes. # For each parameter format code: # Int16 - The parameter format code. # Int16 - Number of parameter values. # For each parameter value: # Int32 - The length of the parameter value, in bytes, not # including this length. -1 indicates a NULL parameter # value, in which no value bytes follow. # Byte[n] - Value of the parameter. # Int16 - The number of result-column format codes. # For each result-column format code: # Int16 - The format code. ps['bind_1'] = NULL_BYTE + statement_name_bin + \ h_pack(len(params)) + \ pack("!" + "h" * len(param_fcs), *param_fcs) + \ h_pack(len(params)) ps['bind_2'] = h_pack(len(output_fc)) + \ pack("!" + "h" * len(output_fc), *output_fc) if len(cache['ps']) > self.max_prepared_statements: for p in cache['ps'].values(): self.close_prepared_statement(p['statement_name_bin']) cache['ps'].clear() cache['ps'][key] = ps cursor._cached_rows.clear() cursor._row_count = -1 # Byte1('B') - Identifies the Bind command. # Int32 - Message length, including self. # String - Name of the destination portal. # String - Name of the source prepared statement. # Int16 - Number of parameter format codes. # For each parameter format code: # Int16 - The parameter format code. # Int16 - Number of parameter values. # For each parameter value: # Int32 - The length of the parameter value, in bytes, not # including this length. -1 indicates a NULL parameter # value, in which no value bytes follow. # Byte[n] - Value of the parameter. # Int16 - The number of result-column format codes. # For each result-column format code: # Int16 - The format code. retval = bytearray(ps['bind_1']) for value, send_func in zip(args, ps['param_funcs']): if value is None: val = NULL else: val = send_func(value) retval.extend(i_pack(len(val))) retval.extend(val) retval.extend(ps['bind_2']) self._send_message(BIND, retval) self.send_EXECUTE(cursor) self._write(SYNC_MSG) self._flush() self.handle_messages(cursor) def _send_message(self, code, data): try: self._write(code) self._write(i_pack(len(data) + 4)) self._write(data) self._write(FLUSH_MSG) except ValueError as e: if str(e) == "write to closed file": raise InterfaceError("connection is closed") else: raise e except AttributeError: raise InterfaceError("connection is closed") def send_EXECUTE(self, cursor): # Byte1('E') - Identifies the message as an execute message. # Int32 - Message length, including self. # String - The name of the portal to execute. # Int32 - Maximum number of rows to return, if portal # contains a query # that returns rows. # 0 = no limit. self._write(EXECUTE_MSG) self._write(FLUSH_MSG) def handle_NO_DATA(self, msg, ps): pass def handle_COMMAND_COMPLETE(self, data, cursor): values = data[:-1].split(b' ') command = values[0] if command in self._commands_with_count: row_count = int(values[-1]) if cursor._row_count == -1: cursor._row_count = row_count else: cursor._row_count += row_count if command in (b"ALTER", b"CREATE"): for scache in self._caches.values(): for pcache in scache.values(): for ps in pcache['ps'].values(): self.close_prepared_statement(ps['statement_name_bin']) pcache['ps'].clear() def handle_DATA_ROW(self, data, cursor): data_idx = 2 row = [] for func in cursor.ps['input_funcs']: vlen = i_unpack(data, data_idx)[0] data_idx += 4 if vlen == -1: row.append(None) else: row.append(func(data, data_idx, vlen)) data_idx += vlen cursor._cached_rows.append(row) def handle_messages(self, cursor): code = self.error = None while code != READY_FOR_QUERY: code, data_len = ci_unpack(self._read(5)) self.message_types[code](self._read(data_len - 4), cursor) if self.error is not None: raise self.error # Byte1('C') - Identifies the message as a close command. # Int32 - Message length, including self. # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to close. def close_prepared_statement(self, statement_name_bin): self._send_message(CLOSE, STATEMENT + statement_name_bin) self._write(SYNC_MSG) self._flush() self.handle_messages(self._cursor) # Byte1('N') - Identifier # Int32 - Message length # Any number of these, followed by a zero byte: # Byte1 - code identifying the field type (see responseKeys) # String - field value def handle_NOTICE_RESPONSE(self, data, ps): self.notices.append( dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE))) def handle_PARAMETER_STATUS(self, data, ps): pos = data.find(NULL_BYTE) key, value = data[:pos], data[pos + 1:-1] self.parameter_statuses.append((key, value)) if key == b"client_encoding": encoding = value.decode("ascii").lower() self._client_encoding = pg_to_py_encodings.get(encoding, encoding) elif key == b"integer_datetimes": if value == b'on': self.py_types[1114] = (1114, FC_BINARY, timestamp_send_integer) self.pg_types[1114] = (FC_BINARY, timestamp_recv_integer) self.py_types[1184] = ( 1184, FC_BINARY, timestamptz_send_integer) self.pg_types[1184] = (FC_BINARY, timestamptz_recv_integer) self.py_types[Interval] = ( 1186, FC_BINARY, interval_send_integer) self.py_types[Timedelta] = ( 1186, FC_BINARY, interval_send_integer) self.pg_types[1186] = (FC_BINARY, interval_recv_integer) else: self.py_types[1114] = (1114, FC_BINARY, timestamp_send_float) self.pg_types[1114] = (FC_BINARY, timestamp_recv_float) self.py_types[1184] = (1184, FC_BINARY, timestamptz_send_float) self.pg_types[1184] = (FC_BINARY, timestamptz_recv_float) self.py_types[Interval] = ( 1186, FC_BINARY, interval_send_float) self.py_types[Timedelta] = ( 1186, FC_BINARY, interval_send_float) self.pg_types[1186] = (FC_BINARY, interval_recv_float) elif key == b"server_version": self._server_version = LooseVersion(value.decode('ascii')) if self._server_version < LooseVersion('8.2.0'): self._commands_with_count = ( b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH") elif self._server_version < LooseVersion('9.0.0'): self._commands_with_count = ( b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY") def array_inspect(self, value): # Check if array has any values. If empty, we can just assume it's an # array of strings first_element = array_find_first_element(value) if first_element is None: oid = 25 # Use binary ARRAY format to avoid having to properly # escape text in the array literals fc = FC_BINARY array_oid = pg_array_types[oid] else: # supported array output typ = type(first_element) if issubclass(typ, int): # special int array support -- send as smallest possible array # type typ = int int2_ok, int4_ok, int8_ok = True, True, True for v in array_flatten(value): if v is None: continue if min_int2 < v < max_int2: continue int2_ok = False if min_int4 < v < max_int4: continue int4_ok = False if min_int8 < v < max_int8: continue int8_ok = False if int2_ok: array_oid = 1005 # INT2[] oid, fc, send_func = (21, FC_BINARY, h_pack) elif int4_ok: array_oid = 1007 # INT4[] oid, fc, send_func = (23, FC_BINARY, i_pack) elif int8_ok: array_oid = 1016 # INT8[] oid, fc, send_func = (20, FC_BINARY, q_pack) else: raise ArrayContentNotSupportedError( "numeric not supported as array contents") else: try: oid, fc, send_func = self.make_params((first_element,))[0] # If unknown or string, assume it's a string array if oid in (705, 1043, 25): oid = 25 # Use binary ARRAY format to avoid having to properly # escape text in the array literals fc = FC_BINARY array_oid = pg_array_types[oid] except KeyError: raise ArrayContentNotSupportedError( "oid " + str(oid) + " not supported as array contents") except NotSupportedError: raise ArrayContentNotSupportedError( "type " + str(typ) + " not supported as array contents") if fc == FC_BINARY: def send_array(arr): # check that all array dimensions are consistent array_check_dimensions(arr) has_null = array_has_null(arr) dim_lengths = array_dim_lengths(arr) data = bytearray(iii_pack(len(dim_lengths), has_null, oid)) for i in dim_lengths: data.extend(ii_pack(i, 1)) for v in array_flatten(arr): if v is None: data += i_pack(-1) elif isinstance(v, typ): inner_data = send_func(v) data += i_pack(len(inner_data)) data += inner_data else: raise ArrayContentNotHomogenousError( "not all array elements are of type " + str(typ)) return data else: def send_array(arr): array_check_dimensions(arr) ar = deepcopy(arr) for a, i, v in walk_array(ar): if v is None: a[i] = 'NULL' elif isinstance(v, typ): a[i] = send_func(v).decode('ascii') else: raise ArrayContentNotHomogenousError( "not all array elements are of type " + str(typ)) return str(ar).translate(arr_trans).encode('ascii') return (array_oid, fc, send_array) def xid(self, format_id, global_transaction_id, branch_qualifier): """Create a Transaction IDs (only global_transaction_id is used in pg) format_id and branch_qualifier are not used in postgres global_transaction_id may be any string identifier supported by postgres returns a tuple (format_id, global_transaction_id, branch_qualifier)""" return (format_id, global_transaction_id, branch_qualifier) def tpc_begin(self, xid): """Begins a TPC transaction with the given transaction ID xid. This method should be called outside of a transaction (i.e. nothing may have executed since the last .commit() or .rollback()). Furthermore, it is an error to call .commit() or .rollback() within the TPC transaction. A ProgrammingError is raised, if the application calls .commit() or .rollback() during an active TPC transaction. This function is part of the `DBAPI 2.0 specification `_. """ self._xid = xid if self.autocommit: self.execute(self._cursor, "begin transaction", None) def tpc_prepare(self): """Performs the first phase of a transaction started with .tpc_begin(). A ProgrammingError is be raised if this method is called outside of a TPC transaction. After calling .tpc_prepare(), no statements can be executed until .tpc_commit() or .tpc_rollback() have been called. This function is part of the `DBAPI 2.0 specification `_. """ q = "PREPARE TRANSACTION '%s';" % (self._xid[1],) self.execute(self._cursor, q, None) def tpc_commit(self, xid=None): """When called with no arguments, .tpc_commit() commits a TPC transaction previously prepared with .tpc_prepare(). If .tpc_commit() is called prior to .tpc_prepare(), a single phase commit is performed. A transaction manager may choose to do this if only a single resource is participating in the global transaction. When called with a transaction ID xid, the database commits the given transaction. If an invalid transaction ID is provided, a ProgrammingError will be raised. This form should be called outside of a transaction, and is intended for use in recovery. On return, the TPC transaction is ended. This function is part of the `DBAPI 2.0 specification `_. """ if xid is None: xid = self._xid if xid is None: raise ProgrammingError( "Cannot tpc_commit() without a TPC transaction!") try: previous_autocommit_mode = self.autocommit self.autocommit = True if xid in self.tpc_recover(): self.execute( self._cursor, "COMMIT PREPARED '%s';" % (xid[1], ), None) else: # a single-phase commit self.commit() finally: self.autocommit = previous_autocommit_mode self._xid = None def tpc_rollback(self, xid=None): """When called with no arguments, .tpc_rollback() rolls back a TPC transaction. It may be called before or after .tpc_prepare(). When called with a transaction ID xid, it rolls back the given transaction. If an invalid transaction ID is provided, a ProgrammingError is raised. This form should be called outside of a transaction, and is intended for use in recovery. On return, the TPC transaction is ended. This function is part of the `DBAPI 2.0 specification `_. """ if xid is None: xid = self._xid if xid is None: raise ProgrammingError( "Cannot tpc_rollback() without a TPC prepared transaction!") try: previous_autocommit_mode = self.autocommit self.autocommit = True if xid in self.tpc_recover(): # a two-phase rollback self.execute( self._cursor, "ROLLBACK PREPARED '%s';" % (xid[1],), None) else: # a single-phase rollback self.rollback() finally: self.autocommit = previous_autocommit_mode self._xid = None def tpc_recover(self): """Returns a list of pending transaction IDs suitable for use with .tpc_commit(xid) or .tpc_rollback(xid). This function is part of the `DBAPI 2.0 specification `_. """ try: previous_autocommit_mode = self.autocommit self.autocommit = True curs = self.cursor() curs.execute("select gid FROM pg_prepared_xacts") return [self.xid(0, row[0], '') for row in curs] finally: self.autocommit = previous_autocommit_mode # pg element oid -> pg array typeoid pg_array_types = { 16: 1000, 25: 1009, # TEXT[] 701: 1022, 1043: 1009, 1700: 1231, # NUMERIC[] } # PostgreSQL encodings: # http://www.postgresql.org/docs/8.3/interactive/multibyte.html # Python encodings: # http://www.python.org/doc/2.4/lib/standard-encodings.html # # Commented out encodings don't require a name change between PostgreSQL and # Python. If the py side is None, then the encoding isn't supported. pg_to_py_encodings = { # Not supported: "mule_internal": None, "euc_tw": None, # Name fine as-is: # "euc_jp", # "euc_jis_2004", # "euc_kr", # "gb18030", # "gbk", # "johab", # "sjis", # "shift_jis_2004", # "uhc", # "utf8", # Different name: "euc_cn": "gb2312", "iso_8859_5": "is8859_5", "iso_8859_6": "is8859_6", "iso_8859_7": "is8859_7", "iso_8859_8": "is8859_8", "koi8": "koi8_r", "latin1": "iso8859-1", "latin2": "iso8859_2", "latin3": "iso8859_3", "latin4": "iso8859_4", "latin5": "iso8859_9", "latin6": "iso8859_10", "latin7": "iso8859_13", "latin8": "iso8859_14", "latin9": "iso8859_15", "sql_ascii": "ascii", "win866": "cp886", "win874": "cp874", "win1250": "cp1250", "win1251": "cp1251", "win1252": "cp1252", "win1253": "cp1253", "win1254": "cp1254", "win1255": "cp1255", "win1256": "cp1256", "win1257": "cp1257", "win1258": "cp1258", "unicode": "utf-8", # Needed for Amazon Redshift } def walk_array(arr): for i, v in enumerate(arr): if isinstance(v, list): for a, i2, v2 in walk_array(v): yield a, i2, v2 else: yield arr, i, v def array_find_first_element(arr): for v in array_flatten(arr): if v is not None: return v return None def array_flatten(arr): for v in arr: if isinstance(v, list): for v2 in array_flatten(v): yield v2 else: yield v def array_check_dimensions(arr): if len(arr) > 0: v0 = arr[0] if isinstance(v0, list): req_len = len(v0) req_inner_lengths = array_check_dimensions(v0) for v in arr: inner_lengths = array_check_dimensions(v) if len(v) != req_len or inner_lengths != req_inner_lengths: raise ArrayDimensionsNotConsistentError( "array dimensions not consistent") retval = [req_len] retval.extend(req_inner_lengths) return retval else: # make sure nothing else at this level is a list for v in arr: if isinstance(v, list): raise ArrayDimensionsNotConsistentError( "array dimensions not consistent") return [] def array_has_null(arr): for v in array_flatten(arr): if v is None: return True return False def array_dim_lengths(arr): len_arr = len(arr) retval = [len_arr] if len_arr > 0: v0 = arr[0] if isinstance(v0, list): retval.extend(array_dim_lengths(v0)) return retval