zaz/source/libraries/pg8000/scramp/core.py

355 lines
11 KiB
Python

import hmac
from uuid import uuid4
from base64 import b64encode, b64decode
import hashlib
from stringprep import (
in_table_a1, in_table_b1, in_table_c21_c22, in_table_c3, in_table_c4,
in_table_c5, in_table_c6, in_table_c7, in_table_c8, in_table_c9,
in_table_c12, in_table_d1, in_table_d2)
import unicodedata
from os import urandom
from enum import IntEnum, unique
# https://tools.ietf.org/html/rfc5802
# https://www.rfc-editor.org/rfc/rfc7677.txt
@unique
class ClientStage(IntEnum):
get_client_first = 1
set_server_first = 2
get_client_final = 3
set_server_final = 4
@unique
class ServerStage(IntEnum):
set_client_first = 1
get_server_first = 2
set_client_final = 3
get_server_final = 4
def _check_stage(Stages, current_stage, next_stage):
if current_stage is None:
if next_stage != 1:
raise ScramException(
"The method " + Stages(1).name + " must be called first.")
elif current_stage == 4:
raise ScramException(
"The authentication sequence has already finished.")
elif next_stage != current_stage + 1:
raise ScramException(
"The next method to be called is " + Stages(current_stage + 1) +
", not this method.")
class ScramException(Exception):
pass
MECHANISMS = ('SCRAM-SHA-1', 'SCRAM-SHA-256')
HASHES = {
'SCRAM-SHA-1': hashlib.sha1,
'SCRAM-SHA-256': hashlib.sha256
}
class ScramClient():
def __init__(self, mechanisms, username, password, c_nonce=None):
self.mech = None
for mech in MECHANISMS:
if mech in mechanisms:
self.mech = mech
if self.mech is None:
raise ScramException(
"The only recognized mechanisms are " + str(MECHANISMS) +
"and none of those can be found in " + mechanisms + ".")
self.hf = HASHES[self.mech]
if c_nonce is None:
self.c_nonce = _make_nonce()
else:
self.c_nonce = c_nonce
self.username = username
self.password = password
self.stage = None
def _set_stage(self, next_stage):
_check_stage(ClientStage, self.stage, next_stage)
self.stage = next_stage
def get_client_first(self):
self._set_stage(ClientStage.get_client_first)
self.client_first_bare, client_first = _get_client_first(
self.username, self.c_nonce)
return client_first
def set_server_first(self, message):
self._set_stage(ClientStage.set_server_first)
self.server_first = message
self.auth_message, self.nonce, self.salt, self.iterations = \
_set_server_first(message, self.c_nonce, self.client_first_bare)
def get_client_final(self):
self._set_stage(ClientStage.get_client_final)
self.server_signature, cfinal = _get_client_final(
self.hf, self.password, self.salt, self.iterations, self.nonce,
self.auth_message)
return cfinal
def set_server_final(self, message):
self._set_stage(ClientStage.set_server_final)
_set_server_final(message, self.server_signature)
class ScramServer():
def __init__(
self, password_fn, s_nonce=None, iterations=4096, salt=None,
mechanism='SCRAM-SHA-256'):
if mechanism not in MECHANISMS:
raise ScramException(
"The only recognized mechanisms are " + str(MECHANISMS) +
".")
self.mechanism = mechanism
self.hf = HASHES[self.mechanism]
if s_nonce is None:
self.s_nonce = _make_nonce()
else:
self.s_nonce = s_nonce
if salt is None:
self.salt = _b64enc(urandom(16))
else:
self.salt = salt
self.password_fn = password_fn
self.iterations = iterations
self.stage = None
def _set_stage(self, next_stage):
_check_stage(ServerStage, self.stage, next_stage)
self.stage = next_stage
def set_client_first(self, client_first):
self._set_stage(ServerStage.set_client_first)
self.nonce, self.user, self.client_first_bare = _set_client_first(
client_first, self.s_nonce)
self.password = self.password_fn(self.user)
def get_server_first(self):
self._set_stage(ServerStage.get_server_first)
self.auth_message, server_first = _get_server_first(
self.nonce, self.salt, self.iterations, self.client_first_bare)
return server_first
def set_client_final(self, client_final):
self._set_stage(ServerStage.set_client_final)
self.server_signature = _set_client_final(
self.hf, client_final, self.s_nonce, self.password, self.salt,
self.iterations, self.auth_message)
def get_server_final(self):
self._set_stage(ServerStage.get_server_final)
return _get_server_final(self.server_signature)
def _make_nonce():
return str(uuid4()).replace('-', '')
def _make_auth_message(nonce, client_first_bare, server_first):
msg = client_first_bare, server_first, 'c=' + _b64enc(b'n,,'), 'r=' + nonce
return ','.join(msg)
def _proof_signature(hf, password, salt, iterations, auth_msg):
salted_password = _hi(
hf, _uenc(saslprep(password)), _b64dec(salt), iterations)
client_key = _hmac(hf, salted_password, b"Client Key")
stored_key = _h(hf, client_key)
client_signature = _hmac(hf, stored_key, _uenc(auth_msg))
client_proof = _xor(client_key, client_signature)
server_key = _hmac(hf, salted_password, b"Server Key")
server_signature = _hmac(hf, server_key, _uenc(auth_msg))
return _b64enc(client_proof), _b64enc(server_signature)
def _hmac(hf, key, msg):
return hmac.new(key, msg=msg, digestmod=hf).digest()
def _h(hf, msg):
return hf(msg).digest()
def _hi(hf, password, salt, iterations):
u = ui = _hmac(hf, password, salt + b'\x00\x00\x00\x01')
for i in range(iterations - 1):
ui = _hmac(hf, password, ui)
u = _xor(u, ui)
return u
def _hi_iter(password, mac, iterations):
if iterations == 0:
return mac
else:
new_mac = _hmac(password, mac)
return _xor(_hi_iter(password, new_mac, iterations-1), mac)
def _parse_message(msg):
return dict((e[0], e[2:]) for e in msg.split(',') if len(e) > 1)
def _b64enc(binary):
return b64encode(binary).decode('utf8')
def _b64dec(string):
return b64decode(string)
def _uenc(string):
return string.encode('utf-8')
def _xor(bytes1, bytes2):
return bytes(a ^ b for a, b in zip(bytes1, bytes2))
def _get_client_first(username, c_nonce):
bare = ','.join(('n=' + saslprep(username), 'r=' + c_nonce))
return bare, 'n,,' + bare
def _set_client_first(client_first, s_nonce):
msg = _parse_message(client_first)
c_nonce = msg['r']
nonce = c_nonce + s_nonce
user = msg['n']
client_first_bare = client_first[3:]
return nonce, user, client_first_bare
def _get_server_first(nonce, salt, iterations, client_first_bare):
sfirst = ','.join(('r=' + nonce, 's=' + salt, 'i=' + str(iterations)))
auth_msg = _make_auth_message(nonce, client_first_bare, sfirst)
return auth_msg, sfirst
def _set_server_first(server_first, c_nonce, client_first_bare):
msg = _parse_message(server_first)
nonce = msg['r']
salt = msg['s']
iterations = int(msg['i'])
if not nonce.startswith(c_nonce):
raise ScramException("Client nonce doesn't match.")
auth_msg = _make_auth_message(nonce, client_first_bare, server_first)
return auth_msg, nonce, salt, iterations
def _get_client_final(hf, password, salt, iterations, nonce, auth_msg):
client_proof, server_signature = _proof_signature(
hf, password, salt, iterations, auth_msg)
message = ['c=' + _b64enc(b'n,,'), 'r=' + nonce, 'p=' + client_proof]
return server_signature, ','.join(message)
def _set_client_final(
hf, client_final, s_nonce, password, salt, iterations, auth_msg):
msg = _parse_message(client_final)
nonce = msg['r']
proof = msg['p']
if not nonce.endswith(s_nonce):
raise ScramException("Server nonce doesn't match.")
client_proof, server_signature = _proof_signature(
hf, password, salt, iterations, auth_msg)
if client_proof != proof:
raise ScramException("The proofs don't match")
return server_signature
def _get_server_final(server_signature):
return 'v=' + server_signature
def _set_server_final(message, server_signature):
msg = _parse_message(message)
if server_signature != msg['v']:
raise ScramException("The server signature doesn't match.")
def saslprep(source):
# mapping stage
# - map non-ascii spaces to U+0020 (stringprep C.1.2)
# - strip 'commonly mapped to nothing' chars (stringprep B.1)
data = ''.join(
' ' if in_table_c12(c) else c for c in source if not in_table_b1(c))
# normalize to KC form
data = unicodedata.normalize('NFKC', data)
if not data:
return ''
# check for invalid bi-directional strings.
# stringprep requires the following:
# - chars in C.8 must be prohibited.
# - if any R/AL chars in string:
# - no L chars allowed in string
# - first and last must be R/AL chars
# this checks if start/end are R/AL chars. if so, prohibited loop
# will forbid all L chars. if not, prohibited loop will forbid all
# R/AL chars instead. in both cases, prohibited loop takes care of C.8.
is_ral_char = in_table_d1
if is_ral_char(data[0]):
if not is_ral_char(data[-1]):
raise ValueError("malformed bidi sequence")
# forbid L chars within R/AL sequence.
is_forbidden_bidi_char = in_table_d2
else:
# forbid R/AL chars if start not setup correctly; L chars allowed.
is_forbidden_bidi_char = is_ral_char
# check for prohibited output
# stringprep tables A.1, B.1, C.1.2, C.2 - C.9
for c in data:
# check for chars mapping stage should have removed
assert not in_table_b1(c), "failed to strip B.1 in mapping stage"
assert not in_table_c12(c), "failed to replace C.1.2 in mapping stage"
# check for forbidden chars
for f, msg in (
(in_table_a1, "unassigned code points forbidden"),
(in_table_c21_c22, "control characters forbidden"),
(in_table_c3, "private use characters forbidden"),
(in_table_c4, "non-char code points forbidden"),
(in_table_c5, "surrogate codes forbidden"),
(in_table_c6, "non-plaintext chars forbidden"),
(in_table_c7, "non-canonical chars forbidden"),
(in_table_c8, "display-modifying/deprecated chars forbidden"),
(in_table_c9, "tagged characters forbidden"),
(is_forbidden_bidi_char, "forbidden bidi character")):
if f(c):
raise ValueError(msg)
return data