|
- # Copyright 2017 Gehirn Inc.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import hmac
- from abc import (
- ABC,
- abstractmethod,
- )
- from collections.abc import Mapping
- from functools import wraps
- from typing import (
- Any,
- Callable,
- Optional,
- TypeVar,
- Union,
- )
- from warnings import warn
-
- import cryptography.hazmat.primitives.serialization as serialization_module
- from cryptography.exceptions import InvalidSignature
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives.asymmetric import padding
- from cryptography.hazmat.primitives.asymmetric.rsa import (
- RSAPrivateKey,
- RSAPrivateNumbers,
- RSAPublicKey,
- RSAPublicNumbers,
- rsa_crt_dmp1,
- rsa_crt_dmq1,
- rsa_crt_iqmp,
- rsa_recover_prime_factors,
- )
- from cryptography.hazmat.primitives.hashes import HashAlgorithm
-
- from .exceptions import (
- MalformedJWKError,
- UnsupportedKeyTypeError,
- )
- from .utils import (
- b64decode,
- b64encode,
- uint_b64decode,
- uint_b64encode,
- )
-
- _AJWK = TypeVar("_AJWK", bound="AbstractJWKBase")
- _T = TypeVar("_T")
-
-
- class AbstractJWKBase(ABC):
-
- @abstractmethod
- def get_kty(self) -> str:
- pass # pragma: no cover
-
- @abstractmethod
- def get_kid(self) -> str:
- pass # pragma: no cover
-
- @abstractmethod
- def is_sign_key(self) -> bool:
- pass # pragma: no cover
-
- @abstractmethod
- def sign(self, message: bytes, **options) -> bytes:
- pass # pragma: no cover
-
- @abstractmethod
- def verify(self, message: bytes, signature: bytes, **options) -> bool:
- pass # pragma: no cover
-
- @abstractmethod
- def to_dict(self, public_only: bool = True) -> dict[str, str]:
- pass # pragma: no cover
-
- @classmethod
- @abstractmethod
- def from_dict(cls: type[_AJWK], dct: dict[str, object]) -> _AJWK:
- pass # pragma: no cover
-
-
- class OctetJWK(AbstractJWKBase):
-
- def __init__(self, key: bytes, kid=None, **options) -> None:
- super(AbstractJWKBase, self).__init__()
- self.key = key
- self.kid = kid
-
- optnames = {"use", "key_ops", "alg", "x5u", "x5c", "x5t", "x5t#s256"}
- self.options = {k: v for k, v in options.items() if k in optnames}
-
- def get_kty(self):
- return "oct"
-
- def get_kid(self):
- return self.kid
-
- def is_sign_key(self) -> bool:
- return True
-
- def _get_signer(self, options) -> Callable[[bytes, bytes], bytes]:
- return options["signer"]
-
- def sign(self, message: bytes, **options) -> bytes:
- signer = self._get_signer(options)
- return signer(message, self.key)
-
- def verify(self, message: bytes, signature: bytes, **options) -> bool:
- signer = self._get_signer(options)
- return hmac.compare_digest(signature, signer(message, self.key))
-
- def to_dict(self, public_only=True):
- dct = {
- "kty": "oct",
- "k": b64encode(self.key),
- }
- dct.update(self.options)
- if self.kid:
- dct["kid"] = self.kid
- return dct
-
- @classmethod
- def from_dict(cls, dct):
- try:
- return cls(b64decode(dct["k"]), **dct)
- except KeyError as why:
- raise MalformedJWKError("k is required") from why
-
-
- class RSAJWK(AbstractJWKBase):
- """
- https://tools.ietf.org/html/rfc7518.html#section-6.3.1
- """
-
- def __init__(
- self, keyobj: Union[RSAPrivateKey, RSAPublicKey], **options
- ) -> None:
- super(AbstractJWKBase, self).__init__()
- self.keyobj = keyobj
-
- optnames = {
- "use",
- "key_ops",
- "alg",
- "kid",
- "x5u",
- "x5c",
- "x5t",
- "x5t#s256",
- }
- self.options = {k: v for k, v in options.items() if k in optnames}
-
- def is_sign_key(self) -> bool:
- return isinstance(self.keyobj, RSAPrivateKey)
-
- def _get_hash_fun(self, options) -> Callable[[], HashAlgorithm]:
- return options["hash_fun"]
-
- def _get_padding(self, options) -> padding.AsymmetricPadding:
- try:
- return options["padding"]
- except KeyError:
- warn(
- "you should not use RSAJWK.verify/sign without jwa "
- "intermiediary, used legacy padding"
- )
- return padding.PKCS1v15()
-
- def sign(self, message: bytes, **options) -> bytes:
- if isinstance(self.keyobj, RSAPublicKey):
- raise ValueError("Requires a private key.")
- hash_fun = self._get_hash_fun(options)
- _padding = self._get_padding(options)
- return self.keyobj.sign(message, _padding, hash_fun())
-
- def verify(self, message: bytes, signature: bytes, **options) -> bool:
- hash_fun = self._get_hash_fun(options)
- _padding = self._get_padding(options)
- if isinstance(self.keyobj, RSAPrivateKey):
- pubkey = self.keyobj.public_key()
- else:
- pubkey = self.keyobj
- try:
- pubkey.verify(signature, message, _padding, hash_fun())
- return True
- except InvalidSignature:
- return False
-
- def get_kty(self):
- return "RSA"
-
- def get_kid(self):
- return self.options.get("kid")
-
- def to_dict(self, public_only=True):
- dct = {
- "kty": "RSA",
- }
- dct.update(self.options)
-
- if isinstance(self.keyobj, RSAPrivateKey):
- priv_numbers = self.keyobj.private_numbers()
- pub_numbers = priv_numbers.public_numbers
- dct.update(
- {
- "e": uint_b64encode(pub_numbers.e),
- "n": uint_b64encode(pub_numbers.n),
- }
- )
- if not public_only:
- dct.update(
- {
- "e": uint_b64encode(pub_numbers.e),
- "n": uint_b64encode(pub_numbers.n),
- "d": uint_b64encode(priv_numbers.d),
- "p": uint_b64encode(priv_numbers.p),
- "q": uint_b64encode(priv_numbers.q),
- "dp": uint_b64encode(priv_numbers.dmp1),
- "dq": uint_b64encode(priv_numbers.dmq1),
- "qi": uint_b64encode(priv_numbers.iqmp),
- }
- )
- return dct
- pub_numbers = self.keyobj.public_numbers()
- dct.update(
- {
- "e": uint_b64encode(pub_numbers.e),
- "n": uint_b64encode(pub_numbers.n),
- }
- )
- return dct
-
- @classmethod
- def from_dict(cls, dct):
- if "oth" in dct:
- raise UnsupportedKeyTypeError(
- "RSA keys with multiples primes are not supported"
- )
-
- try:
- e = uint_b64decode(dct["e"])
- n = uint_b64decode(dct["n"])
- except KeyError as why:
- raise MalformedJWKError("e and n are required") from why
- pub_numbers = RSAPublicNumbers(e, n)
- if "d" not in dct:
- return cls(
- pub_numbers.public_key(backend=default_backend()), **dct
- )
- d = uint_b64decode(dct["d"])
-
- privparams = {"p", "q", "dp", "dq", "qi"}
- product = set(dct.keys()) & privparams
- if len(product) == 0:
- p, q = rsa_recover_prime_factors(n, e, d)
- priv_numbers = RSAPrivateNumbers(
- d=d,
- p=p,
- q=q,
- dmp1=rsa_crt_dmp1(d, p),
- dmq1=rsa_crt_dmq1(d, q),
- iqmp=rsa_crt_iqmp(p, q),
- public_numbers=pub_numbers,
- )
- elif product == privparams:
- priv_numbers = RSAPrivateNumbers(
- d=d,
- p=uint_b64decode(dct["p"]),
- q=uint_b64decode(dct["q"]),
- dmp1=uint_b64decode(dct["dp"]),
- dmq1=uint_b64decode(dct["dq"]),
- iqmp=uint_b64decode(dct["qi"]),
- public_numbers=pub_numbers,
- )
- else:
- # If the producer includes any of the other private key parameters,
- # then all of the others MUST be present, with the exception of
- # "oth", which MUST only be present when more than two prime
- # factors were used.
- raise MalformedJWKError(
- "p, q, dp, dq, qi MUST be present or"
- "all of them MUST be absent"
- )
- return cls(priv_numbers.private_key(backend=default_backend()), **dct)
-
-
- def supported_key_types() -> dict[str, type[AbstractJWKBase]]:
- return {
- "oct": OctetJWK,
- "RSA": RSAJWK,
- }
-
-
- def jwk_from_dict(dct: Mapping[str, Any]) -> AbstractJWKBase:
- if not isinstance(dct, dict): # pragma: no cover
- raise TypeError("dct must be a dict")
- if "kty" not in dct:
- raise MalformedJWKError("kty MUST be present")
-
- supported = supported_key_types()
- kty = dct["kty"]
- if kty not in supported:
- raise UnsupportedKeyTypeError(f"unsupported key type: {kty}")
- return supported[kty].from_dict(dct)
-
-
- PublicKeyLoaderT = Union[str, Callable[[bytes, object], object]]
- PrivateKeyLoaderT = Union[
- str, Callable[[bytes, Optional[str], object], object]
- ]
- _Loader = TypeVar("_Loader", PublicKeyLoaderT, PrivateKeyLoaderT)
- _C = TypeVar("_C", bound=Callable[..., Any])
-
-
- # The above LoaderTs should actually not be Union, and this function should be
- # typed something like this. But, this will lose any kwargs from the typing
- # information. Probably needs: https://github.com/python/mypy/issues/3157
- # (func: Callable[[bytes, _Loader], _T])
- # -> Callable[[bytes, Union[str, _Loader]], _T]
- def jwk_from_bytes_argument_conversion(func: _C) -> _C:
- if not ("private" in func.__name__ or "public" in func.__name__):
- raise Exception(
- "the wrapped function must have either public"
- " or private in it's name"
- )
-
- @wraps(func)
- def wrapper(content, loader, **kwargs):
- # now convert it to a Callable if it's a string
- if isinstance(loader, str):
- loader = getattr(serialization_module, loader)
-
- if kwargs.get("options") is None:
- kwargs["options"] = {}
-
- return func(content, loader, **kwargs)
-
- return wrapper # type: ignore[return-value]
-
-
- @jwk_from_bytes_argument_conversion
- def jwk_from_private_bytes(
- content: bytes,
- private_loader: PrivateKeyLoaderT,
- *,
- password: Optional[str] = None,
- backend: Optional[object] = None,
- options: Optional[Mapping[str, object]] = None,
- ) -> AbstractJWKBase:
- """This function is meant to be called from jwk_from_bytes"""
- if options is None:
- options = {}
- try:
- privkey = private_loader(content, password, backend) # type: ignore[operator] # noqa: E501
- if isinstance(privkey, RSAPrivateKey):
- return RSAJWK(privkey, **options)
- raise UnsupportedKeyTypeError("unsupported key type")
- except ValueError as ex:
- raise UnsupportedKeyTypeError("this is probably a public key") from ex
-
-
- @jwk_from_bytes_argument_conversion
- def jwk_from_public_bytes(
- content: bytes,
- public_loader: PublicKeyLoaderT,
- *,
- backend: Optional[object] = None,
- options: Optional[Mapping[str, object]] = None,
- ) -> AbstractJWKBase:
- """This function is meant to be called from jwk_from_bytes"""
- if options is None:
- options = {}
- try:
- pubkey = public_loader(content, backend) # type: ignore[operator]
- if isinstance(pubkey, RSAPublicKey):
- return RSAJWK(pubkey, **options)
- raise UnsupportedKeyTypeError(
- "unsupported key type"
- ) # pragma: no cover
- except ValueError as why:
- raise UnsupportedKeyTypeError("could not deserialize") from why
-
-
- def jwk_from_bytes(
- content: bytes,
- private_loader: PrivateKeyLoaderT,
- public_loader: PublicKeyLoaderT,
- *,
- private_password: Optional[str] = None,
- backend: Optional[object] = None,
- options: Optional[Mapping[str, object]] = None,
- ) -> AbstractJWKBase:
- try:
- return jwk_from_private_bytes(
- content,
- private_loader,
- password=private_password,
- backend=backend,
- options=options,
- )
- except UnsupportedKeyTypeError:
- return jwk_from_public_bytes(
- content,
- public_loader,
- backend=backend,
- options=options,
- )
-
-
- def jwk_from_pem(
- pem_content: bytes,
- private_password: Optional[str] = None,
- options: Optional[Mapping[str, object]] = None,
- ) -> AbstractJWKBase:
- return jwk_from_bytes(
- pem_content,
- private_loader="load_pem_private_key",
- public_loader="load_pem_public_key",
- private_password=private_password,
- backend=None,
- options=options,
- )
-
-
- def jwk_from_der(
- der_content: bytes,
- private_password: Optional[str] = None,
- options: Optional[Mapping[str, object]] = None,
- ) -> AbstractJWKBase:
- return jwk_from_bytes(
- der_content,
- private_loader="load_der_private_key",
- public_loader="load_der_public_key",
- private_password=private_password,
- backend=None,
- options=options,
- )
|