2400 lines
82 KiB
Python
2400 lines
82 KiB
Python
from datetime import (
|
|
timedelta as Timedelta, datetime as Datetime, date, time)
|
|
from warnings import warn
|
|
import socket
|
|
from struct import pack
|
|
from hashlib import md5
|
|
from decimal import Decimal
|
|
from collections import deque, defaultdict
|
|
from itertools import count, islice
|
|
from uuid import UUID
|
|
from copy import deepcopy
|
|
from calendar import timegm
|
|
from distutils.version import LooseVersion
|
|
from struct import Struct
|
|
from time import localtime
|
|
import pg8000
|
|
from json import loads, dumps
|
|
from os import getpid
|
|
from .scramp import ScramClient
|
|
import enum
|
|
from ipaddress import (
|
|
ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network)
|
|
from datetime import timezone as Timezone
|
|
|
|
|
|
# Copyright (c) 2007-2009, Mathieu Fenniak
|
|
# Copyright (c) The Contributors
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are
|
|
# met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice,
|
|
# this list of conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
# * The name of the author may not be used to endorse or promote products
|
|
# derived from this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
# POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
__author__ = "Mathieu Fenniak"
|
|
|
|
|
|
ZERO = Timedelta(0)
|
|
BINARY = bytes
|
|
|
|
|
|
class Interval():
|
|
"""An Interval represents a measurement of time. In PostgreSQL, an
|
|
interval is defined in the measure of months, days, and microseconds; as
|
|
such, the pg8000 interval type represents the same information.
|
|
|
|
Note that values of the :attr:`microseconds`, :attr:`days` and
|
|
:attr:`months` properties are independently measured and cannot be
|
|
converted to each other. A month may be 28, 29, 30, or 31 days, and a day
|
|
may occasionally be lengthened slightly by a leap second.
|
|
|
|
.. attribute:: microseconds
|
|
|
|
Measure of microseconds in the interval.
|
|
|
|
The microseconds value is constrained to fit into a signed 64-bit
|
|
integer. Any attempt to set a value too large or too small will result
|
|
in an OverflowError being raised.
|
|
|
|
.. attribute:: days
|
|
|
|
Measure of days in the interval.
|
|
|
|
The days value is constrained to fit into a signed 32-bit integer.
|
|
Any attempt to set a value too large or too small will result in an
|
|
OverflowError being raised.
|
|
|
|
.. attribute:: months
|
|
|
|
Measure of months in the interval.
|
|
|
|
The months value is constrained to fit into a signed 32-bit integer.
|
|
Any attempt to set a value too large or too small will result in an
|
|
OverflowError being raised.
|
|
"""
|
|
|
|
def __init__(self, microseconds=0, days=0, months=0):
|
|
self.microseconds = microseconds
|
|
self.days = days
|
|
self.months = months
|
|
|
|
def _setMicroseconds(self, value):
|
|
if not isinstance(value, int):
|
|
raise TypeError("microseconds must be an integer type")
|
|
elif not (min_int8 < value < max_int8):
|
|
raise OverflowError(
|
|
"microseconds must be representable as a 64-bit integer")
|
|
else:
|
|
self._microseconds = value
|
|
|
|
def _setDays(self, value):
|
|
if not isinstance(value, int):
|
|
raise TypeError("days must be an integer type")
|
|
elif not (min_int4 < value < max_int4):
|
|
raise OverflowError(
|
|
"days must be representable as a 32-bit integer")
|
|
else:
|
|
self._days = value
|
|
|
|
def _setMonths(self, value):
|
|
if not isinstance(value, int):
|
|
raise TypeError("months must be an integer type")
|
|
elif not (min_int4 < value < max_int4):
|
|
raise OverflowError(
|
|
"months must be representable as a 32-bit integer")
|
|
else:
|
|
self._months = value
|
|
|
|
microseconds = property(lambda self: self._microseconds, _setMicroseconds)
|
|
days = property(lambda self: self._days, _setDays)
|
|
months = property(lambda self: self._months, _setMonths)
|
|
|
|
def __repr__(self):
|
|
return "<Interval %s months %s days %s microseconds>" % (
|
|
self.months, self.days, self.microseconds)
|
|
|
|
def __eq__(self, other):
|
|
return other is not None and isinstance(other, Interval) and \
|
|
self.months == other.months and self.days == other.days and \
|
|
self.microseconds == other.microseconds
|
|
|
|
def __neq__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
|
|
class PGType():
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def encode(self, encoding):
|
|
return str(self.value).encode(encoding)
|
|
|
|
|
|
class PGEnum(PGType):
|
|
def __init__(self, value):
|
|
if isinstance(value, str):
|
|
self.value = value
|
|
else:
|
|
self.value = value.value
|
|
|
|
|
|
class PGJson(PGType):
|
|
def encode(self, encoding):
|
|
return dumps(self.value).encode(encoding)
|
|
|
|
|
|
class PGJsonb(PGType):
|
|
def encode(self, encoding):
|
|
return dumps(self.value).encode(encoding)
|
|
|
|
|
|
class PGTsvector(PGType):
|
|
pass
|
|
|
|
|
|
class PGVarchar(str):
|
|
pass
|
|
|
|
|
|
class PGText(str):
|
|
pass
|
|
|
|
|
|
def pack_funcs(fmt):
|
|
struc = Struct('!' + fmt)
|
|
return struc.pack, struc.unpack_from
|
|
|
|
|
|
i_pack, i_unpack = pack_funcs('i')
|
|
h_pack, h_unpack = pack_funcs('h')
|
|
q_pack, q_unpack = pack_funcs('q')
|
|
d_pack, d_unpack = pack_funcs('d')
|
|
f_pack, f_unpack = pack_funcs('f')
|
|
iii_pack, iii_unpack = pack_funcs('iii')
|
|
ii_pack, ii_unpack = pack_funcs('ii')
|
|
qii_pack, qii_unpack = pack_funcs('qii')
|
|
dii_pack, dii_unpack = pack_funcs('dii')
|
|
ihihih_pack, ihihih_unpack = pack_funcs('ihihih')
|
|
ci_pack, ci_unpack = pack_funcs('ci')
|
|
bh_pack, bh_unpack = pack_funcs('bh')
|
|
cccc_pack, cccc_unpack = pack_funcs('cccc')
|
|
|
|
|
|
min_int2, max_int2 = -2 ** 15, 2 ** 15
|
|
min_int4, max_int4 = -2 ** 31, 2 ** 31
|
|
min_int8, max_int8 = -2 ** 63, 2 ** 63
|
|
|
|
|
|
class Warning(Exception):
|
|
"""Generic exception raised for important database warnings like data
|
|
truncations. This exception is not currently used by pg8000.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class Error(Exception):
|
|
"""Generic exception that is the base exception of all other error
|
|
exceptions.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class InterfaceError(Error):
|
|
"""Generic exception raised for errors that are related to the database
|
|
interface rather than the database itself. For example, if the interface
|
|
attempts to use an SSL connection but the server refuses, an InterfaceError
|
|
will be raised.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class DatabaseError(Error):
|
|
"""Generic exception raised for errors that are related to the database.
|
|
This exception is currently never raised by pg8000.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class DataError(DatabaseError):
|
|
"""Generic exception raised for errors that are due to problems with the
|
|
processed data. This exception is not currently raised by pg8000.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class OperationalError(DatabaseError):
|
|
"""
|
|
Generic exception raised for errors that are related to the database's
|
|
operation and not necessarily under the control of the programmer. This
|
|
exception is currently never raised by pg8000.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class IntegrityError(DatabaseError):
|
|
"""
|
|
Generic exception raised when the relational integrity of the database is
|
|
affected. This exception is not currently raised by pg8000.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class InternalError(DatabaseError):
|
|
"""Generic exception raised when the database encounters an internal error.
|
|
This is currently only raised when unexpected state occurs in the pg8000
|
|
interface itself, and is typically the result of a interface bug.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ProgrammingError(DatabaseError):
|
|
"""Generic exception raised for programming errors. For example, this
|
|
exception is raised if more parameter fields are in a query string than
|
|
there are available parameters.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class NotSupportedError(DatabaseError):
|
|
"""Generic exception raised in case a method or database API was used which
|
|
is not supported by the database.
|
|
|
|
This exception is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ArrayContentNotSupportedError(NotSupportedError):
|
|
"""
|
|
Raised when attempting to transmit an array where the base type is not
|
|
supported for binary data transfer by the interface.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ArrayContentNotHomogenousError(ProgrammingError):
|
|
"""
|
|
Raised when attempting to transmit an array that doesn't contain only a
|
|
single type of object.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ArrayDimensionsNotConsistentError(ProgrammingError):
|
|
"""
|
|
Raised when attempting to transmit an array that has inconsistent
|
|
multi-dimension sizes.
|
|
"""
|
|
pass
|
|
|
|
|
|
def Date(year, month, day):
|
|
"""Constuct an object holding a date value.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:rtype: :class:`datetime.date`
|
|
"""
|
|
return date(year, month, day)
|
|
|
|
|
|
def Time(hour, minute, second):
|
|
"""Construct an object holding a time value.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:rtype: :class:`datetime.time`
|
|
"""
|
|
return time(hour, minute, second)
|
|
|
|
|
|
def Timestamp(year, month, day, hour, minute, second):
|
|
"""Construct an object holding a timestamp value.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:rtype: :class:`datetime.datetime`
|
|
"""
|
|
return Datetime(year, month, day, hour, minute, second)
|
|
|
|
|
|
def DateFromTicks(ticks):
|
|
"""Construct an object holding a date value from the given ticks value
|
|
(number of seconds since the epoch).
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:rtype: :class:`datetime.date`
|
|
"""
|
|
return Date(*localtime(ticks)[:3])
|
|
|
|
|
|
def TimeFromTicks(ticks):
|
|
"""Construct an objet holding a time value from the given ticks value
|
|
(number of seconds since the epoch).
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:rtype: :class:`datetime.time`
|
|
"""
|
|
return Time(*localtime(ticks)[3:6])
|
|
|
|
|
|
def TimestampFromTicks(ticks):
|
|
"""Construct an object holding a timestamp value from the given ticks value
|
|
(number of seconds since the epoch).
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:rtype: :class:`datetime.datetime`
|
|
"""
|
|
return Timestamp(*localtime(ticks)[:6])
|
|
|
|
|
|
def Binary(value):
|
|
"""Construct an object holding binary data.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
"""
|
|
return value
|
|
|
|
|
|
FC_TEXT = 0
|
|
FC_BINARY = 1
|
|
|
|
|
|
def convert_paramstyle(style, query):
|
|
# I don't see any way to avoid scanning the query string char by char,
|
|
# so we might as well take that careful approach and create a
|
|
# state-based scanner. We'll use int variables for the state.
|
|
OUTSIDE = 0 # outside quoted string
|
|
INSIDE_SQ = 1 # inside single-quote string '...'
|
|
INSIDE_QI = 2 # inside quoted identifier "..."
|
|
INSIDE_ES = 3 # inside escaped single-quote string, E'...'
|
|
INSIDE_PN = 4 # inside parameter name eg. :name
|
|
INSIDE_CO = 5 # inside inline comment eg. --
|
|
|
|
in_quote_escape = False
|
|
in_param_escape = False
|
|
placeholders = []
|
|
output_query = []
|
|
param_idx = map(lambda x: "$" + str(x), count(1))
|
|
state = OUTSIDE
|
|
prev_c = None
|
|
for i, c in enumerate(query):
|
|
if i + 1 < len(query):
|
|
next_c = query[i + 1]
|
|
else:
|
|
next_c = None
|
|
|
|
if state == OUTSIDE:
|
|
if c == "'":
|
|
output_query.append(c)
|
|
if prev_c == 'E':
|
|
state = INSIDE_ES
|
|
else:
|
|
state = INSIDE_SQ
|
|
elif c == '"':
|
|
output_query.append(c)
|
|
state = INSIDE_QI
|
|
elif c == '-':
|
|
output_query.append(c)
|
|
if prev_c == '-':
|
|
state = INSIDE_CO
|
|
elif style == "qmark" and c == "?":
|
|
output_query.append(next(param_idx))
|
|
elif style == "numeric" and c == ":" and next_c not in ':=' \
|
|
and prev_c != ':':
|
|
# Treat : as beginning of parameter name if and only
|
|
# if it's the only : around
|
|
# Needed to properly process type conversions
|
|
# i.e. sum(x)::float
|
|
output_query.append("$")
|
|
elif style == "named" and c == ":" and next_c not in ':=' \
|
|
and prev_c != ':':
|
|
# Same logic for : as in numeric parameters
|
|
state = INSIDE_PN
|
|
placeholders.append('')
|
|
elif style == "pyformat" and c == '%' and next_c == "(":
|
|
state = INSIDE_PN
|
|
placeholders.append('')
|
|
elif style in ("format", "pyformat") and c == "%":
|
|
style = "format"
|
|
if in_param_escape:
|
|
in_param_escape = False
|
|
output_query.append(c)
|
|
else:
|
|
if next_c == "%":
|
|
in_param_escape = True
|
|
elif next_c == "s":
|
|
state = INSIDE_PN
|
|
output_query.append(next(param_idx))
|
|
else:
|
|
raise InterfaceError(
|
|
"Only %s and %% are supported in the query.")
|
|
else:
|
|
output_query.append(c)
|
|
|
|
elif state == INSIDE_SQ:
|
|
if c == "'":
|
|
if in_quote_escape:
|
|
in_quote_escape = False
|
|
else:
|
|
if next_c == "'":
|
|
in_quote_escape = True
|
|
else:
|
|
state = OUTSIDE
|
|
output_query.append(c)
|
|
|
|
elif state == INSIDE_QI:
|
|
if c == '"':
|
|
state = OUTSIDE
|
|
output_query.append(c)
|
|
|
|
elif state == INSIDE_ES:
|
|
if c == "'" and prev_c != "\\":
|
|
# check for escaped single-quote
|
|
state = OUTSIDE
|
|
output_query.append(c)
|
|
|
|
elif state == INSIDE_PN:
|
|
if style == 'named':
|
|
placeholders[-1] += c
|
|
if next_c is None or (not next_c.isalnum() and next_c != '_'):
|
|
state = OUTSIDE
|
|
try:
|
|
pidx = placeholders.index(placeholders[-1], 0, -1)
|
|
output_query.append("$" + str(pidx + 1))
|
|
del placeholders[-1]
|
|
except ValueError:
|
|
output_query.append("$" + str(len(placeholders)))
|
|
elif style == 'pyformat':
|
|
if prev_c == ')' and c == "s":
|
|
state = OUTSIDE
|
|
try:
|
|
pidx = placeholders.index(placeholders[-1], 0, -1)
|
|
output_query.append("$" + str(pidx + 1))
|
|
del placeholders[-1]
|
|
except ValueError:
|
|
output_query.append("$" + str(len(placeholders)))
|
|
elif c in "()":
|
|
pass
|
|
else:
|
|
placeholders[-1] += c
|
|
elif style == 'format':
|
|
state = OUTSIDE
|
|
|
|
elif state == INSIDE_CO:
|
|
output_query.append(c)
|
|
if c == '\n':
|
|
state = OUTSIDE
|
|
|
|
prev_c = c
|
|
|
|
if style in ('numeric', 'qmark', 'format'):
|
|
def make_args(vals):
|
|
return vals
|
|
else:
|
|
def make_args(vals):
|
|
return tuple(vals[p] for p in placeholders)
|
|
|
|
return ''.join(output_query), make_args
|
|
|
|
|
|
EPOCH = Datetime(2000, 1, 1)
|
|
EPOCH_TZ = EPOCH.replace(tzinfo=Timezone.utc)
|
|
EPOCH_SECONDS = timegm(EPOCH.timetuple())
|
|
INFINITY_MICROSECONDS = 2 ** 63 - 1
|
|
MINUS_INFINITY_MICROSECONDS = -1 * INFINITY_MICROSECONDS - 1
|
|
|
|
|
|
# data is 64-bit integer representing microseconds since 2000-01-01
|
|
def timestamp_recv_integer(data, offset, length):
|
|
micros = q_unpack(data, offset)[0]
|
|
try:
|
|
return EPOCH + Timedelta(microseconds=micros)
|
|
except OverflowError:
|
|
if micros == INFINITY_MICROSECONDS:
|
|
return 'infinity'
|
|
elif micros == MINUS_INFINITY_MICROSECONDS:
|
|
return '-infinity'
|
|
else:
|
|
return micros
|
|
|
|
|
|
# data is double-precision float representing seconds since 2000-01-01
|
|
def timestamp_recv_float(data, offset, length):
|
|
return Datetime.utcfromtimestamp(EPOCH_SECONDS + d_unpack(data, offset)[0])
|
|
|
|
|
|
# data is 64-bit integer representing microseconds since 2000-01-01
|
|
def timestamp_send_integer(v):
|
|
return q_pack(
|
|
int((timegm(v.timetuple()) - EPOCH_SECONDS) * 1e6) + v.microsecond)
|
|
|
|
|
|
# data is double-precision float representing seconds since 2000-01-01
|
|
def timestamp_send_float(v):
|
|
return d_pack(timegm(v.timetuple()) + v.microsecond / 1e6 - EPOCH_SECONDS)
|
|
|
|
|
|
def timestamptz_send_integer(v):
|
|
# timestamps should be sent as UTC. If they have zone info,
|
|
# convert them.
|
|
return timestamp_send_integer(
|
|
v.astimezone(Timezone.utc).replace(tzinfo=None))
|
|
|
|
|
|
def timestamptz_send_float(v):
|
|
# timestamps should be sent as UTC. If they have zone info,
|
|
# convert them.
|
|
return timestamp_send_float(
|
|
v.astimezone(Timezone.utc).replace(tzinfo=None))
|
|
|
|
|
|
# return a timezone-aware datetime instance if we're reading from a
|
|
# "timestamp with timezone" type. The timezone returned will always be
|
|
# UTC, but providing that additional information can permit conversion
|
|
# to local.
|
|
def timestamptz_recv_integer(data, offset, length):
|
|
micros = q_unpack(data, offset)[0]
|
|
try:
|
|
return EPOCH_TZ + Timedelta(microseconds=micros)
|
|
except OverflowError:
|
|
if micros == INFINITY_MICROSECONDS:
|
|
return 'infinity'
|
|
elif micros == MINUS_INFINITY_MICROSECONDS:
|
|
return '-infinity'
|
|
else:
|
|
return micros
|
|
|
|
|
|
def timestamptz_recv_float(data, offset, length):
|
|
return timestamp_recv_float(data, offset, length).replace(
|
|
tzinfo=Timezone.utc)
|
|
|
|
|
|
def interval_send_integer(v):
|
|
microseconds = v.microseconds
|
|
try:
|
|
microseconds += int(v.seconds * 1e6)
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
months = v.months
|
|
except AttributeError:
|
|
months = 0
|
|
|
|
return qii_pack(microseconds, v.days, months)
|
|
|
|
|
|
def interval_send_float(v):
|
|
seconds = v.microseconds / 1000.0 / 1000.0
|
|
try:
|
|
seconds += v.seconds
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
months = v.months
|
|
except AttributeError:
|
|
months = 0
|
|
|
|
return dii_pack(seconds, v.days, months)
|
|
|
|
|
|
def interval_recv_integer(data, offset, length):
|
|
microseconds, days, months = qii_unpack(data, offset)
|
|
if months == 0:
|
|
seconds, micros = divmod(microseconds, 1e6)
|
|
return Timedelta(days, seconds, micros)
|
|
else:
|
|
return Interval(microseconds, days, months)
|
|
|
|
|
|
def interval_recv_float(data, offset, length):
|
|
seconds, days, months = dii_unpack(data, offset)
|
|
if months == 0:
|
|
secs, microseconds = divmod(seconds, 1e6)
|
|
return Timedelta(days, secs, microseconds)
|
|
else:
|
|
return Interval(int(seconds * 1000 * 1000), days, months)
|
|
|
|
|
|
def int8_recv(data, offset, length):
|
|
return q_unpack(data, offset)[0]
|
|
|
|
|
|
def int2_recv(data, offset, length):
|
|
return h_unpack(data, offset)[0]
|
|
|
|
|
|
def int4_recv(data, offset, length):
|
|
return i_unpack(data, offset)[0]
|
|
|
|
|
|
def float4_recv(data, offset, length):
|
|
return f_unpack(data, offset)[0]
|
|
|
|
|
|
def float8_recv(data, offset, length):
|
|
return d_unpack(data, offset)[0]
|
|
|
|
|
|
def bytea_send(v):
|
|
return v
|
|
|
|
|
|
# bytea
|
|
def bytea_recv(data, offset, length):
|
|
return data[offset:offset + length]
|
|
|
|
|
|
def uuid_send(v):
|
|
return v.bytes
|
|
|
|
|
|
def uuid_recv(data, offset, length):
|
|
return UUID(bytes=data[offset:offset+length])
|
|
|
|
|
|
def bool_send(v):
|
|
return b"\x01" if v else b"\x00"
|
|
|
|
|
|
NULL = i_pack(-1)
|
|
|
|
NULL_BYTE = b'\x00'
|
|
|
|
|
|
def null_send(v):
|
|
return NULL
|
|
|
|
|
|
def int_in(data, offset, length):
|
|
return int(data[offset: offset + length])
|
|
|
|
|
|
class Cursor():
|
|
"""A cursor object is returned by the :meth:`~Connection.cursor` method of
|
|
a connection. It has the following attributes and methods:
|
|
|
|
.. attribute:: arraysize
|
|
|
|
This read/write attribute specifies the number of rows to fetch at a
|
|
time with :meth:`fetchmany`. It defaults to 1.
|
|
|
|
.. attribute:: connection
|
|
|
|
This read-only attribute contains a reference to the connection object
|
|
(an instance of :class:`Connection`) on which the cursor was
|
|
created.
|
|
|
|
This attribute is part of a DBAPI 2.0 extension. Accessing this
|
|
attribute will generate the following warning: ``DB-API extension
|
|
cursor.connection used``.
|
|
|
|
.. attribute:: rowcount
|
|
|
|
This read-only attribute contains the number of rows that the last
|
|
``execute()`` or ``executemany()`` method produced (for query
|
|
statements like ``SELECT``) or affected (for modification statements
|
|
like ``UPDATE``).
|
|
|
|
The value is -1 if:
|
|
|
|
- No ``execute()`` or ``executemany()`` method has been performed yet
|
|
on the cursor.
|
|
- There was no rowcount associated with the last ``execute()``.
|
|
- At least one of the statements executed as part of an
|
|
``executemany()`` had no row count associated with it.
|
|
- Using a ``SELECT`` query statement on PostgreSQL server older than
|
|
version 9.
|
|
- Using a ``COPY`` query statement on PostgreSQL server version 8.1 or
|
|
older.
|
|
|
|
This attribute is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
.. attribute:: description
|
|
|
|
This read-only attribute is a sequence of 7-item sequences. Each value
|
|
contains information describing one result column. The 7 items
|
|
returned for each column are (name, type_code, display_size,
|
|
internal_size, precision, scale, null_ok). Only the first two values
|
|
are provided by the current implementation.
|
|
|
|
This attribute is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
|
|
def __init__(self, connection):
|
|
self._c = connection
|
|
self.arraysize = 1
|
|
self.ps = None
|
|
self._row_count = -1
|
|
self._cached_rows = deque()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close()
|
|
|
|
@property
|
|
def connection(self):
|
|
warn("DB-API extension cursor.connection used", stacklevel=3)
|
|
return self._c
|
|
|
|
@property
|
|
def rowcount(self):
|
|
return self._row_count
|
|
|
|
description = property(lambda self: self._getDescription())
|
|
|
|
def _getDescription(self):
|
|
if self.ps is None:
|
|
return None
|
|
row_desc = self.ps['row_desc']
|
|
if len(row_desc) == 0:
|
|
return None
|
|
columns = []
|
|
for col in row_desc:
|
|
columns.append(
|
|
(col["name"], col["type_oid"], None, None, None, None, None))
|
|
return columns
|
|
|
|
##
|
|
# Executes a database operation. Parameters may be provided as a sequence
|
|
# or mapping and will be bound to variables in the operation.
|
|
# <p>
|
|
# Stability: Part of the DBAPI 2.0 specification.
|
|
def execute(self, operation, args=None, stream=None):
|
|
"""Executes a database operation. Parameters may be provided as a
|
|
sequence, or as a mapping, depending upon the value of
|
|
:data:`pg8000.paramstyle`.
|
|
|
|
This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:param operation:
|
|
The SQL statement to execute.
|
|
|
|
:param args:
|
|
If :data:`paramstyle` is ``qmark``, ``numeric``, or ``format``,
|
|
this argument should be an array of parameters to bind into the
|
|
statement. If :data:`paramstyle` is ``named``, the argument should
|
|
be a dict mapping of parameters. If the :data:`paramstyle` is
|
|
``pyformat``, the argument value may be either an array or a
|
|
mapping.
|
|
|
|
:param stream: This is a pg8000 extension for use with the PostgreSQL
|
|
`COPY
|
|
<http://www.postgresql.org/docs/current/static/sql-copy.html>`_
|
|
command. For a COPY FROM the parameter must be a readable file-like
|
|
object, and for COPY TO it must be writable.
|
|
|
|
.. versionadded:: 1.9.11
|
|
"""
|
|
try:
|
|
self.stream = stream
|
|
|
|
if not self._c.in_transaction and not self._c.autocommit:
|
|
self._c.execute(self, "begin transaction", None)
|
|
self._c.execute(self, operation, args)
|
|
except AttributeError as e:
|
|
if self._c is None:
|
|
raise InterfaceError("Cursor closed")
|
|
elif self._c._sock is None:
|
|
raise InterfaceError("connection is closed")
|
|
else:
|
|
raise e
|
|
return self
|
|
|
|
def executemany(self, operation, param_sets):
|
|
"""Prepare a database operation, and then execute it against all
|
|
parameter sequences or mappings provided.
|
|
|
|
This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:param operation:
|
|
The SQL statement to execute
|
|
:param parameter_sets:
|
|
A sequence of parameters to execute the statement with. The values
|
|
in the sequence should be sequences or mappings of parameters, the
|
|
same as the args argument of the :meth:`execute` method.
|
|
"""
|
|
rowcounts = []
|
|
for parameters in param_sets:
|
|
self.execute(operation, parameters)
|
|
rowcounts.append(self._row_count)
|
|
|
|
self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
|
|
return self
|
|
|
|
def fetchone(self):
|
|
"""Fetch the next row of a query result set.
|
|
|
|
This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:returns:
|
|
A row as a sequence of field values, or ``None`` if no more rows
|
|
are available.
|
|
"""
|
|
try:
|
|
return next(self)
|
|
except StopIteration:
|
|
return None
|
|
except TypeError:
|
|
raise ProgrammingError("attempting to use unexecuted cursor")
|
|
except AttributeError:
|
|
raise ProgrammingError("attempting to use unexecuted cursor")
|
|
|
|
def fetchmany(self, num=None):
|
|
"""Fetches the next set of rows of a query result.
|
|
|
|
This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:param size:
|
|
|
|
The number of rows to fetch when called. If not provided, the
|
|
:attr:`arraysize` attribute value is used instead.
|
|
|
|
:returns:
|
|
|
|
A sequence, each entry of which is a sequence of field values
|
|
making up a row. If no more rows are available, an empty sequence
|
|
will be returned.
|
|
"""
|
|
try:
|
|
return tuple(
|
|
islice(self, self.arraysize if num is None else num))
|
|
except TypeError:
|
|
raise ProgrammingError("attempting to use unexecuted cursor")
|
|
|
|
def fetchall(self):
|
|
"""Fetches all remaining rows of a query result.
|
|
|
|
This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
|
|
:returns:
|
|
|
|
A sequence, each entry of which is a sequence of field values
|
|
making up a row.
|
|
"""
|
|
try:
|
|
return tuple(self)
|
|
except TypeError:
|
|
raise ProgrammingError("attempting to use unexecuted cursor")
|
|
|
|
def close(self):
|
|
"""Closes the cursor.
|
|
|
|
This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
self._c = None
|
|
|
|
def __iter__(self):
|
|
"""A cursor object is iterable to retrieve the rows from a query.
|
|
|
|
This is a DBAPI 2.0 extension.
|
|
"""
|
|
return self
|
|
|
|
def setinputsizes(self, sizes):
|
|
"""This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_, however, it is not
|
|
implemented by pg8000.
|
|
"""
|
|
pass
|
|
|
|
def setoutputsize(self, size, column=None):
|
|
"""This method is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_, however, it is not
|
|
implemented by pg8000.
|
|
"""
|
|
pass
|
|
|
|
def __next__(self):
|
|
try:
|
|
return self._cached_rows.popleft()
|
|
except IndexError:
|
|
if self.ps is None:
|
|
raise ProgrammingError("A query hasn't been issued.")
|
|
elif len(self.ps['row_desc']) == 0:
|
|
raise ProgrammingError("no result set")
|
|
else:
|
|
raise StopIteration()
|
|
|
|
|
|
# Message codes
|
|
NOTICE_RESPONSE = b"N"
|
|
AUTHENTICATION_REQUEST = b"R"
|
|
PARAMETER_STATUS = b"S"
|
|
BACKEND_KEY_DATA = b"K"
|
|
READY_FOR_QUERY = b"Z"
|
|
ROW_DESCRIPTION = b"T"
|
|
ERROR_RESPONSE = b"E"
|
|
DATA_ROW = b"D"
|
|
COMMAND_COMPLETE = b"C"
|
|
PARSE_COMPLETE = b"1"
|
|
BIND_COMPLETE = b"2"
|
|
CLOSE_COMPLETE = b"3"
|
|
PORTAL_SUSPENDED = b"s"
|
|
NO_DATA = b"n"
|
|
PARAMETER_DESCRIPTION = b"t"
|
|
NOTIFICATION_RESPONSE = b"A"
|
|
COPY_DONE = b"c"
|
|
COPY_DATA = b"d"
|
|
COPY_IN_RESPONSE = b"G"
|
|
COPY_OUT_RESPONSE = b"H"
|
|
EMPTY_QUERY_RESPONSE = b"I"
|
|
|
|
BIND = b"B"
|
|
PARSE = b"P"
|
|
EXECUTE = b"E"
|
|
FLUSH = b'H'
|
|
SYNC = b'S'
|
|
PASSWORD = b'p'
|
|
DESCRIBE = b'D'
|
|
TERMINATE = b'X'
|
|
CLOSE = b'C'
|
|
|
|
|
|
def _establish_ssl(_socket, ssl_params):
|
|
if not isinstance(ssl_params, dict):
|
|
ssl_params = {}
|
|
|
|
try:
|
|
import ssl as sslmodule
|
|
|
|
keyfile = ssl_params.get('keyfile')
|
|
certfile = ssl_params.get('certfile')
|
|
ca_certs = ssl_params.get('ca_certs')
|
|
if ca_certs is None:
|
|
verify_mode = sslmodule.CERT_NONE
|
|
else:
|
|
verify_mode = sslmodule.CERT_REQUIRED
|
|
|
|
# Int32(8) - Message length, including self.
|
|
# Int32(80877103) - The SSL request code.
|
|
_socket.sendall(ii_pack(8, 80877103))
|
|
resp = _socket.recv(1)
|
|
if resp == b'S':
|
|
return sslmodule.wrap_socket(
|
|
_socket, keyfile=keyfile, certfile=certfile,
|
|
cert_reqs=verify_mode, ca_certs=ca_certs)
|
|
else:
|
|
raise InterfaceError("Server refuses SSL")
|
|
except ImportError:
|
|
raise InterfaceError(
|
|
"SSL required but ssl module not available in "
|
|
"this python installation")
|
|
|
|
|
|
def create_message(code, data=b''):
|
|
return code + i_pack(len(data) + 4) + data
|
|
|
|
|
|
FLUSH_MSG = create_message(FLUSH)
|
|
SYNC_MSG = create_message(SYNC)
|
|
TERMINATE_MSG = create_message(TERMINATE)
|
|
COPY_DONE_MSG = create_message(COPY_DONE)
|
|
EXECUTE_MSG = create_message(EXECUTE, NULL_BYTE + i_pack(0))
|
|
|
|
# DESCRIBE constants
|
|
STATEMENT = b'S'
|
|
PORTAL = b'P'
|
|
|
|
# ErrorResponse codes
|
|
RESPONSE_SEVERITY = "S" # always present
|
|
RESPONSE_SEVERITY = "V" # always present
|
|
RESPONSE_CODE = "C" # always present
|
|
RESPONSE_MSG = "M" # always present
|
|
RESPONSE_DETAIL = "D"
|
|
RESPONSE_HINT = "H"
|
|
RESPONSE_POSITION = "P"
|
|
RESPONSE__POSITION = "p"
|
|
RESPONSE__QUERY = "q"
|
|
RESPONSE_WHERE = "W"
|
|
RESPONSE_FILE = "F"
|
|
RESPONSE_LINE = "L"
|
|
RESPONSE_ROUTINE = "R"
|
|
|
|
IDLE = b"I"
|
|
IDLE_IN_TRANSACTION = b"T"
|
|
IDLE_IN_FAILED_TRANSACTION = b"E"
|
|
|
|
|
|
arr_trans = dict(zip(map(ord, "[] 'u"), list('{}') + [None] * 3))
|
|
|
|
|
|
class Connection():
|
|
|
|
# DBAPI Extension: supply exceptions as attributes on the connection
|
|
Warning = property(lambda self: self._getError(Warning))
|
|
Error = property(lambda self: self._getError(Error))
|
|
InterfaceError = property(lambda self: self._getError(InterfaceError))
|
|
DatabaseError = property(lambda self: self._getError(DatabaseError))
|
|
OperationalError = property(lambda self: self._getError(OperationalError))
|
|
IntegrityError = property(lambda self: self._getError(IntegrityError))
|
|
InternalError = property(lambda self: self._getError(InternalError))
|
|
ProgrammingError = property(lambda self: self._getError(ProgrammingError))
|
|
NotSupportedError = property(
|
|
lambda self: self._getError(NotSupportedError))
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close()
|
|
|
|
def _getError(self, error):
|
|
warn(
|
|
"DB-API extension connection.%s used" %
|
|
error.__name__, stacklevel=3)
|
|
return error
|
|
|
|
def __init__(
|
|
self, user, host, unix_sock, port, database, password, ssl,
|
|
timeout, application_name, max_prepared_statements, tcp_keepalive):
|
|
self._client_encoding = "utf8"
|
|
self._commands_with_count = (
|
|
b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY",
|
|
b"SELECT")
|
|
self.notifications = deque(maxlen=100)
|
|
self.notices = deque(maxlen=100)
|
|
self.parameter_statuses = deque(maxlen=100)
|
|
self.max_prepared_statements = int(max_prepared_statements)
|
|
|
|
if user is None:
|
|
raise InterfaceError(
|
|
"The 'user' connection parameter cannot be None")
|
|
|
|
if isinstance(user, str):
|
|
self.user = user.encode('utf8')
|
|
else:
|
|
self.user = user
|
|
|
|
if isinstance(password, str):
|
|
self.password = password.encode('utf8')
|
|
else:
|
|
self.password = password
|
|
|
|
self.autocommit = False
|
|
self._xid = None
|
|
|
|
self._caches = {}
|
|
|
|
try:
|
|
if unix_sock is None and host is not None:
|
|
self._usock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
elif unix_sock is not None:
|
|
if not hasattr(socket, "AF_UNIX"):
|
|
raise InterfaceError(
|
|
"attempt to connect to unix socket on unsupported "
|
|
"platform")
|
|
self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
else:
|
|
raise ProgrammingError(
|
|
"one of host or unix_sock must be provided")
|
|
if timeout is not None:
|
|
self._usock.settimeout(timeout)
|
|
|
|
if unix_sock is None and host is not None:
|
|
self._usock.connect((host, port))
|
|
elif unix_sock is not None:
|
|
self._usock.connect(unix_sock)
|
|
|
|
if ssl:
|
|
self._usock = _establish_ssl(self._usock, ssl)
|
|
|
|
self._sock = self._usock.makefile(mode="rwb")
|
|
if tcp_keepalive:
|
|
self._usock.setsockopt(
|
|
socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
except socket.error as e:
|
|
self._usock.close()
|
|
raise InterfaceError("communication error", e)
|
|
self._flush = self._sock.flush
|
|
self._read = self._sock.read
|
|
self._write = self._sock.write
|
|
self._backend_key_data = None
|
|
|
|
def text_out(v):
|
|
return v.encode(self._client_encoding)
|
|
|
|
def enum_out(v):
|
|
return str(v.value).encode(self._client_encoding)
|
|
|
|
def time_out(v):
|
|
return v.isoformat().encode(self._client_encoding)
|
|
|
|
def date_out(v):
|
|
return v.isoformat().encode(self._client_encoding)
|
|
|
|
def unknown_out(v):
|
|
return str(v).encode(self._client_encoding)
|
|
|
|
trans_tab = dict(zip(map(ord, '{}'), '[]'))
|
|
glbls = {'Decimal': Decimal}
|
|
|
|
def array_in(data, idx, length):
|
|
arr = []
|
|
prev_c = None
|
|
for c in data[idx:idx+length].decode(
|
|
self._client_encoding).translate(
|
|
trans_tab).replace('NULL', 'None'):
|
|
if c not in ('[', ']', ',', 'N') and prev_c in ('[', ','):
|
|
arr.extend("Decimal('")
|
|
elif c in (']', ',') and prev_c not in ('[', ']', ',', 'e'):
|
|
arr.extend("')")
|
|
|
|
arr.append(c)
|
|
prev_c = c
|
|
return eval(''.join(arr), glbls)
|
|
|
|
def array_recv(data, idx, length):
|
|
final_idx = idx + length
|
|
dim, hasnull, typeoid = iii_unpack(data, idx)
|
|
idx += 12
|
|
|
|
# get type conversion method for typeoid
|
|
conversion = self.pg_types[typeoid][1]
|
|
|
|
# Read dimension info
|
|
dim_lengths = []
|
|
for i in range(dim):
|
|
dim_lengths.append(ii_unpack(data, idx)[0])
|
|
idx += 8
|
|
|
|
# Read all array values
|
|
values = []
|
|
while idx < final_idx:
|
|
element_len, = i_unpack(data, idx)
|
|
idx += 4
|
|
if element_len == -1:
|
|
values.append(None)
|
|
else:
|
|
values.append(conversion(data, idx, element_len))
|
|
idx += element_len
|
|
|
|
# at this point, {{1,2,3},{4,5,6}}::int[][] looks like
|
|
# [1,2,3,4,5,6]. go through the dimensions and fix up the array
|
|
# contents to match expected dimensions
|
|
for length in reversed(dim_lengths[1:]):
|
|
values = list(map(list, zip(*[iter(values)] * length)))
|
|
return values
|
|
|
|
def vector_in(data, idx, length):
|
|
return eval('[' + data[idx:idx+length].decode(
|
|
self._client_encoding).replace(' ', ',') + ']')
|
|
|
|
def text_recv(data, offset, length):
|
|
return str(data[offset: offset + length], self._client_encoding)
|
|
|
|
def bool_recv(data, offset, length):
|
|
return data[offset] == 1
|
|
|
|
def json_in(data, offset, length):
|
|
return loads(
|
|
str(data[offset: offset + length], self._client_encoding))
|
|
|
|
def time_in(data, offset, length):
|
|
hour = int(data[offset:offset + 2])
|
|
minute = int(data[offset + 3:offset + 5])
|
|
sec = Decimal(
|
|
data[offset + 6:offset + length].decode(self._client_encoding))
|
|
return time(
|
|
hour, minute, int(sec), int((sec - int(sec)) * 1000000))
|
|
|
|
def date_in(data, offset, length):
|
|
d = data[offset:offset+length].decode(self._client_encoding)
|
|
try:
|
|
return date(int(d[:4]), int(d[5:7]), int(d[8:10]))
|
|
except ValueError:
|
|
return d
|
|
|
|
def numeric_in(data, offset, length):
|
|
return Decimal(
|
|
data[offset: offset + length].decode(self._client_encoding))
|
|
|
|
def numeric_out(d):
|
|
return str(d).encode(self._client_encoding)
|
|
|
|
self.pg_types = defaultdict(
|
|
lambda: (FC_TEXT, text_recv), {
|
|
16: (FC_BINARY, bool_recv), # boolean
|
|
17: (FC_BINARY, bytea_recv), # bytea
|
|
19: (FC_BINARY, text_recv), # name type
|
|
20: (FC_BINARY, int8_recv), # int8
|
|
21: (FC_BINARY, int2_recv), # int2
|
|
22: (FC_TEXT, vector_in), # int2vector
|
|
23: (FC_BINARY, int4_recv), # int4
|
|
25: (FC_BINARY, text_recv), # TEXT type
|
|
26: (FC_TEXT, int_in), # oid
|
|
28: (FC_TEXT, int_in), # xid
|
|
114: (FC_TEXT, json_in), # json
|
|
700: (FC_BINARY, float4_recv), # float4
|
|
701: (FC_BINARY, float8_recv), # float8
|
|
705: (FC_BINARY, text_recv), # unknown
|
|
829: (FC_TEXT, text_recv), # MACADDR type
|
|
1000: (FC_BINARY, array_recv), # BOOL[]
|
|
1003: (FC_BINARY, array_recv), # NAME[]
|
|
1005: (FC_BINARY, array_recv), # INT2[]
|
|
1007: (FC_BINARY, array_recv), # INT4[]
|
|
1009: (FC_BINARY, array_recv), # TEXT[]
|
|
1014: (FC_BINARY, array_recv), # CHAR[]
|
|
1015: (FC_BINARY, array_recv), # VARCHAR[]
|
|
1016: (FC_BINARY, array_recv), # INT8[]
|
|
1021: (FC_BINARY, array_recv), # FLOAT4[]
|
|
1022: (FC_BINARY, array_recv), # FLOAT8[]
|
|
1042: (FC_BINARY, text_recv), # CHAR type
|
|
1043: (FC_BINARY, text_recv), # VARCHAR type
|
|
1082: (FC_TEXT, date_in), # date
|
|
1083: (FC_TEXT, time_in),
|
|
1114: (FC_BINARY, timestamp_recv_float), # timestamp w/ tz
|
|
1184: (FC_BINARY, timestamptz_recv_float),
|
|
1186: (FC_BINARY, interval_recv_integer),
|
|
1231: (FC_TEXT, array_in), # NUMERIC[]
|
|
1263: (FC_BINARY, array_recv), # cstring[]
|
|
1700: (FC_TEXT, numeric_in), # NUMERIC
|
|
2275: (FC_BINARY, text_recv), # cstring
|
|
2950: (FC_BINARY, uuid_recv), # uuid
|
|
3802: (FC_TEXT, json_in), # jsonb
|
|
})
|
|
|
|
self.py_types = {
|
|
type(None): (-1, FC_BINARY, null_send), # null
|
|
bool: (16, FC_BINARY, bool_send),
|
|
bytearray: (17, FC_BINARY, bytea_send), # bytea
|
|
20: (20, FC_BINARY, q_pack), # int8
|
|
21: (21, FC_BINARY, h_pack), # int2
|
|
23: (23, FC_BINARY, i_pack), # int4
|
|
PGText: (25, FC_TEXT, text_out), # text
|
|
float: (701, FC_BINARY, d_pack), # float8
|
|
PGEnum: (705, FC_TEXT, enum_out),
|
|
date: (1082, FC_TEXT, date_out), # date
|
|
time: (1083, FC_TEXT, time_out), # time
|
|
1114: (1114, FC_BINARY, timestamp_send_integer), # timestamp
|
|
# timestamp w/ tz
|
|
PGVarchar: (1043, FC_TEXT, text_out), # varchar
|
|
1184: (1184, FC_BINARY, timestamptz_send_integer),
|
|
PGJson: (114, FC_TEXT, text_out),
|
|
PGJsonb: (3802, FC_TEXT, text_out),
|
|
Timedelta: (1186, FC_BINARY, interval_send_integer),
|
|
Interval: (1186, FC_BINARY, interval_send_integer),
|
|
Decimal: (1700, FC_TEXT, numeric_out), # Decimal
|
|
PGTsvector: (3614, FC_TEXT, text_out),
|
|
UUID: (2950, FC_BINARY, uuid_send)} # uuid
|
|
|
|
self.inspect_funcs = {
|
|
Datetime: self.inspect_datetime,
|
|
list: self.array_inspect,
|
|
tuple: self.array_inspect,
|
|
int: self.inspect_int}
|
|
|
|
self.py_types[bytes] = (17, FC_BINARY, bytea_send) # bytea
|
|
self.py_types[str] = (705, FC_TEXT, text_out) # unknown
|
|
self.py_types[enum.Enum] = (705, FC_TEXT, enum_out)
|
|
|
|
def inet_out(v):
|
|
return str(v).encode(self._client_encoding)
|
|
|
|
def inet_in(data, offset, length):
|
|
inet_str = data[offset: offset + length].decode(
|
|
self._client_encoding)
|
|
if '/' in inet_str:
|
|
return ip_network(inet_str, False)
|
|
else:
|
|
return ip_address(inet_str)
|
|
|
|
self.py_types[IPv4Address] = (869, FC_TEXT, inet_out) # inet
|
|
self.py_types[IPv6Address] = (869, FC_TEXT, inet_out) # inet
|
|
self.py_types[IPv4Network] = (869, FC_TEXT, inet_out) # inet
|
|
self.py_types[IPv6Network] = (869, FC_TEXT, inet_out) # inet
|
|
self.pg_types[869] = (FC_TEXT, inet_in) # inet
|
|
|
|
self.message_types = {
|
|
NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE,
|
|
AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST,
|
|
PARAMETER_STATUS: self.handle_PARAMETER_STATUS,
|
|
BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA,
|
|
READY_FOR_QUERY: self.handle_READY_FOR_QUERY,
|
|
ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION,
|
|
ERROR_RESPONSE: self.handle_ERROR_RESPONSE,
|
|
EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE,
|
|
DATA_ROW: self.handle_DATA_ROW,
|
|
COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE,
|
|
PARSE_COMPLETE: self.handle_PARSE_COMPLETE,
|
|
BIND_COMPLETE: self.handle_BIND_COMPLETE,
|
|
CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE,
|
|
PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED,
|
|
NO_DATA: self.handle_NO_DATA,
|
|
PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION,
|
|
NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE,
|
|
COPY_DONE: self.handle_COPY_DONE,
|
|
COPY_DATA: self.handle_COPY_DATA,
|
|
COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE,
|
|
COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE}
|
|
|
|
# Int32 - Message length, including self.
|
|
# Int32(196608) - Protocol version number. Version 3.0.
|
|
# Any number of key/value pairs, terminated by a zero byte:
|
|
# String - A parameter name (user, database, or options)
|
|
# String - Parameter value
|
|
protocol = 196608
|
|
val = bytearray(
|
|
i_pack(protocol) + b"user\x00" + self.user + NULL_BYTE)
|
|
if database is not None:
|
|
if isinstance(database, str):
|
|
database = database.encode('utf8')
|
|
val.extend(b"database\x00" + database + NULL_BYTE)
|
|
if application_name is not None:
|
|
if isinstance(application_name, str):
|
|
application_name = application_name.encode('utf8')
|
|
val.extend(
|
|
b"application_name\x00" + application_name + NULL_BYTE)
|
|
val.append(0)
|
|
self._write(i_pack(len(val) + 4))
|
|
self._write(val)
|
|
self._flush()
|
|
|
|
self._cursor = self.cursor()
|
|
code = self.error = None
|
|
while code not in (READY_FOR_QUERY, ERROR_RESPONSE):
|
|
code, data_len = ci_unpack(self._read(5))
|
|
self.message_types[code](self._read(data_len - 4), None)
|
|
if self.error is not None:
|
|
raise self.error
|
|
|
|
self.in_transaction = False
|
|
|
|
def handle_ERROR_RESPONSE(self, data, ps):
|
|
msg = dict(
|
|
(
|
|
s[:1].decode(self._client_encoding),
|
|
s[1:].decode(self._client_encoding)) for s in
|
|
data.split(NULL_BYTE) if s != b'')
|
|
|
|
response_code = msg[RESPONSE_CODE]
|
|
if response_code == '28000':
|
|
cls = InterfaceError
|
|
elif response_code == '23505':
|
|
cls = IntegrityError
|
|
else:
|
|
cls = ProgrammingError
|
|
|
|
self.error = cls(msg)
|
|
|
|
def handle_EMPTY_QUERY_RESPONSE(self, data, ps):
|
|
self.error = ProgrammingError("query was empty")
|
|
|
|
def handle_CLOSE_COMPLETE(self, data, ps):
|
|
pass
|
|
|
|
def handle_PARSE_COMPLETE(self, data, ps):
|
|
# Byte1('1') - Identifier.
|
|
# Int32(4) - Message length, including self.
|
|
pass
|
|
|
|
def handle_BIND_COMPLETE(self, data, ps):
|
|
pass
|
|
|
|
def handle_PORTAL_SUSPENDED(self, data, cursor):
|
|
pass
|
|
|
|
def handle_PARAMETER_DESCRIPTION(self, data, ps):
|
|
# Well, we don't really care -- we're going to send whatever we
|
|
# want and let the database deal with it. But thanks anyways!
|
|
|
|
# count = h_unpack(data)[0]
|
|
# type_oids = unpack_from("!" + "i" * count, data, 2)
|
|
pass
|
|
|
|
def handle_COPY_DONE(self, data, ps):
|
|
self._copy_done = True
|
|
|
|
def handle_COPY_OUT_RESPONSE(self, data, ps):
|
|
# Int8(1) - 0 textual, 1 binary
|
|
# Int16(2) - Number of columns
|
|
# Int16(N) - Format codes for each column (0 text, 1 binary)
|
|
|
|
is_binary, num_cols = bh_unpack(data)
|
|
# column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
|
|
if ps.stream is None:
|
|
raise InterfaceError(
|
|
"An output stream is required for the COPY OUT response.")
|
|
|
|
def handle_COPY_DATA(self, data, ps):
|
|
ps.stream.write(data)
|
|
|
|
def handle_COPY_IN_RESPONSE(self, data, ps):
|
|
# Int16(2) - Number of columns
|
|
# Int16(N) - Format codes for each column (0 text, 1 binary)
|
|
is_binary, num_cols = bh_unpack(data)
|
|
# column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
|
|
if ps.stream is None:
|
|
raise InterfaceError(
|
|
"An input stream is required for the COPY IN response.")
|
|
|
|
bffr = bytearray(8192)
|
|
while True:
|
|
bytes_read = ps.stream.readinto(bffr)
|
|
if bytes_read == 0:
|
|
break
|
|
self._write(COPY_DATA + i_pack(bytes_read + 4))
|
|
self._write(bffr[:bytes_read])
|
|
self._flush()
|
|
|
|
# Send CopyDone
|
|
# Byte1('c') - Identifier.
|
|
# Int32(4) - Message length, including self.
|
|
self._write(COPY_DONE_MSG)
|
|
self._write(SYNC_MSG)
|
|
self._flush()
|
|
|
|
def handle_NOTIFICATION_RESPONSE(self, data, ps):
|
|
##
|
|
# A message sent if this connection receives a NOTIFY that it was
|
|
# LISTENing for.
|
|
# <p>
|
|
# Stability: Added in pg8000 v1.03. When limited to accessing
|
|
# properties from a notification event dispatch, stability is
|
|
# guaranteed for v1.xx.
|
|
backend_pid = i_unpack(data)[0]
|
|
idx = 4
|
|
null = data.find(NULL_BYTE, idx) - idx
|
|
condition = data[idx:idx + null].decode("ascii")
|
|
idx += null + 1
|
|
null = data.find(NULL_BYTE, idx) - idx
|
|
# additional_info = data[idx:idx + null]
|
|
|
|
self.notifications.append((backend_pid, condition))
|
|
|
|
def cursor(self):
|
|
"""Creates a :class:`Cursor` object bound to this
|
|
connection.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
return Cursor(self)
|
|
|
|
def commit(self):
|
|
"""Commits the current database transaction.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
self.execute(self._cursor, "commit", None)
|
|
|
|
def rollback(self):
|
|
"""Rolls back the current database transaction.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
if not self.in_transaction:
|
|
return
|
|
self.execute(self._cursor, "rollback", None)
|
|
|
|
def close(self):
|
|
"""Closes the database connection.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
try:
|
|
# Byte1('X') - Identifies the message as a terminate message.
|
|
# Int32(4) - Message length, including self.
|
|
self._write(TERMINATE_MSG)
|
|
self._flush()
|
|
self._sock.close()
|
|
except AttributeError:
|
|
raise InterfaceError("connection is closed")
|
|
except ValueError:
|
|
raise InterfaceError("connection is closed")
|
|
except socket.error:
|
|
pass
|
|
finally:
|
|
self._usock.close()
|
|
self._sock = None
|
|
|
|
def handle_AUTHENTICATION_REQUEST(self, data, cursor):
|
|
# Int32 - An authentication code that represents different
|
|
# authentication messages:
|
|
# 0 = AuthenticationOk
|
|
# 5 = MD5 pwd
|
|
# 2 = Kerberos v5 (not supported by pg8000)
|
|
# 3 = Cleartext pwd
|
|
# 4 = crypt() pwd (not supported by pg8000)
|
|
# 6 = SCM credential (not supported by pg8000)
|
|
# 7 = GSSAPI (not supported by pg8000)
|
|
# 8 = GSSAPI data (not supported by pg8000)
|
|
# 9 = SSPI (not supported by pg8000)
|
|
# Some authentication messages have additional data following the
|
|
# authentication code. That data is documented in the appropriate
|
|
# class.
|
|
auth_code = i_unpack(data)[0]
|
|
if auth_code == 0:
|
|
pass
|
|
elif auth_code == 3:
|
|
if self.password is None:
|
|
raise InterfaceError(
|
|
"server requesting password authentication, but no "
|
|
"password was provided")
|
|
self._send_message(PASSWORD, self.password + NULL_BYTE)
|
|
self._flush()
|
|
elif auth_code == 5:
|
|
##
|
|
# A message representing the backend requesting an MD5 hashed
|
|
# password response. The response will be sent as
|
|
# md5(md5(pwd + login) + salt).
|
|
|
|
# Additional message data:
|
|
# Byte4 - Hash salt.
|
|
salt = b"".join(cccc_unpack(data, 4))
|
|
if self.password is None:
|
|
raise InterfaceError(
|
|
"server requesting MD5 password authentication, but no "
|
|
"password was provided")
|
|
pwd = b"md5" + md5(
|
|
md5(self.password + self.user).hexdigest().encode("ascii") +
|
|
salt).hexdigest().encode("ascii")
|
|
# Byte1('p') - Identifies the message as a password message.
|
|
# Int32 - Message length including self.
|
|
# String - The password. Password may be encrypted.
|
|
self._send_message(PASSWORD, pwd + NULL_BYTE)
|
|
self._flush()
|
|
|
|
elif auth_code == 10:
|
|
# AuthenticationSASL
|
|
mechanisms = [
|
|
m.decode('ascii') for m in data[4:-1].split(NULL_BYTE)]
|
|
|
|
self.auth = ScramClient(
|
|
mechanisms, self.user.decode('utf8'),
|
|
self.password.decode('utf8'))
|
|
|
|
init = self.auth.get_client_first().encode('utf8')
|
|
|
|
# SASLInitialResponse
|
|
self._write(
|
|
create_message(
|
|
PASSWORD,
|
|
b'SCRAM-SHA-256' + NULL_BYTE + i_pack(len(init)) + init))
|
|
self._flush()
|
|
|
|
elif auth_code == 11:
|
|
# AuthenticationSASLContinue
|
|
self.auth.set_server_first(data[4:].decode('utf8'))
|
|
|
|
# SASLResponse
|
|
msg = self.auth.get_client_final().encode('utf8')
|
|
self._write(create_message(PASSWORD, msg))
|
|
self._flush()
|
|
|
|
elif auth_code == 12:
|
|
# AuthenticationSASLFinal
|
|
self.auth.set_server_final(data[4:].decode('utf8'))
|
|
|
|
elif auth_code in (2, 4, 6, 7, 8, 9):
|
|
raise InterfaceError(
|
|
"Authentication method " + str(auth_code) +
|
|
" not supported by pg8000.")
|
|
else:
|
|
raise InterfaceError(
|
|
"Authentication method " + str(auth_code) +
|
|
" not recognized by pg8000.")
|
|
|
|
def handle_READY_FOR_QUERY(self, data, ps):
|
|
# Byte1 - Status indicator.
|
|
self.in_transaction = data != IDLE
|
|
|
|
def handle_BACKEND_KEY_DATA(self, data, ps):
|
|
self._backend_key_data = data
|
|
|
|
def inspect_datetime(self, value):
|
|
if value.tzinfo is None:
|
|
return self.py_types[1114] # timestamp
|
|
else:
|
|
return self.py_types[1184] # send as timestamptz
|
|
|
|
def inspect_int(self, value):
|
|
if min_int2 < value < max_int2:
|
|
return self.py_types[21]
|
|
if min_int4 < value < max_int4:
|
|
return self.py_types[23]
|
|
if min_int8 < value < max_int8:
|
|
return self.py_types[20]
|
|
|
|
def make_params(self, values):
|
|
params = []
|
|
for value in values:
|
|
typ = type(value)
|
|
try:
|
|
params.append(self.py_types[typ])
|
|
except KeyError:
|
|
try:
|
|
params.append(self.inspect_funcs[typ](value))
|
|
except KeyError as e:
|
|
param = None
|
|
for k, v in self.py_types.items():
|
|
try:
|
|
if isinstance(value, k):
|
|
param = v
|
|
break
|
|
except TypeError:
|
|
pass
|
|
|
|
if param is None:
|
|
for k, v in self.inspect_funcs.items():
|
|
try:
|
|
if isinstance(value, k):
|
|
param = v(value)
|
|
break
|
|
except TypeError:
|
|
pass
|
|
except KeyError:
|
|
pass
|
|
|
|
if param is None:
|
|
raise NotSupportedError(
|
|
"type " + str(e) + " not mapped to pg type")
|
|
else:
|
|
params.append(param)
|
|
|
|
return tuple(params)
|
|
|
|
def handle_ROW_DESCRIPTION(self, data, cursor):
|
|
count = h_unpack(data)[0]
|
|
idx = 2
|
|
for i in range(count):
|
|
name = data[idx:data.find(NULL_BYTE, idx)]
|
|
idx += len(name) + 1
|
|
field = dict(
|
|
zip((
|
|
"table_oid", "column_attrnum", "type_oid", "type_size",
|
|
"type_modifier", "format"), ihihih_unpack(data, idx)))
|
|
field['name'] = name
|
|
idx += 18
|
|
cursor.ps['row_desc'].append(field)
|
|
field['pg8000_fc'], field['func'] = \
|
|
self.pg_types[field['type_oid']]
|
|
|
|
def execute(self, cursor, operation, vals):
|
|
if vals is None:
|
|
vals = ()
|
|
|
|
paramstyle = pg8000.paramstyle
|
|
pid = getpid()
|
|
try:
|
|
cache = self._caches[paramstyle][pid]
|
|
except KeyError:
|
|
try:
|
|
param_cache = self._caches[paramstyle]
|
|
except KeyError:
|
|
param_cache = self._caches[paramstyle] = {}
|
|
|
|
try:
|
|
cache = param_cache[pid]
|
|
except KeyError:
|
|
cache = param_cache[pid] = {'statement': {}, 'ps': {}}
|
|
|
|
try:
|
|
statement, make_args = cache['statement'][operation]
|
|
except KeyError:
|
|
statement, make_args = cache['statement'][operation] = \
|
|
convert_paramstyle(paramstyle, operation)
|
|
|
|
args = make_args(vals)
|
|
params = self.make_params(args)
|
|
key = operation, params
|
|
|
|
try:
|
|
ps = cache['ps'][key]
|
|
cursor.ps = ps
|
|
except KeyError:
|
|
statement_nums = [0]
|
|
for style_cache in self._caches.values():
|
|
try:
|
|
pid_cache = style_cache[pid]
|
|
for csh in pid_cache['ps'].values():
|
|
statement_nums.append(csh['statement_num'])
|
|
except KeyError:
|
|
pass
|
|
|
|
statement_num = sorted(statement_nums)[-1] + 1
|
|
statement_name = '_'.join(
|
|
("pg8000", "statement", str(pid), str(statement_num)))
|
|
statement_name_bin = statement_name.encode('ascii') + NULL_BYTE
|
|
ps = {
|
|
'statement_name_bin': statement_name_bin,
|
|
'pid': pid,
|
|
'statement_num': statement_num,
|
|
'row_desc': [],
|
|
'param_funcs': tuple(x[2] for x in params)}
|
|
cursor.ps = ps
|
|
|
|
param_fcs = tuple(x[1] for x in params)
|
|
|
|
# Byte1('P') - Identifies the message as a Parse command.
|
|
# Int32 - Message length, including self.
|
|
# String - Prepared statement name. An empty string selects the
|
|
# unnamed prepared statement.
|
|
# String - The query string.
|
|
# Int16 - Number of parameter data types specified (can be zero).
|
|
# For each parameter:
|
|
# Int32 - The OID of the parameter data type.
|
|
val = bytearray(statement_name_bin)
|
|
val.extend(statement.encode(self._client_encoding) + NULL_BYTE)
|
|
val.extend(h_pack(len(params)))
|
|
for oid, fc, send_func in params:
|
|
# Parse message doesn't seem to handle the -1 type_oid for NULL
|
|
# values that other messages handle. So we'll provide type_oid
|
|
# 705, the PG "unknown" type.
|
|
val.extend(i_pack(705 if oid == -1 else oid))
|
|
|
|
# Byte1('D') - Identifies the message as a describe command.
|
|
# Int32 - Message length, including self.
|
|
# Byte1 - 'S' for prepared statement, 'P' for portal.
|
|
# String - The name of the item to describe.
|
|
self._send_message(PARSE, val)
|
|
self._send_message(DESCRIBE, STATEMENT + statement_name_bin)
|
|
self._write(SYNC_MSG)
|
|
|
|
try:
|
|
self._flush()
|
|
except AttributeError as e:
|
|
if self._sock is None:
|
|
raise InterfaceError("connection is closed")
|
|
else:
|
|
raise e
|
|
|
|
self.handle_messages(cursor)
|
|
|
|
# We've got row_desc that allows us to identify what we're
|
|
# going to get back from this statement.
|
|
output_fc = tuple(
|
|
self.pg_types[f['type_oid']][0] for f in ps['row_desc'])
|
|
|
|
ps['input_funcs'] = tuple(f['func'] for f in ps['row_desc'])
|
|
# Byte1('B') - Identifies the Bind command.
|
|
# Int32 - Message length, including self.
|
|
# String - Name of the destination portal.
|
|
# String - Name of the source prepared statement.
|
|
# Int16 - Number of parameter format codes.
|
|
# For each parameter format code:
|
|
# Int16 - The parameter format code.
|
|
# Int16 - Number of parameter values.
|
|
# For each parameter value:
|
|
# Int32 - The length of the parameter value, in bytes, not
|
|
# including this length. -1 indicates a NULL parameter
|
|
# value, in which no value bytes follow.
|
|
# Byte[n] - Value of the parameter.
|
|
# Int16 - The number of result-column format codes.
|
|
# For each result-column format code:
|
|
# Int16 - The format code.
|
|
ps['bind_1'] = NULL_BYTE + statement_name_bin + \
|
|
h_pack(len(params)) + \
|
|
pack("!" + "h" * len(param_fcs), *param_fcs) + \
|
|
h_pack(len(params))
|
|
|
|
ps['bind_2'] = h_pack(len(output_fc)) + \
|
|
pack("!" + "h" * len(output_fc), *output_fc)
|
|
|
|
if len(cache['ps']) > self.max_prepared_statements:
|
|
for p in cache['ps'].values():
|
|
self.close_prepared_statement(p['statement_name_bin'])
|
|
cache['ps'].clear()
|
|
|
|
cache['ps'][key] = ps
|
|
|
|
cursor._cached_rows.clear()
|
|
cursor._row_count = -1
|
|
|
|
# Byte1('B') - Identifies the Bind command.
|
|
# Int32 - Message length, including self.
|
|
# String - Name of the destination portal.
|
|
# String - Name of the source prepared statement.
|
|
# Int16 - Number of parameter format codes.
|
|
# For each parameter format code:
|
|
# Int16 - The parameter format code.
|
|
# Int16 - Number of parameter values.
|
|
# For each parameter value:
|
|
# Int32 - The length of the parameter value, in bytes, not
|
|
# including this length. -1 indicates a NULL parameter
|
|
# value, in which no value bytes follow.
|
|
# Byte[n] - Value of the parameter.
|
|
# Int16 - The number of result-column format codes.
|
|
# For each result-column format code:
|
|
# Int16 - The format code.
|
|
retval = bytearray(ps['bind_1'])
|
|
for value, send_func in zip(args, ps['param_funcs']):
|
|
if value is None:
|
|
val = NULL
|
|
else:
|
|
val = send_func(value)
|
|
retval.extend(i_pack(len(val)))
|
|
retval.extend(val)
|
|
retval.extend(ps['bind_2'])
|
|
|
|
self._send_message(BIND, retval)
|
|
self.send_EXECUTE(cursor)
|
|
self._write(SYNC_MSG)
|
|
self._flush()
|
|
self.handle_messages(cursor)
|
|
|
|
def _send_message(self, code, data):
|
|
try:
|
|
self._write(code)
|
|
self._write(i_pack(len(data) + 4))
|
|
self._write(data)
|
|
self._write(FLUSH_MSG)
|
|
except ValueError as e:
|
|
if str(e) == "write to closed file":
|
|
raise InterfaceError("connection is closed")
|
|
else:
|
|
raise e
|
|
except AttributeError:
|
|
raise InterfaceError("connection is closed")
|
|
|
|
def send_EXECUTE(self, cursor):
|
|
# Byte1('E') - Identifies the message as an execute message.
|
|
# Int32 - Message length, including self.
|
|
# String - The name of the portal to execute.
|
|
# Int32 - Maximum number of rows to return, if portal
|
|
# contains a query # that returns rows.
|
|
# 0 = no limit.
|
|
self._write(EXECUTE_MSG)
|
|
self._write(FLUSH_MSG)
|
|
|
|
def handle_NO_DATA(self, msg, ps):
|
|
pass
|
|
|
|
def handle_COMMAND_COMPLETE(self, data, cursor):
|
|
values = data[:-1].split(b' ')
|
|
command = values[0]
|
|
if command in self._commands_with_count:
|
|
row_count = int(values[-1])
|
|
if cursor._row_count == -1:
|
|
cursor._row_count = row_count
|
|
else:
|
|
cursor._row_count += row_count
|
|
|
|
if command in (b"ALTER", b"CREATE"):
|
|
for scache in self._caches.values():
|
|
for pcache in scache.values():
|
|
for ps in pcache['ps'].values():
|
|
self.close_prepared_statement(ps['statement_name_bin'])
|
|
pcache['ps'].clear()
|
|
|
|
def handle_DATA_ROW(self, data, cursor):
|
|
data_idx = 2
|
|
row = []
|
|
for func in cursor.ps['input_funcs']:
|
|
vlen = i_unpack(data, data_idx)[0]
|
|
data_idx += 4
|
|
if vlen == -1:
|
|
row.append(None)
|
|
else:
|
|
row.append(func(data, data_idx, vlen))
|
|
data_idx += vlen
|
|
cursor._cached_rows.append(row)
|
|
|
|
def handle_messages(self, cursor):
|
|
code = self.error = None
|
|
|
|
while code != READY_FOR_QUERY:
|
|
code, data_len = ci_unpack(self._read(5))
|
|
self.message_types[code](self._read(data_len - 4), cursor)
|
|
|
|
if self.error is not None:
|
|
raise self.error
|
|
|
|
# Byte1('C') - Identifies the message as a close command.
|
|
# Int32 - Message length, including self.
|
|
# Byte1 - 'S' for prepared statement, 'P' for portal.
|
|
# String - The name of the item to close.
|
|
def close_prepared_statement(self, statement_name_bin):
|
|
self._send_message(CLOSE, STATEMENT + statement_name_bin)
|
|
self._write(SYNC_MSG)
|
|
self._flush()
|
|
self.handle_messages(self._cursor)
|
|
|
|
# Byte1('N') - Identifier
|
|
# Int32 - Message length
|
|
# Any number of these, followed by a zero byte:
|
|
# Byte1 - code identifying the field type (see responseKeys)
|
|
# String - field value
|
|
def handle_NOTICE_RESPONSE(self, data, ps):
|
|
self.notices.append(
|
|
dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE)))
|
|
|
|
def handle_PARAMETER_STATUS(self, data, ps):
|
|
pos = data.find(NULL_BYTE)
|
|
key, value = data[:pos], data[pos + 1:-1]
|
|
self.parameter_statuses.append((key, value))
|
|
if key == b"client_encoding":
|
|
encoding = value.decode("ascii").lower()
|
|
self._client_encoding = pg_to_py_encodings.get(encoding, encoding)
|
|
|
|
elif key == b"integer_datetimes":
|
|
if value == b'on':
|
|
|
|
self.py_types[1114] = (1114, FC_BINARY, timestamp_send_integer)
|
|
self.pg_types[1114] = (FC_BINARY, timestamp_recv_integer)
|
|
|
|
self.py_types[1184] = (
|
|
1184, FC_BINARY, timestamptz_send_integer)
|
|
self.pg_types[1184] = (FC_BINARY, timestamptz_recv_integer)
|
|
|
|
self.py_types[Interval] = (
|
|
1186, FC_BINARY, interval_send_integer)
|
|
self.py_types[Timedelta] = (
|
|
1186, FC_BINARY, interval_send_integer)
|
|
self.pg_types[1186] = (FC_BINARY, interval_recv_integer)
|
|
else:
|
|
self.py_types[1114] = (1114, FC_BINARY, timestamp_send_float)
|
|
self.pg_types[1114] = (FC_BINARY, timestamp_recv_float)
|
|
self.py_types[1184] = (1184, FC_BINARY, timestamptz_send_float)
|
|
self.pg_types[1184] = (FC_BINARY, timestamptz_recv_float)
|
|
|
|
self.py_types[Interval] = (
|
|
1186, FC_BINARY, interval_send_float)
|
|
self.py_types[Timedelta] = (
|
|
1186, FC_BINARY, interval_send_float)
|
|
self.pg_types[1186] = (FC_BINARY, interval_recv_float)
|
|
|
|
elif key == b"server_version":
|
|
self._server_version = LooseVersion(value.decode('ascii'))
|
|
if self._server_version < LooseVersion('8.2.0'):
|
|
self._commands_with_count = (
|
|
b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH")
|
|
elif self._server_version < LooseVersion('9.0.0'):
|
|
self._commands_with_count = (
|
|
b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH",
|
|
b"COPY")
|
|
|
|
def array_inspect(self, value):
|
|
# Check if array has any values. If empty, we can just assume it's an
|
|
# array of strings
|
|
first_element = array_find_first_element(value)
|
|
if first_element is None:
|
|
oid = 25
|
|
# Use binary ARRAY format to avoid having to properly
|
|
# escape text in the array literals
|
|
fc = FC_BINARY
|
|
array_oid = pg_array_types[oid]
|
|
else:
|
|
# supported array output
|
|
typ = type(first_element)
|
|
|
|
if issubclass(typ, int):
|
|
# special int array support -- send as smallest possible array
|
|
# type
|
|
typ = int
|
|
int2_ok, int4_ok, int8_ok = True, True, True
|
|
for v in array_flatten(value):
|
|
if v is None:
|
|
continue
|
|
if min_int2 < v < max_int2:
|
|
continue
|
|
int2_ok = False
|
|
if min_int4 < v < max_int4:
|
|
continue
|
|
int4_ok = False
|
|
if min_int8 < v < max_int8:
|
|
continue
|
|
int8_ok = False
|
|
if int2_ok:
|
|
array_oid = 1005 # INT2[]
|
|
oid, fc, send_func = (21, FC_BINARY, h_pack)
|
|
elif int4_ok:
|
|
array_oid = 1007 # INT4[]
|
|
oid, fc, send_func = (23, FC_BINARY, i_pack)
|
|
elif int8_ok:
|
|
array_oid = 1016 # INT8[]
|
|
oid, fc, send_func = (20, FC_BINARY, q_pack)
|
|
else:
|
|
raise ArrayContentNotSupportedError(
|
|
"numeric not supported as array contents")
|
|
else:
|
|
try:
|
|
oid, fc, send_func = self.make_params((first_element,))[0]
|
|
|
|
# If unknown or string, assume it's a string array
|
|
if oid in (705, 1043, 25):
|
|
oid = 25
|
|
# Use binary ARRAY format to avoid having to properly
|
|
# escape text in the array literals
|
|
fc = FC_BINARY
|
|
array_oid = pg_array_types[oid]
|
|
except KeyError:
|
|
raise ArrayContentNotSupportedError(
|
|
"oid " + str(oid) + " not supported as array contents")
|
|
except NotSupportedError:
|
|
raise ArrayContentNotSupportedError(
|
|
"type " + str(typ) +
|
|
" not supported as array contents")
|
|
if fc == FC_BINARY:
|
|
def send_array(arr):
|
|
# check that all array dimensions are consistent
|
|
array_check_dimensions(arr)
|
|
|
|
has_null = array_has_null(arr)
|
|
dim_lengths = array_dim_lengths(arr)
|
|
data = bytearray(iii_pack(len(dim_lengths), has_null, oid))
|
|
for i in dim_lengths:
|
|
data.extend(ii_pack(i, 1))
|
|
for v in array_flatten(arr):
|
|
if v is None:
|
|
data += i_pack(-1)
|
|
elif isinstance(v, typ):
|
|
inner_data = send_func(v)
|
|
data += i_pack(len(inner_data))
|
|
data += inner_data
|
|
else:
|
|
raise ArrayContentNotHomogenousError(
|
|
"not all array elements are of type " + str(typ))
|
|
return data
|
|
else:
|
|
def send_array(arr):
|
|
array_check_dimensions(arr)
|
|
ar = deepcopy(arr)
|
|
for a, i, v in walk_array(ar):
|
|
if v is None:
|
|
a[i] = 'NULL'
|
|
elif isinstance(v, typ):
|
|
a[i] = send_func(v).decode('ascii')
|
|
else:
|
|
raise ArrayContentNotHomogenousError(
|
|
"not all array elements are of type " + str(typ))
|
|
return str(ar).translate(arr_trans).encode('ascii')
|
|
|
|
return (array_oid, fc, send_array)
|
|
|
|
def xid(self, format_id, global_transaction_id, branch_qualifier):
|
|
"""Create a Transaction IDs (only global_transaction_id is used in pg)
|
|
format_id and branch_qualifier are not used in postgres
|
|
global_transaction_id may be any string identifier supported by
|
|
postgres returns a tuple
|
|
(format_id, global_transaction_id, branch_qualifier)"""
|
|
return (format_id, global_transaction_id, branch_qualifier)
|
|
|
|
def tpc_begin(self, xid):
|
|
"""Begins a TPC transaction with the given transaction ID xid.
|
|
|
|
This method should be called outside of a transaction (i.e. nothing may
|
|
have executed since the last .commit() or .rollback()).
|
|
|
|
Furthermore, it is an error to call .commit() or .rollback() within the
|
|
TPC transaction. A ProgrammingError is raised, if the application calls
|
|
.commit() or .rollback() during an active TPC transaction.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
self._xid = xid
|
|
if self.autocommit:
|
|
self.execute(self._cursor, "begin transaction", None)
|
|
|
|
def tpc_prepare(self):
|
|
"""Performs the first phase of a transaction started with .tpc_begin().
|
|
A ProgrammingError is be raised if this method is called outside of a
|
|
TPC transaction.
|
|
|
|
After calling .tpc_prepare(), no statements can be executed until
|
|
.tpc_commit() or .tpc_rollback() have been called.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
q = "PREPARE TRANSACTION '%s';" % (self._xid[1],)
|
|
self.execute(self._cursor, q, None)
|
|
|
|
def tpc_commit(self, xid=None):
|
|
"""When called with no arguments, .tpc_commit() commits a TPC
|
|
transaction previously prepared with .tpc_prepare().
|
|
|
|
If .tpc_commit() is called prior to .tpc_prepare(), a single phase
|
|
commit is performed. A transaction manager may choose to do this if
|
|
only a single resource is participating in the global transaction.
|
|
|
|
When called with a transaction ID xid, the database commits the given
|
|
transaction. If an invalid transaction ID is provided, a
|
|
ProgrammingError will be raised. This form should be called outside of
|
|
a transaction, and is intended for use in recovery.
|
|
|
|
On return, the TPC transaction is ended.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
if xid is None:
|
|
xid = self._xid
|
|
|
|
if xid is None:
|
|
raise ProgrammingError(
|
|
"Cannot tpc_commit() without a TPC transaction!")
|
|
|
|
try:
|
|
previous_autocommit_mode = self.autocommit
|
|
self.autocommit = True
|
|
if xid in self.tpc_recover():
|
|
self.execute(
|
|
self._cursor, "COMMIT PREPARED '%s';" % (xid[1], ),
|
|
None)
|
|
else:
|
|
# a single-phase commit
|
|
self.commit()
|
|
finally:
|
|
self.autocommit = previous_autocommit_mode
|
|
self._xid = None
|
|
|
|
def tpc_rollback(self, xid=None):
|
|
"""When called with no arguments, .tpc_rollback() rolls back a TPC
|
|
transaction. It may be called before or after .tpc_prepare().
|
|
|
|
When called with a transaction ID xid, it rolls back the given
|
|
transaction. If an invalid transaction ID is provided, a
|
|
ProgrammingError is raised. This form should be called outside of a
|
|
transaction, and is intended for use in recovery.
|
|
|
|
On return, the TPC transaction is ended.
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
if xid is None:
|
|
xid = self._xid
|
|
|
|
if xid is None:
|
|
raise ProgrammingError(
|
|
"Cannot tpc_rollback() without a TPC prepared transaction!")
|
|
|
|
try:
|
|
previous_autocommit_mode = self.autocommit
|
|
self.autocommit = True
|
|
if xid in self.tpc_recover():
|
|
# a two-phase rollback
|
|
self.execute(
|
|
self._cursor, "ROLLBACK PREPARED '%s';" % (xid[1],),
|
|
None)
|
|
else:
|
|
# a single-phase rollback
|
|
self.rollback()
|
|
finally:
|
|
self.autocommit = previous_autocommit_mode
|
|
self._xid = None
|
|
|
|
def tpc_recover(self):
|
|
"""Returns a list of pending transaction IDs suitable for use with
|
|
.tpc_commit(xid) or .tpc_rollback(xid).
|
|
|
|
This function is part of the `DBAPI 2.0 specification
|
|
<http://www.python.org/dev/peps/pep-0249/>`_.
|
|
"""
|
|
try:
|
|
previous_autocommit_mode = self.autocommit
|
|
self.autocommit = True
|
|
curs = self.cursor()
|
|
curs.execute("select gid FROM pg_prepared_xacts")
|
|
return [self.xid(0, row[0], '') for row in curs]
|
|
finally:
|
|
self.autocommit = previous_autocommit_mode
|
|
|
|
|
|
# pg element oid -> pg array typeoid
|
|
pg_array_types = {
|
|
16: 1000,
|
|
25: 1009, # TEXT[]
|
|
701: 1022,
|
|
1043: 1009,
|
|
1700: 1231, # NUMERIC[]
|
|
}
|
|
|
|
|
|
# PostgreSQL encodings:
|
|
# http://www.postgresql.org/docs/8.3/interactive/multibyte.html
|
|
# Python encodings:
|
|
# http://www.python.org/doc/2.4/lib/standard-encodings.html
|
|
#
|
|
# Commented out encodings don't require a name change between PostgreSQL and
|
|
# Python. If the py side is None, then the encoding isn't supported.
|
|
pg_to_py_encodings = {
|
|
# Not supported:
|
|
"mule_internal": None,
|
|
"euc_tw": None,
|
|
|
|
# Name fine as-is:
|
|
# "euc_jp",
|
|
# "euc_jis_2004",
|
|
# "euc_kr",
|
|
# "gb18030",
|
|
# "gbk",
|
|
# "johab",
|
|
# "sjis",
|
|
# "shift_jis_2004",
|
|
# "uhc",
|
|
# "utf8",
|
|
|
|
# Different name:
|
|
"euc_cn": "gb2312",
|
|
"iso_8859_5": "is8859_5",
|
|
"iso_8859_6": "is8859_6",
|
|
"iso_8859_7": "is8859_7",
|
|
"iso_8859_8": "is8859_8",
|
|
"koi8": "koi8_r",
|
|
"latin1": "iso8859-1",
|
|
"latin2": "iso8859_2",
|
|
"latin3": "iso8859_3",
|
|
"latin4": "iso8859_4",
|
|
"latin5": "iso8859_9",
|
|
"latin6": "iso8859_10",
|
|
"latin7": "iso8859_13",
|
|
"latin8": "iso8859_14",
|
|
"latin9": "iso8859_15",
|
|
"sql_ascii": "ascii",
|
|
"win866": "cp886",
|
|
"win874": "cp874",
|
|
"win1250": "cp1250",
|
|
"win1251": "cp1251",
|
|
"win1252": "cp1252",
|
|
"win1253": "cp1253",
|
|
"win1254": "cp1254",
|
|
"win1255": "cp1255",
|
|
"win1256": "cp1256",
|
|
"win1257": "cp1257",
|
|
"win1258": "cp1258",
|
|
"unicode": "utf-8", # Needed for Amazon Redshift
|
|
}
|
|
|
|
|
|
def walk_array(arr):
|
|
for i, v in enumerate(arr):
|
|
if isinstance(v, list):
|
|
for a, i2, v2 in walk_array(v):
|
|
yield a, i2, v2
|
|
else:
|
|
yield arr, i, v
|
|
|
|
|
|
def array_find_first_element(arr):
|
|
for v in array_flatten(arr):
|
|
if v is not None:
|
|
return v
|
|
return None
|
|
|
|
|
|
def array_flatten(arr):
|
|
for v in arr:
|
|
if isinstance(v, list):
|
|
for v2 in array_flatten(v):
|
|
yield v2
|
|
else:
|
|
yield v
|
|
|
|
|
|
def array_check_dimensions(arr):
|
|
if len(arr) > 0:
|
|
v0 = arr[0]
|
|
if isinstance(v0, list):
|
|
req_len = len(v0)
|
|
req_inner_lengths = array_check_dimensions(v0)
|
|
for v in arr:
|
|
inner_lengths = array_check_dimensions(v)
|
|
if len(v) != req_len or inner_lengths != req_inner_lengths:
|
|
raise ArrayDimensionsNotConsistentError(
|
|
"array dimensions not consistent")
|
|
retval = [req_len]
|
|
retval.extend(req_inner_lengths)
|
|
return retval
|
|
else:
|
|
# make sure nothing else at this level is a list
|
|
for v in arr:
|
|
if isinstance(v, list):
|
|
raise ArrayDimensionsNotConsistentError(
|
|
"array dimensions not consistent")
|
|
return []
|
|
|
|
|
|
def array_has_null(arr):
|
|
for v in array_flatten(arr):
|
|
if v is None:
|
|
return True
|
|
return False
|
|
|
|
|
|
def array_dim_lengths(arr):
|
|
len_arr = len(arr)
|
|
retval = [len_arr]
|
|
if len_arr > 0:
|
|
v0 = arr[0]
|
|
if isinstance(v0, list):
|
|
retval.extend(array_dim_lengths(v0))
|
|
return retval
|