You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

144 lines
4.7 KiB

  1. # Copyright 2017 Gehirn Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from datetime import (
  16. datetime,
  17. timezone,
  18. )
  19. from typing import (
  20. AbstractSet,
  21. Any,
  22. Optional,
  23. )
  24. from jwt.utils import (
  25. get_time_from_int,
  26. )
  27. from .exceptions import (
  28. JWSDecodeError,
  29. JWSEncodeError,
  30. JWTDecodeError,
  31. JWTEncodeError,
  32. )
  33. from .jwk import AbstractJWKBase
  34. from .jws import JWS
  35. class JWT:
  36. def __init__(self):
  37. self._jws = JWS()
  38. def encode(
  39. self,
  40. payload: dict[str, Any],
  41. key: Optional[AbstractJWKBase] = None,
  42. alg="HS256",
  43. optional_headers: Optional[dict[str, str]] = None,
  44. ) -> str:
  45. if not isinstance(self, JWT): # pragma: no cover
  46. # https://github.com/GehirnInc/python-jwt/issues/15
  47. raise RuntimeError(
  48. "encode must be called on a jwt.JWT() instance. "
  49. "Do jwt.JWT().encode(...)"
  50. )
  51. if not isinstance(payload, dict): # pragma: no cover
  52. raise TypeError("payload must be a dict")
  53. if not (
  54. key is None or isinstance(key, AbstractJWKBase)
  55. ): # pragma: no cover
  56. raise TypeError(
  57. "key must be an instance of a class implements "
  58. "jwt.AbstractJWKBase"
  59. )
  60. if not (
  61. optional_headers is None or isinstance(optional_headers, dict)
  62. ): # pragma: no cover
  63. raise TypeError("optional_headers must be a dict")
  64. try:
  65. message = json.dumps(payload).encode("utf-8")
  66. except ValueError as why:
  67. raise JWTEncodeError(
  68. "payload must be able to be encoded to JSON"
  69. ) from why
  70. optional_headers = optional_headers and optional_headers.copy() or {}
  71. optional_headers["typ"] = "JWT"
  72. try:
  73. return self._jws.encode(message, key, alg, optional_headers)
  74. except JWSEncodeError as why:
  75. raise JWTEncodeError("failed to encode to JWT") from why
  76. def decode(
  77. self,
  78. message: str,
  79. key: Optional[AbstractJWKBase] = None,
  80. do_verify=True,
  81. algorithms: Optional[AbstractSet[str]] = None,
  82. do_time_check: bool = True,
  83. ) -> dict[str, Any]:
  84. if not isinstance(self, JWT): # pragma: no cover
  85. # https://github.com/GehirnInc/python-jwt/issues/15
  86. raise RuntimeError(
  87. "decode must be called on a jwt.JWT() instance. "
  88. "Do jwt.JWT().decode(...)"
  89. )
  90. if not isinstance(message, str): # pragma: no cover
  91. raise TypeError("message must be a str")
  92. if not (
  93. key is None or isinstance(key, AbstractJWKBase)
  94. ): # pragma: no cover
  95. raise TypeError(
  96. "key must be an instance of a class implements "
  97. "jwt.AbstractJWKBase"
  98. )
  99. # utc now with timezone
  100. now = datetime.now(timezone.utc)
  101. try:
  102. message_bin = self._jws.decode(message, key, do_verify, algorithms)
  103. except JWSDecodeError as why:
  104. raise JWTDecodeError("failed to decode JWT") from why
  105. try:
  106. payload = json.loads(message_bin.decode("utf-8"))
  107. except ValueError as why:
  108. raise JWTDecodeError(
  109. "a payload of the JWT is not valid JSON"
  110. ) from why
  111. # The "exp" (expiration time) claim identifies the expiration time on
  112. # or after which the JWT MUST NOT be accepted for processing.
  113. if "exp" in payload and do_time_check:
  114. try:
  115. exp = get_time_from_int(payload["exp"])
  116. except TypeError:
  117. raise JWTDecodeError("Invalid Expired value")
  118. if now >= exp:
  119. raise JWTDecodeError("JWT Expired")
  120. # The "nbf" (not before) claim identifies the time before which the JWT
  121. # MUST NOT be accepted for processing.
  122. if "nbf" in payload and do_time_check:
  123. try:
  124. nbf = get_time_from_int(payload["nbf"])
  125. except TypeError:
  126. raise JWTDecodeError('Invalid "Not valid yet" value')
  127. if now < nbf:
  128. raise JWTDecodeError("JWT Not valid yet")
  129. return payload