You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

282 lines
8.7 KiB

  1. # -*- coding: utf-8 -*-
  2. #
  3. # Copyright 2017 Gehirn Inc.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import hmac
  17. from typing import (
  18. Callable,
  19. Union,
  20. )
  21. from cryptography.exceptions import InvalidSignature
  22. from cryptography.hazmat.backends import default_backend
  23. from cryptography.hazmat.primitives.asymmetric import padding
  24. from cryptography.hazmat.primitives.asymmetric.rsa import (
  25. rsa_crt_dmp1,
  26. rsa_crt_dmq1,
  27. rsa_crt_iqmp,
  28. rsa_recover_prime_factors,
  29. RSAPrivateKey,
  30. RSAPrivateNumbers,
  31. RSAPublicKey,
  32. RSAPublicNumbers,
  33. )
  34. from cryptography.hazmat.primitives.serialization import (
  35. load_pem_private_key,
  36. load_pem_public_key,
  37. )
  38. from .exceptions import (
  39. MalformedJWKError,
  40. UnsupportedKeyTypeError,
  41. )
  42. from .utils import (
  43. b64encode,
  44. b64decode,
  45. uint_b64encode,
  46. uint_b64decode,
  47. )
  48. class AbstractJWKBase:
  49. def get_kty(self):
  50. raise NotImplementedError() # pragma: no cover
  51. def get_kid(self):
  52. raise NotImplementedError() # pragma: no cover
  53. def is_sign_key(self) -> bool:
  54. raise NotImplementedError() # pragma: no cover
  55. def sign(self, message: bytes, **options) -> bytes:
  56. raise NotImplementedError() # pragma: no cover
  57. def verify(self, message: bytes, signature: bytes, **options) -> bool:
  58. raise NotImplementedError() # pragma: no cover
  59. def to_dict(self, public_only=True):
  60. raise NotImplementedError() # pragma: no cover
  61. @classmethod
  62. def from_dict(cls, dct):
  63. raise NotImplementedError() # pragma: no cover
  64. class OctetJWK(AbstractJWKBase):
  65. def __init__(self, key: bytes, kid=None, **options) -> None:
  66. super(AbstractJWKBase, self).__init__()
  67. self.key = key
  68. self.kid = kid
  69. optnames = {'use', 'key_ops', 'alg', 'x5u', 'x5c', 'x5t', 'x5t#s256'}
  70. self.options = {k: v for k, v in options.items() if k in optnames}
  71. def get_kty(self):
  72. return 'oct'
  73. def get_kid(self):
  74. return self.kid
  75. def is_sign_key(self) -> bool:
  76. return True
  77. def sign(self, message: bytes,
  78. signer: Callable[[bytes, bytes], bytes] = None,
  79. **options) -> bytes:
  80. return signer(message, self.key)
  81. def verify(self, message: bytes, signature: bytes,
  82. signer: Callable[[bytes, bytes], bytes] = None,
  83. **options) -> bool:
  84. return hmac.compare_digest(signature, signer(message, self.key))
  85. def to_dict(self, public_only=True):
  86. dct = {
  87. 'kty': 'oct',
  88. 'k': b64encode(self.key),
  89. }
  90. dct.update(self.options)
  91. if self.kid:
  92. dct['kid'] = self.kid
  93. return dct
  94. @classmethod
  95. def from_dict(cls, dct):
  96. try:
  97. return cls(b64decode(dct['k']), **dct)
  98. except KeyError as why:
  99. raise MalformedJWKError('k is required')
  100. class RSAJWK(AbstractJWKBase):
  101. """
  102. https://tools.ietf.org/html/rfc7518.html#section-6.3.1
  103. """
  104. def __init__(self, keyobj: Union[RSAPrivateKey, RSAPublicKey],
  105. **options) -> None:
  106. super(AbstractJWKBase, self).__init__()
  107. self.keyobj = keyobj
  108. optnames = {'use', 'key_ops', 'alg', 'kid',
  109. 'x5u', 'x5c', 'x5t', 'x5t#s256'}
  110. self.options = {k: v for k, v in options.items() if k in optnames}
  111. def is_sign_key(self) -> bool:
  112. return isinstance(self.keyobj, RSAPrivateKey)
  113. def sign(self, message: bytes, hash_fun: Callable = None,
  114. **options) -> bytes:
  115. return self.keyobj.sign(message, padding.PKCS1v15(), hash_fun())
  116. def verify(self, message: bytes, signature: bytes,
  117. hash_fun: Callable = None, **options) -> bool:
  118. if self.is_sign_key():
  119. pubkey = self.keyobj.public_key()
  120. else:
  121. pubkey = self.keyobj
  122. try:
  123. pubkey.verify(signature, message, padding.PKCS1v15(), hash_fun())
  124. return True
  125. except InvalidSignature:
  126. return False
  127. def get_kty(self):
  128. return 'RSA'
  129. def get_kid(self):
  130. return self.options.get('kid')
  131. def to_dict(self, public_only=True):
  132. dct = {
  133. 'kty': 'RSA',
  134. }
  135. dct.update(self.options)
  136. if isinstance(self.keyobj, RSAPrivateKey):
  137. priv_numbers = self.keyobj.private_numbers()
  138. pub_numbers = priv_numbers.public_numbers
  139. dct.update({
  140. 'e': uint_b64encode(pub_numbers.e),
  141. 'n': uint_b64encode(pub_numbers.n),
  142. })
  143. if not public_only:
  144. dct.update({
  145. 'e': uint_b64encode(pub_numbers.e),
  146. 'n': uint_b64encode(pub_numbers.n),
  147. 'd': uint_b64encode(priv_numbers.d),
  148. 'p': uint_b64encode(priv_numbers.p),
  149. 'q': uint_b64encode(priv_numbers.q),
  150. 'dp': uint_b64encode(priv_numbers.dmp1),
  151. 'dq': uint_b64encode(priv_numbers.dmq1),
  152. 'qi': uint_b64encode(priv_numbers.iqmp),
  153. })
  154. return dct
  155. pub_numbers = self.keyobj.public_numbers()
  156. dct.update({
  157. 'e': uint_b64encode(pub_numbers.e),
  158. 'n': uint_b64encode(pub_numbers.n),
  159. })
  160. return dct
  161. @classmethod
  162. def from_dict(cls, dct):
  163. if 'oth' in dct:
  164. raise UnsupportedKeyTypeError(
  165. 'RSA keys with multiples primes are not supported')
  166. try:
  167. e = uint_b64decode(dct['e'])
  168. n = uint_b64decode(dct['n'])
  169. except KeyError as why:
  170. raise MalformedJWKError('e and n are required')
  171. pub_numbers = RSAPublicNumbers(e, n)
  172. if 'd' not in dct:
  173. return cls(
  174. pub_numbers.public_key(backend=default_backend()), **dct)
  175. d = uint_b64decode(dct['d'])
  176. privparams = {'p', 'q', 'dp', 'dq', 'qi'}
  177. product = set(dct.keys()) & privparams
  178. if len(product) == 0:
  179. p, q = rsa_recover_prime_factors(n, e, d)
  180. priv_numbers = RSAPrivateNumbers(
  181. d=d,
  182. p=p,
  183. q=q,
  184. dmp1=rsa_crt_dmp1(d, p),
  185. dmq1=rsa_crt_dmq1(d, q),
  186. iqmp=rsa_crt_iqmp(p, q),
  187. public_numbers=pub_numbers)
  188. elif product == privparams:
  189. priv_numbers = RSAPrivateNumbers(
  190. d=d,
  191. p=uint_b64decode(dct['p']),
  192. q=uint_b64decode(dct['q']),
  193. dmp1=uint_b64decode(dct['dp']),
  194. dmq1=uint_b64decode(dct['dq']),
  195. iqmp=uint_b64decode(dct['qi']),
  196. public_numbers=pub_numbers)
  197. else:
  198. # If the producer includes any of the other private key parameters,
  199. # then all of the others MUST be present, with the exception of
  200. # "oth", which MUST only be present when more than two prime
  201. # factors were used.
  202. raise MalformedJWKError(
  203. 'p, q, dp, dq, qi MUST be present or'
  204. 'all of them MUST be absent')
  205. return cls(priv_numbers.private_key(backend=default_backend()), **dct)
  206. def supported_key_types():
  207. return {
  208. 'oct': OctetJWK,
  209. 'RSA': RSAJWK,
  210. }
  211. def jwk_from_dict(dct):
  212. if 'kty' not in dct:
  213. raise MalformedJWKError('kty MUST be present')
  214. supported = supported_key_types()
  215. kty = dct['kty']
  216. if kty not in supported:
  217. raise UnsupportedKeyTypeError('unsupported key type: {}'.format(kty))
  218. return supported[kty].from_dict(dct)
  219. def jwk_from_pem(pem_content: bytes) -> AbstractJWKBase:
  220. try:
  221. privkey = load_pem_private_key(
  222. pem_content, password=None, backend=default_backend())
  223. if isinstance(privkey, RSAPrivateKey):
  224. return RSAJWK(privkey)
  225. raise UnsupportedKeyTypeError(
  226. 'unsupported key type') # pragma: no cover
  227. except ValueError:
  228. pass
  229. try:
  230. pubkey = load_pem_public_key(pem_content, backend=default_backend())
  231. if isinstance(pubkey, RSAPublicKey):
  232. return RSAJWK(pubkey)
  233. raise UnsupportedKeyTypeError(
  234. 'unsupported key type') # pragma: no cover
  235. except ValueError as why:
  236. raise UnsupportedKeyTypeError('could not deserialize PEM') from why