# -*- coding: utf-8 -*- # # Copyright 2017 Gehirn Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hmac from typing import ( Callable, Union, ) from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.rsa import ( rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp, rsa_recover_prime_factors, RSAPrivateKey, RSAPrivateNumbers, RSAPublicKey, RSAPublicNumbers, ) from cryptography.hazmat.primitives.serialization import ( load_pem_private_key, load_pem_public_key, ) from .exceptions import ( MalformedJWKError, UnsupportedKeyTypeError, ) from .utils import ( b64encode, b64decode, uint_b64encode, uint_b64decode, ) class AbstractJWKBase: def get_kty(self): raise NotImplementedError() # pragma: no cover def get_kid(self): raise NotImplementedError() # pragma: no cover def is_sign_key(self) -> bool: raise NotImplementedError() # pragma: no cover def sign(self, message: bytes, **options) -> bytes: raise NotImplementedError() # pragma: no cover def verify(self, message: bytes, signature: bytes, **options) -> bool: raise NotImplementedError() # pragma: no cover def to_dict(self, public_only=True): raise NotImplementedError() # pragma: no cover @classmethod def from_dict(cls, dct): raise NotImplementedError() # pragma: no cover class OctetJWK(AbstractJWKBase): def __init__(self, key: bytes, kid=None, **options) -> None: super(AbstractJWKBase, self).__init__() self.key = key self.kid = kid optnames = {'use', 'key_ops', 'alg', 'x5u', 'x5c', 'x5t', 'x5t#s256'} self.options = {k: v for k, v in options.items() if k in optnames} def get_kty(self): return 'oct' def get_kid(self): return self.kid def is_sign_key(self) -> bool: return True def sign(self, message: bytes, signer: Callable[[bytes, bytes], bytes] = None, **options) -> bytes: return signer(message, self.key) def verify(self, message: bytes, signature: bytes, signer: Callable[[bytes, bytes], bytes] = None, **options) -> bool: return hmac.compare_digest(signature, signer(message, self.key)) def to_dict(self, public_only=True): dct = { 'kty': 'oct', 'k': b64encode(self.key), } dct.update(self.options) if self.kid: dct['kid'] = self.kid return dct @classmethod def from_dict(cls, dct): try: return cls(b64decode(dct['k']), **dct) except KeyError as why: raise MalformedJWKError('k is required') class RSAJWK(AbstractJWKBase): """ https://tools.ietf.org/html/rfc7518.html#section-6.3.1 """ def __init__(self, keyobj: Union[RSAPrivateKey, RSAPublicKey], **options) -> None: super(AbstractJWKBase, self).__init__() self.keyobj = keyobj optnames = {'use', 'key_ops', 'alg', 'kid', 'x5u', 'x5c', 'x5t', 'x5t#s256'} self.options = {k: v for k, v in options.items() if k in optnames} def is_sign_key(self) -> bool: return isinstance(self.keyobj, RSAPrivateKey) def sign(self, message: bytes, hash_fun: Callable = None, **options) -> bytes: return self.keyobj.sign(message, padding.PKCS1v15(), hash_fun()) def verify(self, message: bytes, signature: bytes, hash_fun: Callable = None, **options) -> bool: if self.is_sign_key(): pubkey = self.keyobj.public_key() else: pubkey = self.keyobj try: pubkey.verify(signature, message, padding.PKCS1v15(), hash_fun()) return True except InvalidSignature: return False def get_kty(self): return 'RSA' def get_kid(self): return self.options.get('kid') def to_dict(self, public_only=True): dct = { 'kty': 'RSA', } dct.update(self.options) if isinstance(self.keyobj, RSAPrivateKey): priv_numbers = self.keyobj.private_numbers() pub_numbers = priv_numbers.public_numbers dct.update({ 'e': uint_b64encode(pub_numbers.e), 'n': uint_b64encode(pub_numbers.n), }) if not public_only: dct.update({ 'e': uint_b64encode(pub_numbers.e), 'n': uint_b64encode(pub_numbers.n), 'd': uint_b64encode(priv_numbers.d), 'p': uint_b64encode(priv_numbers.p), 'q': uint_b64encode(priv_numbers.q), 'dp': uint_b64encode(priv_numbers.dmp1), 'dq': uint_b64encode(priv_numbers.dmq1), 'qi': uint_b64encode(priv_numbers.iqmp), }) return dct pub_numbers = self.keyobj.public_numbers() dct.update({ 'e': uint_b64encode(pub_numbers.e), 'n': uint_b64encode(pub_numbers.n), }) return dct @classmethod def from_dict(cls, dct): if 'oth' in dct: raise UnsupportedKeyTypeError( 'RSA keys with multiples primes are not supported') try: e = uint_b64decode(dct['e']) n = uint_b64decode(dct['n']) except KeyError as why: raise MalformedJWKError('e and n are required') pub_numbers = RSAPublicNumbers(e, n) if 'd' not in dct: return cls( pub_numbers.public_key(backend=default_backend()), **dct) d = uint_b64decode(dct['d']) privparams = {'p', 'q', 'dp', 'dq', 'qi'} product = set(dct.keys()) & privparams if len(product) == 0: p, q = rsa_recover_prime_factors(n, e, d) priv_numbers = RSAPrivateNumbers( d=d, p=p, q=q, dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), public_numbers=pub_numbers) elif product == privparams: priv_numbers = RSAPrivateNumbers( d=d, p=uint_b64decode(dct['p']), q=uint_b64decode(dct['q']), dmp1=uint_b64decode(dct['dp']), dmq1=uint_b64decode(dct['dq']), iqmp=uint_b64decode(dct['qi']), public_numbers=pub_numbers) else: # If the producer includes any of the other private key parameters, # then all of the others MUST be present, with the exception of # "oth", which MUST only be present when more than two prime # factors were used. raise MalformedJWKError( 'p, q, dp, dq, qi MUST be present or' 'all of them MUST be absent') return cls(priv_numbers.private_key(backend=default_backend()), **dct) def supported_key_types(): return { 'oct': OctetJWK, 'RSA': RSAJWK, } def jwk_from_dict(dct): if 'kty' not in dct: raise MalformedJWKError('kty MUST be present') supported = supported_key_types() kty = dct['kty'] if kty not in supported: raise UnsupportedKeyTypeError('unsupported key type: {}'.format(kty)) return supported[kty].from_dict(dct) def jwk_from_pem(pem_content: bytes) -> AbstractJWKBase: try: privkey = load_pem_private_key( pem_content, password=None, backend=default_backend()) if isinstance(privkey, RSAPrivateKey): return RSAJWK(privkey) raise UnsupportedKeyTypeError( 'unsupported key type') # pragma: no cover except ValueError: pass try: pubkey = load_pem_public_key(pem_content, backend=default_backend()) if isinstance(pubkey, RSAPublicKey): return RSAJWK(pubkey) raise UnsupportedKeyTypeError( 'unsupported key type') # pragma: no cover except ValueError as why: raise UnsupportedKeyTypeError('could not deserialize PEM') from why