# Copyright 2017 Gehirn Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hmac from abc import ( ABC, abstractmethod, ) from collections.abc import Mapping from functools import wraps from typing import ( Any, Callable, Optional, TypeVar, Union, ) from warnings import warn import cryptography.hazmat.primitives.serialization as serialization_module from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.rsa import ( RSAPrivateKey, RSAPrivateNumbers, RSAPublicKey, RSAPublicNumbers, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp, rsa_recover_prime_factors, ) from cryptography.hazmat.primitives.hashes import HashAlgorithm from .exceptions import ( MalformedJWKError, UnsupportedKeyTypeError, ) from .utils import ( b64decode, b64encode, uint_b64decode, uint_b64encode, ) _AJWK = TypeVar("_AJWK", bound="AbstractJWKBase") _T = TypeVar("_T") class AbstractJWKBase(ABC): @abstractmethod def get_kty(self) -> str: pass # pragma: no cover @abstractmethod def get_kid(self) -> str: pass # pragma: no cover @abstractmethod def is_sign_key(self) -> bool: pass # pragma: no cover @abstractmethod def sign(self, message: bytes, **options) -> bytes: pass # pragma: no cover @abstractmethod def verify(self, message: bytes, signature: bytes, **options) -> bool: pass # pragma: no cover @abstractmethod def to_dict(self, public_only: bool = True) -> dict[str, str]: pass # pragma: no cover @classmethod @abstractmethod def from_dict(cls: type[_AJWK], dct: dict[str, object]) -> _AJWK: pass # pragma: no cover class OctetJWK(AbstractJWKBase): def __init__(self, key: bytes, kid=None, **options) -> None: super(AbstractJWKBase, self).__init__() self.key = key self.kid = kid optnames = {"use", "key_ops", "alg", "x5u", "x5c", "x5t", "x5t#s256"} self.options = {k: v for k, v in options.items() if k in optnames} def get_kty(self): return "oct" def get_kid(self): return self.kid def is_sign_key(self) -> bool: return True def _get_signer(self, options) -> Callable[[bytes, bytes], bytes]: return options["signer"] def sign(self, message: bytes, **options) -> bytes: signer = self._get_signer(options) return signer(message, self.key) def verify(self, message: bytes, signature: bytes, **options) -> bool: signer = self._get_signer(options) return hmac.compare_digest(signature, signer(message, self.key)) def to_dict(self, public_only=True): dct = { "kty": "oct", "k": b64encode(self.key), } dct.update(self.options) if self.kid: dct["kid"] = self.kid return dct @classmethod def from_dict(cls, dct): try: return cls(b64decode(dct["k"]), **dct) except KeyError as why: raise MalformedJWKError("k is required") from why class RSAJWK(AbstractJWKBase): """ https://tools.ietf.org/html/rfc7518.html#section-6.3.1 """ def __init__( self, keyobj: Union[RSAPrivateKey, RSAPublicKey], **options ) -> None: super(AbstractJWKBase, self).__init__() self.keyobj = keyobj optnames = { "use", "key_ops", "alg", "kid", "x5u", "x5c", "x5t", "x5t#s256", } self.options = {k: v for k, v in options.items() if k in optnames} def is_sign_key(self) -> bool: return isinstance(self.keyobj, RSAPrivateKey) def _get_hash_fun(self, options) -> Callable[[], HashAlgorithm]: return options["hash_fun"] def _get_padding(self, options) -> padding.AsymmetricPadding: try: return options["padding"] except KeyError: warn( "you should not use RSAJWK.verify/sign without jwa " "intermiediary, used legacy padding" ) return padding.PKCS1v15() def sign(self, message: bytes, **options) -> bytes: if isinstance(self.keyobj, RSAPublicKey): raise ValueError("Requires a private key.") hash_fun = self._get_hash_fun(options) _padding = self._get_padding(options) return self.keyobj.sign(message, _padding, hash_fun()) def verify(self, message: bytes, signature: bytes, **options) -> bool: hash_fun = self._get_hash_fun(options) _padding = self._get_padding(options) if isinstance(self.keyobj, RSAPrivateKey): pubkey = self.keyobj.public_key() else: pubkey = self.keyobj try: pubkey.verify(signature, message, _padding, hash_fun()) return True except InvalidSignature: return False def get_kty(self): return "RSA" def get_kid(self): return self.options.get("kid") def to_dict(self, public_only=True): dct = { "kty": "RSA", } dct.update(self.options) if isinstance(self.keyobj, RSAPrivateKey): priv_numbers = self.keyobj.private_numbers() pub_numbers = priv_numbers.public_numbers dct.update( { "e": uint_b64encode(pub_numbers.e), "n": uint_b64encode(pub_numbers.n), } ) if not public_only: dct.update( { "e": uint_b64encode(pub_numbers.e), "n": uint_b64encode(pub_numbers.n), "d": uint_b64encode(priv_numbers.d), "p": uint_b64encode(priv_numbers.p), "q": uint_b64encode(priv_numbers.q), "dp": uint_b64encode(priv_numbers.dmp1), "dq": uint_b64encode(priv_numbers.dmq1), "qi": uint_b64encode(priv_numbers.iqmp), } ) return dct pub_numbers = self.keyobj.public_numbers() dct.update( { "e": uint_b64encode(pub_numbers.e), "n": uint_b64encode(pub_numbers.n), } ) return dct @classmethod def from_dict(cls, dct): if "oth" in dct: raise UnsupportedKeyTypeError( "RSA keys with multiples primes are not supported" ) try: e = uint_b64decode(dct["e"]) n = uint_b64decode(dct["n"]) except KeyError as why: raise MalformedJWKError("e and n are required") from why pub_numbers = RSAPublicNumbers(e, n) if "d" not in dct: return cls( pub_numbers.public_key(backend=default_backend()), **dct ) d = uint_b64decode(dct["d"]) privparams = {"p", "q", "dp", "dq", "qi"} product = set(dct.keys()) & privparams if len(product) == 0: p, q = rsa_recover_prime_factors(n, e, d) priv_numbers = RSAPrivateNumbers( d=d, p=p, q=q, dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), public_numbers=pub_numbers, ) elif product == privparams: priv_numbers = RSAPrivateNumbers( d=d, p=uint_b64decode(dct["p"]), q=uint_b64decode(dct["q"]), dmp1=uint_b64decode(dct["dp"]), dmq1=uint_b64decode(dct["dq"]), iqmp=uint_b64decode(dct["qi"]), public_numbers=pub_numbers, ) else: # If the producer includes any of the other private key parameters, # then all of the others MUST be present, with the exception of # "oth", which MUST only be present when more than two prime # factors were used. raise MalformedJWKError( "p, q, dp, dq, qi MUST be present or" "all of them MUST be absent" ) return cls(priv_numbers.private_key(backend=default_backend()), **dct) def supported_key_types() -> dict[str, type[AbstractJWKBase]]: return { "oct": OctetJWK, "RSA": RSAJWK, } def jwk_from_dict(dct: Mapping[str, Any]) -> AbstractJWKBase: if not isinstance(dct, dict): # pragma: no cover raise TypeError("dct must be a dict") if "kty" not in dct: raise MalformedJWKError("kty MUST be present") supported = supported_key_types() kty = dct["kty"] if kty not in supported: raise UnsupportedKeyTypeError(f"unsupported key type: {kty}") return supported[kty].from_dict(dct) PublicKeyLoaderT = Union[str, Callable[[bytes, object], object]] PrivateKeyLoaderT = Union[ str, Callable[[bytes, Optional[str], object], object] ] _Loader = TypeVar("_Loader", PublicKeyLoaderT, PrivateKeyLoaderT) _C = TypeVar("_C", bound=Callable[..., Any]) # The above LoaderTs should actually not be Union, and this function should be # typed something like this. But, this will lose any kwargs from the typing # information. Probably needs: https://github.com/python/mypy/issues/3157 # (func: Callable[[bytes, _Loader], _T]) # -> Callable[[bytes, Union[str, _Loader]], _T] def jwk_from_bytes_argument_conversion(func: _C) -> _C: if not ("private" in func.__name__ or "public" in func.__name__): raise Exception( "the wrapped function must have either public" " or private in it's name" ) @wraps(func) def wrapper(content, loader, **kwargs): # now convert it to a Callable if it's a string if isinstance(loader, str): loader = getattr(serialization_module, loader) if kwargs.get("options") is None: kwargs["options"] = {} return func(content, loader, **kwargs) return wrapper # type: ignore[return-value] @jwk_from_bytes_argument_conversion def jwk_from_private_bytes( content: bytes, private_loader: PrivateKeyLoaderT, *, password: Optional[str] = None, backend: Optional[object] = None, options: Optional[Mapping[str, object]] = None, ) -> AbstractJWKBase: """This function is meant to be called from jwk_from_bytes""" if options is None: options = {} try: privkey = private_loader(content, password, backend) # type: ignore[operator] # noqa: E501 if isinstance(privkey, RSAPrivateKey): return RSAJWK(privkey, **options) raise UnsupportedKeyTypeError("unsupported key type") except ValueError as ex: raise UnsupportedKeyTypeError("this is probably a public key") from ex @jwk_from_bytes_argument_conversion def jwk_from_public_bytes( content: bytes, public_loader: PublicKeyLoaderT, *, backend: Optional[object] = None, options: Optional[Mapping[str, object]] = None, ) -> AbstractJWKBase: """This function is meant to be called from jwk_from_bytes""" if options is None: options = {} try: pubkey = public_loader(content, backend) # type: ignore[operator] if isinstance(pubkey, RSAPublicKey): return RSAJWK(pubkey, **options) raise UnsupportedKeyTypeError( "unsupported key type" ) # pragma: no cover except ValueError as why: raise UnsupportedKeyTypeError("could not deserialize") from why def jwk_from_bytes( content: bytes, private_loader: PrivateKeyLoaderT, public_loader: PublicKeyLoaderT, *, private_password: Optional[str] = None, backend: Optional[object] = None, options: Optional[Mapping[str, object]] = None, ) -> AbstractJWKBase: try: return jwk_from_private_bytes( content, private_loader, password=private_password, backend=backend, options=options, ) except UnsupportedKeyTypeError: return jwk_from_public_bytes( content, public_loader, backend=backend, options=options, ) def jwk_from_pem( pem_content: bytes, private_password: Optional[str] = None, options: Optional[Mapping[str, object]] = None, ) -> AbstractJWKBase: return jwk_from_bytes( pem_content, private_loader="load_pem_private_key", public_loader="load_pem_public_key", private_password=private_password, backend=None, options=options, ) def jwk_from_der( der_content: bytes, private_password: Optional[str] = None, options: Optional[Mapping[str, object]] = None, ) -> AbstractJWKBase: return jwk_from_bytes( der_content, private_loader="load_der_private_key", public_loader="load_der_public_key", private_password=private_password, backend=None, options=options, )