355 lines
11 KiB
Python
355 lines
11 KiB
Python
import hmac
|
|
from uuid import uuid4
|
|
from base64 import b64encode, b64decode
|
|
import hashlib
|
|
from stringprep import (
|
|
in_table_a1, in_table_b1, in_table_c21_c22, in_table_c3, in_table_c4,
|
|
in_table_c5, in_table_c6, in_table_c7, in_table_c8, in_table_c9,
|
|
in_table_c12, in_table_d1, in_table_d2)
|
|
import unicodedata
|
|
from os import urandom
|
|
from enum import IntEnum, unique
|
|
|
|
# https://tools.ietf.org/html/rfc5802
|
|
# https://www.rfc-editor.org/rfc/rfc7677.txt
|
|
|
|
|
|
@unique
|
|
class ClientStage(IntEnum):
|
|
get_client_first = 1
|
|
set_server_first = 2
|
|
get_client_final = 3
|
|
set_server_final = 4
|
|
|
|
|
|
@unique
|
|
class ServerStage(IntEnum):
|
|
set_client_first = 1
|
|
get_server_first = 2
|
|
set_client_final = 3
|
|
get_server_final = 4
|
|
|
|
|
|
def _check_stage(Stages, current_stage, next_stage):
|
|
if current_stage is None:
|
|
if next_stage != 1:
|
|
raise ScramException(
|
|
"The method " + Stages(1).name + " must be called first.")
|
|
elif current_stage == 4:
|
|
raise ScramException(
|
|
"The authentication sequence has already finished.")
|
|
elif next_stage != current_stage + 1:
|
|
raise ScramException(
|
|
"The next method to be called is " + Stages(current_stage + 1) +
|
|
", not this method.")
|
|
|
|
|
|
class ScramException(Exception):
|
|
pass
|
|
|
|
|
|
MECHANISMS = ('SCRAM-SHA-1', 'SCRAM-SHA-256')
|
|
|
|
|
|
HASHES = {
|
|
'SCRAM-SHA-1': hashlib.sha1,
|
|
'SCRAM-SHA-256': hashlib.sha256
|
|
}
|
|
|
|
|
|
class ScramClient():
|
|
def __init__(self, mechanisms, username, password, c_nonce=None):
|
|
self.mech = None
|
|
for mech in MECHANISMS:
|
|
if mech in mechanisms:
|
|
self.mech = mech
|
|
|
|
if self.mech is None:
|
|
raise ScramException(
|
|
"The only recognized mechanisms are " + str(MECHANISMS) +
|
|
"and none of those can be found in " + mechanisms + ".")
|
|
|
|
self.hf = HASHES[self.mech]
|
|
|
|
if c_nonce is None:
|
|
self.c_nonce = _make_nonce()
|
|
else:
|
|
self.c_nonce = c_nonce
|
|
|
|
self.username = username
|
|
self.password = password
|
|
self.stage = None
|
|
|
|
def _set_stage(self, next_stage):
|
|
_check_stage(ClientStage, self.stage, next_stage)
|
|
self.stage = next_stage
|
|
|
|
def get_client_first(self):
|
|
self._set_stage(ClientStage.get_client_first)
|
|
self.client_first_bare, client_first = _get_client_first(
|
|
self.username, self.c_nonce)
|
|
return client_first
|
|
|
|
def set_server_first(self, message):
|
|
self._set_stage(ClientStage.set_server_first)
|
|
self.server_first = message
|
|
self.auth_message, self.nonce, self.salt, self.iterations = \
|
|
_set_server_first(message, self.c_nonce, self.client_first_bare)
|
|
|
|
def get_client_final(self):
|
|
self._set_stage(ClientStage.get_client_final)
|
|
self.server_signature, cfinal = _get_client_final(
|
|
self.hf, self.password, self.salt, self.iterations, self.nonce,
|
|
self.auth_message)
|
|
return cfinal
|
|
|
|
def set_server_final(self, message):
|
|
self._set_stage(ClientStage.set_server_final)
|
|
_set_server_final(message, self.server_signature)
|
|
|
|
|
|
class ScramServer():
|
|
def __init__(
|
|
self, password_fn, s_nonce=None, iterations=4096, salt=None,
|
|
mechanism='SCRAM-SHA-256'):
|
|
if mechanism not in MECHANISMS:
|
|
raise ScramException(
|
|
"The only recognized mechanisms are " + str(MECHANISMS) +
|
|
".")
|
|
self.mechanism = mechanism
|
|
self.hf = HASHES[self.mechanism]
|
|
|
|
if s_nonce is None:
|
|
self.s_nonce = _make_nonce()
|
|
else:
|
|
self.s_nonce = s_nonce
|
|
|
|
if salt is None:
|
|
self.salt = _b64enc(urandom(16))
|
|
else:
|
|
self.salt = salt
|
|
|
|
self.password_fn = password_fn
|
|
self.iterations = iterations
|
|
self.stage = None
|
|
|
|
def _set_stage(self, next_stage):
|
|
_check_stage(ServerStage, self.stage, next_stage)
|
|
self.stage = next_stage
|
|
|
|
def set_client_first(self, client_first):
|
|
self._set_stage(ServerStage.set_client_first)
|
|
self.nonce, self.user, self.client_first_bare = _set_client_first(
|
|
client_first, self.s_nonce)
|
|
self.password = self.password_fn(self.user)
|
|
|
|
def get_server_first(self):
|
|
self._set_stage(ServerStage.get_server_first)
|
|
self.auth_message, server_first = _get_server_first(
|
|
self.nonce, self.salt, self.iterations, self.client_first_bare)
|
|
return server_first
|
|
|
|
def set_client_final(self, client_final):
|
|
self._set_stage(ServerStage.set_client_final)
|
|
self.server_signature = _set_client_final(
|
|
self.hf, client_final, self.s_nonce, self.password, self.salt,
|
|
self.iterations, self.auth_message)
|
|
|
|
def get_server_final(self):
|
|
self._set_stage(ServerStage.get_server_final)
|
|
return _get_server_final(self.server_signature)
|
|
|
|
|
|
def _make_nonce():
|
|
return str(uuid4()).replace('-', '')
|
|
|
|
|
|
def _make_auth_message(nonce, client_first_bare, server_first):
|
|
msg = client_first_bare, server_first, 'c=' + _b64enc(b'n,,'), 'r=' + nonce
|
|
return ','.join(msg)
|
|
|
|
|
|
def _proof_signature(hf, password, salt, iterations, auth_msg):
|
|
salted_password = _hi(
|
|
hf, _uenc(saslprep(password)), _b64dec(salt), iterations)
|
|
client_key = _hmac(hf, salted_password, b"Client Key")
|
|
stored_key = _h(hf, client_key)
|
|
|
|
client_signature = _hmac(hf, stored_key, _uenc(auth_msg))
|
|
client_proof = _xor(client_key, client_signature)
|
|
|
|
server_key = _hmac(hf, salted_password, b"Server Key")
|
|
server_signature = _hmac(hf, server_key, _uenc(auth_msg))
|
|
return _b64enc(client_proof), _b64enc(server_signature)
|
|
|
|
|
|
def _hmac(hf, key, msg):
|
|
return hmac.new(key, msg=msg, digestmod=hf).digest()
|
|
|
|
|
|
def _h(hf, msg):
|
|
return hf(msg).digest()
|
|
|
|
|
|
def _hi(hf, password, salt, iterations):
|
|
u = ui = _hmac(hf, password, salt + b'\x00\x00\x00\x01')
|
|
for i in range(iterations - 1):
|
|
ui = _hmac(hf, password, ui)
|
|
u = _xor(u, ui)
|
|
return u
|
|
|
|
|
|
def _hi_iter(password, mac, iterations):
|
|
if iterations == 0:
|
|
return mac
|
|
else:
|
|
new_mac = _hmac(password, mac)
|
|
return _xor(_hi_iter(password, new_mac, iterations-1), mac)
|
|
|
|
|
|
def _parse_message(msg):
|
|
return dict((e[0], e[2:]) for e in msg.split(',') if len(e) > 1)
|
|
|
|
|
|
def _b64enc(binary):
|
|
return b64encode(binary).decode('utf8')
|
|
|
|
|
|
def _b64dec(string):
|
|
return b64decode(string)
|
|
|
|
|
|
def _uenc(string):
|
|
return string.encode('utf-8')
|
|
|
|
|
|
def _xor(bytes1, bytes2):
|
|
return bytes(a ^ b for a, b in zip(bytes1, bytes2))
|
|
|
|
|
|
def _get_client_first(username, c_nonce):
|
|
bare = ','.join(('n=' + saslprep(username), 'r=' + c_nonce))
|
|
return bare, 'n,,' + bare
|
|
|
|
|
|
def _set_client_first(client_first, s_nonce):
|
|
msg = _parse_message(client_first)
|
|
c_nonce = msg['r']
|
|
nonce = c_nonce + s_nonce
|
|
user = msg['n']
|
|
client_first_bare = client_first[3:]
|
|
|
|
return nonce, user, client_first_bare
|
|
|
|
|
|
def _get_server_first(nonce, salt, iterations, client_first_bare):
|
|
sfirst = ','.join(('r=' + nonce, 's=' + salt, 'i=' + str(iterations)))
|
|
auth_msg = _make_auth_message(nonce, client_first_bare, sfirst)
|
|
return auth_msg, sfirst
|
|
|
|
|
|
def _set_server_first(server_first, c_nonce, client_first_bare):
|
|
msg = _parse_message(server_first)
|
|
nonce = msg['r']
|
|
salt = msg['s']
|
|
iterations = int(msg['i'])
|
|
|
|
if not nonce.startswith(c_nonce):
|
|
raise ScramException("Client nonce doesn't match.")
|
|
|
|
auth_msg = _make_auth_message(nonce, client_first_bare, server_first)
|
|
return auth_msg, nonce, salt, iterations
|
|
|
|
|
|
def _get_client_final(hf, password, salt, iterations, nonce, auth_msg):
|
|
client_proof, server_signature = _proof_signature(
|
|
hf, password, salt, iterations, auth_msg)
|
|
|
|
message = ['c=' + _b64enc(b'n,,'), 'r=' + nonce, 'p=' + client_proof]
|
|
return server_signature, ','.join(message)
|
|
|
|
|
|
def _set_client_final(
|
|
hf, client_final, s_nonce, password, salt, iterations, auth_msg):
|
|
|
|
msg = _parse_message(client_final)
|
|
nonce = msg['r']
|
|
proof = msg['p']
|
|
|
|
if not nonce.endswith(s_nonce):
|
|
raise ScramException("Server nonce doesn't match.")
|
|
|
|
client_proof, server_signature = _proof_signature(
|
|
hf, password, salt, iterations, auth_msg)
|
|
|
|
if client_proof != proof:
|
|
raise ScramException("The proofs don't match")
|
|
|
|
return server_signature
|
|
|
|
|
|
def _get_server_final(server_signature):
|
|
return 'v=' + server_signature
|
|
|
|
|
|
def _set_server_final(message, server_signature):
|
|
msg = _parse_message(message)
|
|
if server_signature != msg['v']:
|
|
raise ScramException("The server signature doesn't match.")
|
|
|
|
|
|
def saslprep(source):
|
|
# mapping stage
|
|
# - map non-ascii spaces to U+0020 (stringprep C.1.2)
|
|
# - strip 'commonly mapped to nothing' chars (stringprep B.1)
|
|
data = ''.join(
|
|
' ' if in_table_c12(c) else c for c in source if not in_table_b1(c))
|
|
|
|
# normalize to KC form
|
|
data = unicodedata.normalize('NFKC', data)
|
|
if not data:
|
|
return ''
|
|
|
|
# check for invalid bi-directional strings.
|
|
# stringprep requires the following:
|
|
# - chars in C.8 must be prohibited.
|
|
# - if any R/AL chars in string:
|
|
# - no L chars allowed in string
|
|
# - first and last must be R/AL chars
|
|
# this checks if start/end are R/AL chars. if so, prohibited loop
|
|
# will forbid all L chars. if not, prohibited loop will forbid all
|
|
# R/AL chars instead. in both cases, prohibited loop takes care of C.8.
|
|
is_ral_char = in_table_d1
|
|
if is_ral_char(data[0]):
|
|
if not is_ral_char(data[-1]):
|
|
raise ValueError("malformed bidi sequence")
|
|
# forbid L chars within R/AL sequence.
|
|
is_forbidden_bidi_char = in_table_d2
|
|
else:
|
|
# forbid R/AL chars if start not setup correctly; L chars allowed.
|
|
is_forbidden_bidi_char = is_ral_char
|
|
|
|
# check for prohibited output
|
|
# stringprep tables A.1, B.1, C.1.2, C.2 - C.9
|
|
for c in data:
|
|
# check for chars mapping stage should have removed
|
|
assert not in_table_b1(c), "failed to strip B.1 in mapping stage"
|
|
assert not in_table_c12(c), "failed to replace C.1.2 in mapping stage"
|
|
|
|
# check for forbidden chars
|
|
for f, msg in (
|
|
(in_table_a1, "unassigned code points forbidden"),
|
|
(in_table_c21_c22, "control characters forbidden"),
|
|
(in_table_c3, "private use characters forbidden"),
|
|
(in_table_c4, "non-char code points forbidden"),
|
|
(in_table_c5, "surrogate codes forbidden"),
|
|
(in_table_c6, "non-plaintext chars forbidden"),
|
|
(in_table_c7, "non-canonical chars forbidden"),
|
|
(in_table_c8, "display-modifying/deprecated chars forbidden"),
|
|
(in_table_c9, "tagged characters forbidden"),
|
|
(is_forbidden_bidi_char, "forbidden bidi character")):
|
|
if f(c):
|
|
raise ValueError(msg)
|
|
|
|
return data
|