您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

215 行
5.9 KiB

  1. # Copyright 2017 Gehirn Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import hashlib
  15. import hmac
  16. from typing import (
  17. Any,
  18. Callable,
  19. Optional,
  20. )
  21. from cryptography.hazmat.primitives.asymmetric import padding
  22. from cryptography.hazmat.primitives.hashes import (
  23. SHA256,
  24. SHA384,
  25. SHA512,
  26. )
  27. from .exceptions import InvalidKeyTypeError
  28. from .jwk import AbstractJWKBase
  29. def std_hash_by_alg(alg: str) -> Callable[[bytes], object]:
  30. if alg.endswith("S256"):
  31. return hashlib.sha256
  32. if alg.endswith("S384"):
  33. return hashlib.sha384
  34. if alg.endswith("S512"):
  35. return hashlib.sha512
  36. raise ValueError(f"{alg} is not supported")
  37. class AbstractSigningAlgorithm:
  38. def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
  39. raise NotImplementedError() # pragma: no cover
  40. def verify(
  41. self,
  42. message: bytes,
  43. key: Optional[AbstractJWKBase],
  44. signature: bytes,
  45. ) -> bool:
  46. raise NotImplementedError() # pragma: no cover
  47. class NoneAlgorithm(AbstractSigningAlgorithm):
  48. def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
  49. return b""
  50. def verify(
  51. self,
  52. message: bytes,
  53. key: Optional[AbstractJWKBase],
  54. signature: bytes,
  55. ) -> bool:
  56. return hmac.compare_digest(signature, b"")
  57. none = NoneAlgorithm()
  58. class HMACAlgorithm(AbstractSigningAlgorithm):
  59. def __init__(self, hash_fun: Callable[[], Any]) -> None:
  60. self.hash_fun = hash_fun
  61. def _check_key(self, key: Optional[AbstractJWKBase]) -> AbstractJWKBase:
  62. if not key or key.get_kty() != "oct":
  63. raise InvalidKeyTypeError("Octet key is required")
  64. return key
  65. def _sign(self, message: bytes, key: bytes) -> bytes:
  66. return hmac.new(key, message, self.hash_fun).digest()
  67. def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
  68. key = self._check_key(key)
  69. return key.sign(message, signer=self._sign)
  70. def verify(
  71. self,
  72. message: bytes,
  73. key: Optional[AbstractJWKBase],
  74. signature: bytes,
  75. ) -> bool:
  76. key = self._check_key(key)
  77. return key.verify(message, signature, signer=self._sign)
  78. HS256 = HMACAlgorithm(hashlib.sha256)
  79. HS384 = HMACAlgorithm(hashlib.sha384)
  80. HS512 = HMACAlgorithm(hashlib.sha512)
  81. class RSAAlgorithm(AbstractSigningAlgorithm):
  82. def __init__(self, hash_fun: object) -> None:
  83. self.hash_fun = hash_fun
  84. def _check_key(
  85. self,
  86. key: Optional[AbstractJWKBase],
  87. must_sign_key: bool = False,
  88. ) -> AbstractJWKBase:
  89. if not key or key.get_kty() != "RSA":
  90. raise InvalidKeyTypeError("RSA key is required")
  91. if must_sign_key and not key.is_sign_key():
  92. raise InvalidKeyTypeError(
  93. "a RSA private key is required, but passed is RSA public key"
  94. )
  95. return key
  96. def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
  97. key = self._check_key(key, must_sign_key=True)
  98. return key.sign(
  99. message, hash_fun=self.hash_fun, padding=padding.PKCS1v15()
  100. )
  101. def verify(
  102. self,
  103. message: bytes,
  104. key: Optional[AbstractJWKBase],
  105. signature: bytes,
  106. ) -> bool:
  107. key = self._check_key(key)
  108. return key.verify(
  109. message,
  110. signature,
  111. hash_fun=self.hash_fun,
  112. padding=padding.PKCS1v15(),
  113. )
  114. RS256 = RSAAlgorithm(SHA256)
  115. RS384 = RSAAlgorithm(SHA384)
  116. RS512 = RSAAlgorithm(SHA512)
  117. class PSSRSAAlgorithm(AbstractSigningAlgorithm):
  118. def __init__(self, hash_fun: Callable[[], Any]) -> None:
  119. self.hash_fun = hash_fun
  120. def _check_key(
  121. self,
  122. key: Optional[AbstractJWKBase],
  123. must_sign_key: bool = False,
  124. ) -> AbstractJWKBase:
  125. if not key or key.get_kty() != "RSA":
  126. raise InvalidKeyTypeError("RSA key is required")
  127. if must_sign_key and not key.is_sign_key():
  128. raise InvalidKeyTypeError(
  129. "a RSA private key is required, but passed is RSA public key"
  130. )
  131. return key
  132. def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
  133. key = self._check_key(key, must_sign_key=True)
  134. return key.sign(
  135. message,
  136. hash_fun=self.hash_fun,
  137. padding=padding.PSS(
  138. mgf=padding.MGF1(self.hash_fun()),
  139. salt_length=self.hash_fun().digest_size,
  140. ),
  141. )
  142. def verify(
  143. self,
  144. message: bytes,
  145. key: Optional[AbstractJWKBase],
  146. signature: bytes,
  147. ) -> bool:
  148. key = self._check_key(key)
  149. return key.verify(
  150. message,
  151. signature,
  152. hash_fun=self.hash_fun,
  153. padding=padding.PSS(
  154. mgf=padding.MGF1(self.hash_fun()),
  155. salt_length=self.hash_fun().digest_size,
  156. ),
  157. )
  158. PS256 = PSSRSAAlgorithm(SHA256)
  159. PS384 = PSSRSAAlgorithm(SHA384)
  160. PS512 = PSSRSAAlgorithm(SHA512)
  161. def supported_signing_algorithms() -> dict[str, AbstractSigningAlgorithm]:
  162. # NOTE(yosida95): exclude vulnerable 'none' algorithm by default.
  163. return {
  164. "HS256": HS256,
  165. "HS384": HS384,
  166. "HS512": HS512,
  167. "RS256": RS256,
  168. "RS384": RS384,
  169. "RS512": RS512,
  170. "PS256": PS256,
  171. "PS384": PS384,
  172. "PS512": PS512,
  173. }