25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

450 lines
14 KiB

  1. # Copyright 2017 Gehirn Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import hmac
  15. from abc import (
  16. ABC,
  17. abstractmethod,
  18. )
  19. from collections.abc import Mapping
  20. from functools import wraps
  21. from typing import (
  22. Any,
  23. Callable,
  24. Optional,
  25. TypeVar,
  26. Union,
  27. )
  28. from warnings import warn
  29. import cryptography.hazmat.primitives.serialization as serialization_module
  30. from cryptography.exceptions import InvalidSignature
  31. from cryptography.hazmat.backends import default_backend
  32. from cryptography.hazmat.primitives.asymmetric import padding
  33. from cryptography.hazmat.primitives.asymmetric.rsa import (
  34. RSAPrivateKey,
  35. RSAPrivateNumbers,
  36. RSAPublicKey,
  37. RSAPublicNumbers,
  38. rsa_crt_dmp1,
  39. rsa_crt_dmq1,
  40. rsa_crt_iqmp,
  41. rsa_recover_prime_factors,
  42. )
  43. from cryptography.hazmat.primitives.hashes import HashAlgorithm
  44. from .exceptions import (
  45. MalformedJWKError,
  46. UnsupportedKeyTypeError,
  47. )
  48. from .utils import (
  49. b64decode,
  50. b64encode,
  51. uint_b64decode,
  52. uint_b64encode,
  53. )
  54. _AJWK = TypeVar("_AJWK", bound="AbstractJWKBase")
  55. _T = TypeVar("_T")
  56. class AbstractJWKBase(ABC):
  57. @abstractmethod
  58. def get_kty(self) -> str:
  59. pass # pragma: no cover
  60. @abstractmethod
  61. def get_kid(self) -> str:
  62. pass # pragma: no cover
  63. @abstractmethod
  64. def is_sign_key(self) -> bool:
  65. pass # pragma: no cover
  66. @abstractmethod
  67. def sign(self, message: bytes, **options) -> bytes:
  68. pass # pragma: no cover
  69. @abstractmethod
  70. def verify(self, message: bytes, signature: bytes, **options) -> bool:
  71. pass # pragma: no cover
  72. @abstractmethod
  73. def to_dict(self, public_only: bool = True) -> dict[str, str]:
  74. pass # pragma: no cover
  75. @classmethod
  76. @abstractmethod
  77. def from_dict(cls: type[_AJWK], dct: dict[str, object]) -> _AJWK:
  78. pass # pragma: no cover
  79. class OctetJWK(AbstractJWKBase):
  80. def __init__(self, key: bytes, kid=None, **options) -> None:
  81. super(AbstractJWKBase, self).__init__()
  82. self.key = key
  83. self.kid = kid
  84. optnames = {"use", "key_ops", "alg", "x5u", "x5c", "x5t", "x5t#s256"}
  85. self.options = {k: v for k, v in options.items() if k in optnames}
  86. def get_kty(self):
  87. return "oct"
  88. def get_kid(self):
  89. return self.kid
  90. def is_sign_key(self) -> bool:
  91. return True
  92. def _get_signer(self, options) -> Callable[[bytes, bytes], bytes]:
  93. return options["signer"]
  94. def sign(self, message: bytes, **options) -> bytes:
  95. signer = self._get_signer(options)
  96. return signer(message, self.key)
  97. def verify(self, message: bytes, signature: bytes, **options) -> bool:
  98. signer = self._get_signer(options)
  99. return hmac.compare_digest(signature, signer(message, self.key))
  100. def to_dict(self, public_only=True):
  101. dct = {
  102. "kty": "oct",
  103. "k": b64encode(self.key),
  104. }
  105. dct.update(self.options)
  106. if self.kid:
  107. dct["kid"] = self.kid
  108. return dct
  109. @classmethod
  110. def from_dict(cls, dct):
  111. try:
  112. return cls(b64decode(dct["k"]), **dct)
  113. except KeyError as why:
  114. raise MalformedJWKError("k is required") from why
  115. class RSAJWK(AbstractJWKBase):
  116. """
  117. https://tools.ietf.org/html/rfc7518.html#section-6.3.1
  118. """
  119. def __init__(
  120. self, keyobj: Union[RSAPrivateKey, RSAPublicKey], **options
  121. ) -> None:
  122. super(AbstractJWKBase, self).__init__()
  123. self.keyobj = keyobj
  124. optnames = {
  125. "use",
  126. "key_ops",
  127. "alg",
  128. "kid",
  129. "x5u",
  130. "x5c",
  131. "x5t",
  132. "x5t#s256",
  133. }
  134. self.options = {k: v for k, v in options.items() if k in optnames}
  135. def is_sign_key(self) -> bool:
  136. return isinstance(self.keyobj, RSAPrivateKey)
  137. def _get_hash_fun(self, options) -> Callable[[], HashAlgorithm]:
  138. return options["hash_fun"]
  139. def _get_padding(self, options) -> padding.AsymmetricPadding:
  140. try:
  141. return options["padding"]
  142. except KeyError:
  143. warn(
  144. "you should not use RSAJWK.verify/sign without jwa "
  145. "intermiediary, used legacy padding"
  146. )
  147. return padding.PKCS1v15()
  148. def sign(self, message: bytes, **options) -> bytes:
  149. if isinstance(self.keyobj, RSAPublicKey):
  150. raise ValueError("Requires a private key.")
  151. hash_fun = self._get_hash_fun(options)
  152. _padding = self._get_padding(options)
  153. return self.keyobj.sign(message, _padding, hash_fun())
  154. def verify(self, message: bytes, signature: bytes, **options) -> bool:
  155. hash_fun = self._get_hash_fun(options)
  156. _padding = self._get_padding(options)
  157. if isinstance(self.keyobj, RSAPrivateKey):
  158. pubkey = self.keyobj.public_key()
  159. else:
  160. pubkey = self.keyobj
  161. try:
  162. pubkey.verify(signature, message, _padding, hash_fun())
  163. return True
  164. except InvalidSignature:
  165. return False
  166. def get_kty(self):
  167. return "RSA"
  168. def get_kid(self):
  169. return self.options.get("kid")
  170. def to_dict(self, public_only=True):
  171. dct = {
  172. "kty": "RSA",
  173. }
  174. dct.update(self.options)
  175. if isinstance(self.keyobj, RSAPrivateKey):
  176. priv_numbers = self.keyobj.private_numbers()
  177. pub_numbers = priv_numbers.public_numbers
  178. dct.update(
  179. {
  180. "e": uint_b64encode(pub_numbers.e),
  181. "n": uint_b64encode(pub_numbers.n),
  182. }
  183. )
  184. if not public_only:
  185. dct.update(
  186. {
  187. "e": uint_b64encode(pub_numbers.e),
  188. "n": uint_b64encode(pub_numbers.n),
  189. "d": uint_b64encode(priv_numbers.d),
  190. "p": uint_b64encode(priv_numbers.p),
  191. "q": uint_b64encode(priv_numbers.q),
  192. "dp": uint_b64encode(priv_numbers.dmp1),
  193. "dq": uint_b64encode(priv_numbers.dmq1),
  194. "qi": uint_b64encode(priv_numbers.iqmp),
  195. }
  196. )
  197. return dct
  198. pub_numbers = self.keyobj.public_numbers()
  199. dct.update(
  200. {
  201. "e": uint_b64encode(pub_numbers.e),
  202. "n": uint_b64encode(pub_numbers.n),
  203. }
  204. )
  205. return dct
  206. @classmethod
  207. def from_dict(cls, dct):
  208. if "oth" in dct:
  209. raise UnsupportedKeyTypeError(
  210. "RSA keys with multiples primes are not supported"
  211. )
  212. try:
  213. e = uint_b64decode(dct["e"])
  214. n = uint_b64decode(dct["n"])
  215. except KeyError as why:
  216. raise MalformedJWKError("e and n are required") from why
  217. pub_numbers = RSAPublicNumbers(e, n)
  218. if "d" not in dct:
  219. return cls(
  220. pub_numbers.public_key(backend=default_backend()), **dct
  221. )
  222. d = uint_b64decode(dct["d"])
  223. privparams = {"p", "q", "dp", "dq", "qi"}
  224. product = set(dct.keys()) & privparams
  225. if len(product) == 0:
  226. p, q = rsa_recover_prime_factors(n, e, d)
  227. priv_numbers = RSAPrivateNumbers(
  228. d=d,
  229. p=p,
  230. q=q,
  231. dmp1=rsa_crt_dmp1(d, p),
  232. dmq1=rsa_crt_dmq1(d, q),
  233. iqmp=rsa_crt_iqmp(p, q),
  234. public_numbers=pub_numbers,
  235. )
  236. elif product == privparams:
  237. priv_numbers = RSAPrivateNumbers(
  238. d=d,
  239. p=uint_b64decode(dct["p"]),
  240. q=uint_b64decode(dct["q"]),
  241. dmp1=uint_b64decode(dct["dp"]),
  242. dmq1=uint_b64decode(dct["dq"]),
  243. iqmp=uint_b64decode(dct["qi"]),
  244. public_numbers=pub_numbers,
  245. )
  246. else:
  247. # If the producer includes any of the other private key parameters,
  248. # then all of the others MUST be present, with the exception of
  249. # "oth", which MUST only be present when more than two prime
  250. # factors were used.
  251. raise MalformedJWKError(
  252. "p, q, dp, dq, qi MUST be present or"
  253. "all of them MUST be absent"
  254. )
  255. return cls(priv_numbers.private_key(backend=default_backend()), **dct)
  256. def supported_key_types() -> dict[str, type[AbstractJWKBase]]:
  257. return {
  258. "oct": OctetJWK,
  259. "RSA": RSAJWK,
  260. }
  261. def jwk_from_dict(dct: Mapping[str, Any]) -> AbstractJWKBase:
  262. if not isinstance(dct, dict): # pragma: no cover
  263. raise TypeError("dct must be a dict")
  264. if "kty" not in dct:
  265. raise MalformedJWKError("kty MUST be present")
  266. supported = supported_key_types()
  267. kty = dct["kty"]
  268. if kty not in supported:
  269. raise UnsupportedKeyTypeError(f"unsupported key type: {kty}")
  270. return supported[kty].from_dict(dct)
  271. PublicKeyLoaderT = Union[str, Callable[[bytes, object], object]]
  272. PrivateKeyLoaderT = Union[
  273. str, Callable[[bytes, Optional[str], object], object]
  274. ]
  275. _Loader = TypeVar("_Loader", PublicKeyLoaderT, PrivateKeyLoaderT)
  276. _C = TypeVar("_C", bound=Callable[..., Any])
  277. # The above LoaderTs should actually not be Union, and this function should be
  278. # typed something like this. But, this will lose any kwargs from the typing
  279. # information. Probably needs: https://github.com/python/mypy/issues/3157
  280. # (func: Callable[[bytes, _Loader], _T])
  281. # -> Callable[[bytes, Union[str, _Loader]], _T]
  282. def jwk_from_bytes_argument_conversion(func: _C) -> _C:
  283. if not ("private" in func.__name__ or "public" in func.__name__):
  284. raise Exception(
  285. "the wrapped function must have either public"
  286. " or private in it's name"
  287. )
  288. @wraps(func)
  289. def wrapper(content, loader, **kwargs):
  290. # now convert it to a Callable if it's a string
  291. if isinstance(loader, str):
  292. loader = getattr(serialization_module, loader)
  293. if kwargs.get("options") is None:
  294. kwargs["options"] = {}
  295. return func(content, loader, **kwargs)
  296. return wrapper # type: ignore[return-value]
  297. @jwk_from_bytes_argument_conversion
  298. def jwk_from_private_bytes(
  299. content: bytes,
  300. private_loader: PrivateKeyLoaderT,
  301. *,
  302. password: Optional[str] = None,
  303. backend: Optional[object] = None,
  304. options: Optional[Mapping[str, object]] = None,
  305. ) -> AbstractJWKBase:
  306. """This function is meant to be called from jwk_from_bytes"""
  307. if options is None:
  308. options = {}
  309. try:
  310. privkey = private_loader(content, password, backend) # type: ignore[operator] # noqa: E501
  311. if isinstance(privkey, RSAPrivateKey):
  312. return RSAJWK(privkey, **options)
  313. raise UnsupportedKeyTypeError("unsupported key type")
  314. except ValueError as ex:
  315. raise UnsupportedKeyTypeError("this is probably a public key") from ex
  316. @jwk_from_bytes_argument_conversion
  317. def jwk_from_public_bytes(
  318. content: bytes,
  319. public_loader: PublicKeyLoaderT,
  320. *,
  321. backend: Optional[object] = None,
  322. options: Optional[Mapping[str, object]] = None,
  323. ) -> AbstractJWKBase:
  324. """This function is meant to be called from jwk_from_bytes"""
  325. if options is None:
  326. options = {}
  327. try:
  328. pubkey = public_loader(content, backend) # type: ignore[operator]
  329. if isinstance(pubkey, RSAPublicKey):
  330. return RSAJWK(pubkey, **options)
  331. raise UnsupportedKeyTypeError(
  332. "unsupported key type"
  333. ) # pragma: no cover
  334. except ValueError as why:
  335. raise UnsupportedKeyTypeError("could not deserialize") from why
  336. def jwk_from_bytes(
  337. content: bytes,
  338. private_loader: PrivateKeyLoaderT,
  339. public_loader: PublicKeyLoaderT,
  340. *,
  341. private_password: Optional[str] = None,
  342. backend: Optional[object] = None,
  343. options: Optional[Mapping[str, object]] = None,
  344. ) -> AbstractJWKBase:
  345. try:
  346. return jwk_from_private_bytes(
  347. content,
  348. private_loader,
  349. password=private_password,
  350. backend=backend,
  351. options=options,
  352. )
  353. except UnsupportedKeyTypeError:
  354. return jwk_from_public_bytes(
  355. content,
  356. public_loader,
  357. backend=backend,
  358. options=options,
  359. )
  360. def jwk_from_pem(
  361. pem_content: bytes,
  362. private_password: Optional[str] = None,
  363. options: Optional[Mapping[str, object]] = None,
  364. ) -> AbstractJWKBase:
  365. return jwk_from_bytes(
  366. pem_content,
  367. private_loader="load_pem_private_key",
  368. public_loader="load_pem_public_key",
  369. private_password=private_password,
  370. backend=None,
  371. options=options,
  372. )
  373. def jwk_from_der(
  374. der_content: bytes,
  375. private_password: Optional[str] = None,
  376. options: Optional[Mapping[str, object]] = None,
  377. ) -> AbstractJWKBase:
  378. return jwk_from_bytes(
  379. der_content,
  380. private_loader="load_der_private_key",
  381. public_loader="load_der_public_key",
  382. private_password=private_password,
  383. backend=None,
  384. options=options,
  385. )