You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

113 line
3.4 KiB

  1. # Copyright 2017 Gehirn Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from typing import (
  16. AbstractSet,
  17. Optional,
  18. )
  19. from .exceptions import (
  20. JWSDecodeError,
  21. JWSEncodeError,
  22. )
  23. from .jwa import (
  24. AbstractSigningAlgorithm,
  25. supported_signing_algorithms,
  26. )
  27. from .jwk import AbstractJWKBase
  28. from .utils import (
  29. b64decode,
  30. b64encode,
  31. )
  32. __all__ = ["JWS"]
  33. class JWS:
  34. def __init__(self) -> None:
  35. self._supported_algs = supported_signing_algorithms()
  36. def _retrieve_alg(self, alg: str) -> AbstractSigningAlgorithm:
  37. try:
  38. return self._supported_algs[alg]
  39. except KeyError:
  40. raise JWSDecodeError("Unsupported signing algorithm.")
  41. def encode(
  42. self,
  43. message: bytes,
  44. key: Optional[AbstractJWKBase] = None,
  45. alg="HS256",
  46. optional_headers: Optional[dict[str, str]] = None,
  47. ) -> str:
  48. if alg not in self._supported_algs: # pragma: no cover
  49. raise JWSEncodeError(f"unsupported algorithm: {alg}")
  50. alg_impl = self._retrieve_alg(alg)
  51. header = optional_headers.copy() if optional_headers else {}
  52. header["alg"] = alg
  53. header_b64 = b64encode(
  54. json.dumps(header, separators=(",", ":")).encode("ascii")
  55. )
  56. message_b64 = b64encode(message)
  57. signing_message = header_b64 + "." + message_b64
  58. signature = alg_impl.sign(signing_message.encode("ascii"), key)
  59. signature_b64 = b64encode(signature)
  60. return signing_message + "." + signature_b64
  61. def _decode_segments(
  62. self, message: str
  63. ) -> tuple[dict[str, str], bytes, bytes, str]:
  64. try:
  65. signing_message, signature_b64 = message.rsplit(".", 1)
  66. header_b64, message_b64 = signing_message.split(".")
  67. except ValueError:
  68. raise JWSDecodeError("malformed JWS payload")
  69. header = json.loads(b64decode(header_b64).decode("ascii"))
  70. message_bin = b64decode(message_b64)
  71. signature = b64decode(signature_b64)
  72. return header, message_bin, signature, signing_message
  73. def decode(
  74. self,
  75. message: str,
  76. key: Optional[AbstractJWKBase] = None,
  77. do_verify=True,
  78. algorithms: Optional[AbstractSet[str]] = None,
  79. ) -> bytes:
  80. if algorithms is None:
  81. algorithms = set(supported_signing_algorithms().keys())
  82. header, message_bin, signature, signing_message = (
  83. self._decode_segments(message)
  84. )
  85. alg_value = header["alg"]
  86. if alg_value not in algorithms:
  87. raise JWSDecodeError("Unsupported signing algorithm.")
  88. alg_impl = self._retrieve_alg(alg_value)
  89. if do_verify and not alg_impl.verify(
  90. signing_message.encode("ascii"), key, signature
  91. ):
  92. raise JWSDecodeError("JWS passed could not be validated")
  93. return message_bin