|
- """passlib.handlers.scram - hash for SCRAM credential storage"""
- #=============================================================================
- # imports
- #=============================================================================
- # core
- import logging; log = logging.getLogger(__name__)
- # site
- # pkg
- from passlib.utils import consteq, saslprep, to_native_str, splitcomma
- from passlib.utils.binary import ab64_decode, ab64_encode
- from passlib.utils.compat import bascii_to_str, iteritems, u, native_string_types
- from passlib.crypto.digest import pbkdf2_hmac, norm_hash_name
- import passlib.utils.handlers as uh
- # local
- __all__ = [
- "scram",
- ]
-
- #=============================================================================
- # scram credentials hash
- #=============================================================================
- class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler):
- """This class provides a format for storing SCRAM passwords, and follows
- the :ref:`password-hash-api`.
-
- It supports a variable-length salt, and a variable number of rounds.
-
- The :meth:`~passlib.ifc.PasswordHash.using` method accepts the following optional keywords:
-
- :type salt: bytes
- :param salt:
- Optional salt bytes.
- If specified, the length must be between 0-1024 bytes.
- If not specified, a 12 byte salt will be autogenerated
- (this is recommended).
-
- :type salt_size: int
- :param salt_size:
- Optional number of bytes to use when autogenerating new salts.
- Defaults to 12 bytes, but can be any value between 0 and 1024.
-
- :type rounds: int
- :param rounds:
- Optional number of rounds to use.
- Defaults to 100000, but must be within ``range(1,1<<32)``.
-
- :type algs: list of strings
- :param algs:
- Specify list of digest algorithms to use.
-
- By default each scram hash will contain digests for SHA-1,
- SHA-256, and SHA-512. This can be overridden by specify either be a
- list such as ``["sha-1", "sha-256"]``, or a comma-separated string
- such as ``"sha-1, sha-256"``. Names are case insensitive, and may
- use :mod:`!hashlib` or `IANA <http://www.iana.org/assignments/hash-function-text-names>`_
- hash names.
-
- :type relaxed: bool
- :param relaxed:
- By default, providing an invalid value for one of the other
- keywords will result in a :exc:`ValueError`. If ``relaxed=True``,
- and the error can be corrected, a :exc:`~passlib.exc.PasslibHashWarning`
- will be issued instead. Correctable errors include ``rounds``
- that are too small or too large, and ``salt`` strings that are too long.
-
- .. versionadded:: 1.6
-
- In addition to the standard :ref:`password-hash-api` methods,
- this class also provides the following methods for manipulating Passlib
- scram hashes in ways useful for pluging into a SCRAM protocol stack:
-
- .. automethod:: extract_digest_info
- .. automethod:: extract_digest_algs
- .. automethod:: derive_digest
- """
- #===================================================================
- # class attrs
- #===================================================================
-
- # NOTE: unlike most GenericHandler classes, the 'checksum' attr of
- # ScramHandler is actually a map from digest_name -> digest, so
- # many of the standard methods have been overridden.
-
- # NOTE: max_salt_size and max_rounds are arbitrarily chosen to provide
- # a sanity check; the underlying pbkdf2 specifies no bounds for either.
-
- #--GenericHandler--
- name = "scram"
- setting_kwds = ("salt", "salt_size", "rounds", "algs")
- ident = u("$scram$")
-
- #--HasSalt--
- default_salt_size = 12
- max_salt_size = 1024
-
- #--HasRounds--
- default_rounds = 100000
- min_rounds = 1
- max_rounds = 2**32-1
- rounds_cost = "linear"
-
- #--custom--
-
- # default algorithms when creating new hashes.
- default_algs = ["sha-1", "sha-256", "sha-512"]
-
- # list of algs verify prefers to use, in order.
- _verify_algs = ["sha-256", "sha-512", "sha-224", "sha-384", "sha-1"]
-
- #===================================================================
- # instance attrs
- #===================================================================
-
- # 'checksum' is different from most GenericHandler subclasses,
- # in that it contains a dict mapping from alg -> digest,
- # or None if no checksum present.
-
- # list of algorithms to create/compare digests for.
- algs = None
-
- #===================================================================
- # scram frontend helpers
- #===================================================================
- @classmethod
- def extract_digest_info(cls, hash, alg):
- """return (salt, rounds, digest) for specific hash algorithm.
-
- :type hash: str
- :arg hash:
- :class:`!scram` hash stored for desired user
-
- :type alg: str
- :arg alg:
- Name of digest algorithm (e.g. ``"sha-1"``) requested by client.
-
- This value is run through :func:`~passlib.crypto.digest.norm_hash_name`,
- so it is case-insensitive, and can be the raw SCRAM
- mechanism name (e.g. ``"SCRAM-SHA-1"``), the IANA name,
- or the hashlib name.
-
- :raises KeyError:
- If the hash does not contain an entry for the requested digest
- algorithm.
-
- :returns:
- A tuple containing ``(salt, rounds, digest)``,
- where *digest* matches the raw bytes returned by
- SCRAM's :func:`Hi` function for the stored password,
- the provided *salt*, and the iteration count (*rounds*).
- *salt* and *digest* are both raw (unencoded) bytes.
- """
- # XXX: this could be sped up by writing custom parsing routine
- # that just picks out relevant digest, and doesn't bother
- # with full structure validation each time it's called.
- alg = norm_hash_name(alg, 'iana')
- self = cls.from_string(hash)
- chkmap = self.checksum
- if not chkmap:
- raise ValueError("scram hash contains no digests")
- return self.salt, self.rounds, chkmap[alg]
-
- @classmethod
- def extract_digest_algs(cls, hash, format="iana"):
- """Return names of all algorithms stored in a given hash.
-
- :type hash: str
- :arg hash:
- The :class:`!scram` hash to parse
-
- :type format: str
- :param format:
- This changes the naming convention used by the
- returned algorithm names. By default the names
- are IANA-compatible; possible values are ``"iana"`` or ``"hashlib"``.
-
- :returns:
- Returns a list of digest algorithms; e.g. ``["sha-1"]``
- """
- # XXX: this could be sped up by writing custom parsing routine
- # that just picks out relevant names, and doesn't bother
- # with full structure validation each time it's called.
- algs = cls.from_string(hash).algs
- if format == "iana":
- return algs
- else:
- return [norm_hash_name(alg, format) for alg in algs]
-
- @classmethod
- def derive_digest(cls, password, salt, rounds, alg):
- """helper to create SaltedPassword digest for SCRAM.
-
- This performs the step in the SCRAM protocol described as::
-
- SaltedPassword := Hi(Normalize(password), salt, i)
-
- :type password: unicode or utf-8 bytes
- :arg password: password to run through digest
-
- :type salt: bytes
- :arg salt: raw salt data
-
- :type rounds: int
- :arg rounds: number of iterations.
-
- :type alg: str
- :arg alg: name of digest to use (e.g. ``"sha-1"``).
-
- :returns:
- raw bytes of ``SaltedPassword``
- """
- if isinstance(password, bytes):
- password = password.decode("utf-8")
- # NOTE: pbkdf2_hmac() will encode secret & salt using utf-8,
- # and handle normalizing alg name.
- return pbkdf2_hmac(alg, saslprep(password), salt, rounds)
-
- #===================================================================
- # serialization
- #===================================================================
-
- @classmethod
- def from_string(cls, hash):
- hash = to_native_str(hash, "ascii", "hash")
- if not hash.startswith("$scram$"):
- raise uh.exc.InvalidHashError(cls)
- parts = hash[7:].split("$")
- if len(parts) != 3:
- raise uh.exc.MalformedHashError(cls)
- rounds_str, salt_str, chk_str = parts
-
- # decode rounds
- rounds = int(rounds_str)
- if rounds_str != str(rounds): # forbid zero padding, etc.
- raise uh.exc.MalformedHashError(cls)
-
- # decode salt
- try:
- salt = ab64_decode(salt_str.encode("ascii"))
- except TypeError:
- raise uh.exc.MalformedHashError(cls)
-
- # decode algs/digest list
- if not chk_str:
- # scram hashes MUST have something here.
- raise uh.exc.MalformedHashError(cls)
- elif "=" in chk_str:
- # comma-separated list of 'alg=digest' pairs
- algs = None
- chkmap = {}
- for pair in chk_str.split(","):
- alg, digest = pair.split("=")
- try:
- chkmap[alg] = ab64_decode(digest.encode("ascii"))
- except TypeError:
- raise uh.exc.MalformedHashError(cls)
- else:
- # comma-separated list of alg names, no digests
- algs = chk_str
- chkmap = None
-
- # return new object
- return cls(
- rounds=rounds,
- salt=salt,
- checksum=chkmap,
- algs=algs,
- )
-
- def to_string(self):
- salt = bascii_to_str(ab64_encode(self.salt))
- chkmap = self.checksum
- chk_str = ",".join(
- "%s=%s" % (alg, bascii_to_str(ab64_encode(chkmap[alg])))
- for alg in self.algs
- )
- return '$scram$%d$%s$%s' % (self.rounds, salt, chk_str)
-
- #===================================================================
- # variant constructor
- #===================================================================
- @classmethod
- def using(cls, default_algs=None, algs=None, **kwds):
- # parse aliases
- if algs is not None:
- assert default_algs is None
- default_algs = algs
-
- # create subclass
- subcls = super(scram, cls).using(**kwds)
-
- # fill in algs
- if default_algs is not None:
- subcls.default_algs = cls._norm_algs(default_algs)
- return subcls
-
- #===================================================================
- # init
- #===================================================================
- def __init__(self, algs=None, **kwds):
- super(scram, self).__init__(**kwds)
-
- # init algs
- digest_map = self.checksum
- if algs is not None:
- if digest_map is not None:
- raise RuntimeError("checksum & algs kwds are mutually exclusive")
- algs = self._norm_algs(algs)
- elif digest_map is not None:
- # derive algs list from digest map (if present).
- algs = self._norm_algs(digest_map.keys())
- elif self.use_defaults:
- algs = list(self.default_algs)
- assert self._norm_algs(algs) == algs, "invalid default algs: %r" % (algs,)
- else:
- raise TypeError("no algs list specified")
- self.algs = algs
-
- def _norm_checksum(self, checksum, relaxed=False):
- if not isinstance(checksum, dict):
- raise uh.exc.ExpectedTypeError(checksum, "dict", "checksum")
- for alg, digest in iteritems(checksum):
- if alg != norm_hash_name(alg, 'iana'):
- raise ValueError("malformed algorithm name in scram hash: %r" %
- (alg,))
- if len(alg) > 9:
- raise ValueError("SCRAM limits algorithm names to "
- "9 characters: %r" % (alg,))
- if not isinstance(digest, bytes):
- raise uh.exc.ExpectedTypeError(digest, "raw bytes", "digests")
- # TODO: verify digest size (if digest is known)
- if 'sha-1' not in checksum:
- # NOTE: required because of SCRAM spec.
- raise ValueError("sha-1 must be in algorithm list of scram hash")
- return checksum
-
- @classmethod
- def _norm_algs(cls, algs):
- """normalize algs parameter"""
- if isinstance(algs, native_string_types):
- algs = splitcomma(algs)
- algs = sorted(norm_hash_name(alg, 'iana') for alg in algs)
- if any(len(alg)>9 for alg in algs):
- raise ValueError("SCRAM limits alg names to max of 9 characters")
- if 'sha-1' not in algs:
- # NOTE: required because of SCRAM spec (rfc 5802)
- raise ValueError("sha-1 must be in algorithm list of scram hash")
- return algs
-
- #===================================================================
- # migration
- #===================================================================
- def _calc_needs_update(self, **kwds):
- # marks hashes as deprecated if they don't include at least all default_algs.
- # XXX: should we deprecate if they aren't exactly the same,
- # to permit removing legacy hashes?
- if not set(self.algs).issuperset(self.default_algs):
- return True
-
- # hand off to base implementation
- return super(scram, self)._calc_needs_update(**kwds)
-
- #===================================================================
- # digest methods
- #===================================================================
- def _calc_checksum(self, secret, alg=None):
- rounds = self.rounds
- salt = self.salt
- hash = self.derive_digest
- if alg:
- # if requested, generate digest for specific alg
- return hash(secret, salt, rounds, alg)
- else:
- # by default, return dict containing digests for all algs
- return dict(
- (alg, hash(secret, salt, rounds, alg))
- for alg in self.algs
- )
-
- @classmethod
- def verify(cls, secret, hash, full=False):
- uh.validate_secret(secret)
- self = cls.from_string(hash)
- chkmap = self.checksum
- if not chkmap:
- raise ValueError("expected %s hash, got %s config string instead" %
- (cls.name, cls.name))
-
- # NOTE: to make the verify method efficient, we just calculate hash
- # of shortest digest by default. apps can pass in "full=True" to
- # check entire hash for consistency.
- if full:
- correct = failed = False
- for alg, digest in iteritems(chkmap):
- other = self._calc_checksum(secret, alg)
- # NOTE: could do this length check in norm_algs(),
- # but don't need to be that strict, and want to be able
- # to parse hashes containing algs not supported by platform.
- # it's fine if we fail here though.
- if len(digest) != len(other):
- raise ValueError("mis-sized %s digest in scram hash: %r != %r"
- % (alg, len(digest), len(other)))
- if consteq(other, digest):
- correct = True
- else:
- failed = True
- if correct and failed:
- raise ValueError("scram hash verified inconsistently, "
- "may be corrupted")
- else:
- return correct
- else:
- # XXX: should this just always use sha1 hash? would be faster.
- # otherwise only verify against one hash, pick one w/ best security.
- for alg in self._verify_algs:
- if alg in chkmap:
- other = self._calc_checksum(secret, alg)
- return consteq(other, chkmap[alg])
- # there should always be sha-1 at the very least,
- # or something went wrong inside _norm_algs()
- raise AssertionError("sha-1 digest not found!")
-
- #===================================================================
- #
- #===================================================================
-
- #=============================================================================
- # code used for testing scram against protocol examples during development.
- #=============================================================================
- ##def _test_reference_scram():
- ## "quick hack testing scram reference vectors"
- ## # NOTE: "n,," is GS2 header - see https://tools.ietf.org/html/rfc5801
- ## from passlib.utils.compat import print_
- ##
- ## engine = _scram_engine(
- ## alg="sha-1",
- ## salt='QSXCR+Q6sek8bf92'.decode("base64"),
- ## rounds=4096,
- ## password=u("pencil"),
- ## )
- ## print_(engine.digest.encode("base64").rstrip())
- ##
- ## msg = engine.format_auth_msg(
- ## username="user",
- ## client_nonce = "fyko+d2lbbFgONRv9qkxdawL",
- ## server_nonce = "3rfcNHYJY1ZVvWVs7j",
- ## header='c=biws',
- ## )
- ##
- ## cp = engine.get_encoded_client_proof(msg)
- ## assert cp == "v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=", cp
- ##
- ## ss = engine.get_encoded_server_sig(msg)
- ## assert ss == "rmF9pqV8S7suAoZWja4dJRkFsKQ=", ss
- ##
- ##class _scram_engine(object):
- ## """helper class for verifying scram hash behavior
- ## against SCRAM protocol examples. not officially part of Passlib.
- ##
- ## takes in alg, salt, rounds, and a digest or password.
- ##
- ## can calculate the various keys & messages of the scram protocol.
- ##
- ## """
- ## #=========================================================
- ## # init
- ## #=========================================================
- ##
- ## @classmethod
- ## def from_string(cls, hash, alg):
- ## "create record from scram hash, for given alg"
- ## return cls(alg, *scram.extract_digest_info(hash, alg))
- ##
- ## def __init__(self, alg, salt, rounds, digest=None, password=None):
- ## self.alg = norm_hash_name(alg)
- ## self.salt = salt
- ## self.rounds = rounds
- ## self.password = password
- ## if password:
- ## data = scram.derive_digest(password, salt, rounds, alg)
- ## if digest and data != digest:
- ## raise ValueError("password doesn't match digest")
- ## else:
- ## digest = data
- ## elif not digest:
- ## raise TypeError("must provide password or digest")
- ## self.digest = digest
- ##
- ## #=========================================================
- ## # frontend methods
- ## #=========================================================
- ## def get_hash(self, data):
- ## "return hash of raw data"
- ## return hashlib.new(iana_to_hashlib(self.alg), data).digest()
- ##
- ## def get_client_proof(self, msg):
- ## "return client proof of specified auth msg text"
- ## return xor_bytes(self.client_key, self.get_client_sig(msg))
- ##
- ## def get_encoded_client_proof(self, msg):
- ## return self.get_client_proof(msg).encode("base64").rstrip()
- ##
- ## def get_client_sig(self, msg):
- ## "return client signature of specified auth msg text"
- ## return self.get_hmac(self.stored_key, msg)
- ##
- ## def get_server_sig(self, msg):
- ## "return server signature of specified auth msg text"
- ## return self.get_hmac(self.server_key, msg)
- ##
- ## def get_encoded_server_sig(self, msg):
- ## return self.get_server_sig(msg).encode("base64").rstrip()
- ##
- ## def format_server_response(self, client_nonce, server_nonce):
- ## return 'r={client_nonce}{server_nonce},s={salt},i={rounds}'.format(
- ## client_nonce=client_nonce,
- ## server_nonce=server_nonce,
- ## rounds=self.rounds,
- ## salt=self.encoded_salt,
- ## )
- ##
- ## def format_auth_msg(self, username, client_nonce, server_nonce,
- ## header='c=biws'):
- ## return (
- ## 'n={username},r={client_nonce}'
- ## ','
- ## 'r={client_nonce}{server_nonce},s={salt},i={rounds}'
- ## ','
- ## '{header},r={client_nonce}{server_nonce}'
- ## ).format(
- ## username=username,
- ## client_nonce=client_nonce,
- ## server_nonce=server_nonce,
- ## salt=self.encoded_salt,
- ## rounds=self.rounds,
- ## header=header,
- ## )
- ##
- ## #=========================================================
- ## # helpers to calculate & cache constant data
- ## #=========================================================
- ## def _calc_get_hmac(self):
- ## return get_prf("hmac-" + iana_to_hashlib(self.alg))[0]
- ##
- ## def _calc_client_key(self):
- ## return self.get_hmac(self.digest, b("Client Key"))
- ##
- ## def _calc_stored_key(self):
- ## return self.get_hash(self.client_key)
- ##
- ## def _calc_server_key(self):
- ## return self.get_hmac(self.digest, b("Server Key"))
- ##
- ## def _calc_encoded_salt(self):
- ## return self.salt.encode("base64").rstrip()
- ##
- ## #=========================================================
- ## # hacks for calculated attributes
- ## #=========================================================
- ##
- ## def __getattr__(self, attr):
- ## if not attr.startswith("_"):
- ## f = getattr(self, "_calc_" + attr, None)
- ## if f:
- ## value = f()
- ## setattr(self, attr, value)
- ## return value
- ## raise AttributeError("attribute not found")
- ##
- ## def __dir__(self):
- ## cdir = dir(self.__class__)
- ## attrs = set(cdir)
- ## attrs.update(self.__dict__)
- ## attrs.update(attr[6:] for attr in cdir
- ## if attr.startswith("_calc_"))
- ## return sorted(attrs)
- ## #=========================================================
- ## # eoc
- ## #=========================================================
-
- #=============================================================================
- # eof
- #=============================================================================
|