You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

166 lines
4.7 KiB

  1. import base64
  2. import re
  3. import struct
  4. # Piggyback of the backends implementation of the function that converts a long
  5. # to a bytes stream. Some plumbing is necessary to have the signatures match.
  6. try:
  7. from cryptography.utils import int_to_bytes as _long_to_bytes
  8. def long_to_bytes(n, blocksize=0):
  9. return _long_to_bytes(n, blocksize or None)
  10. except ImportError:
  11. from ecdsa.ecdsa import int_to_string as _long_to_bytes
  12. def long_to_bytes(n, blocksize=0):
  13. ret = _long_to_bytes(n)
  14. if blocksize == 0:
  15. return ret
  16. else:
  17. assert len(ret) <= blocksize
  18. padding = blocksize - len(ret)
  19. return b"\x00" * padding + ret
  20. def long_to_base64(data, size=0):
  21. return base64.urlsafe_b64encode(long_to_bytes(data, size)).strip(b"=")
  22. def int_arr_to_long(arr):
  23. return int("".join(["%02x" % byte for byte in arr]), 16)
  24. def base64_to_long(data):
  25. if isinstance(data, str):
  26. data = data.encode("ascii")
  27. # urlsafe_b64decode will happily convert b64encoded data
  28. _d = base64.urlsafe_b64decode(bytes(data) + b"==")
  29. return int_arr_to_long(struct.unpack("%sB" % len(_d), _d))
  30. def calculate_at_hash(access_token, hash_alg):
  31. """Helper method for calculating an access token
  32. hash, as described in http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
  33. Its value is the base64url encoding of the left-most half of the hash of the octets
  34. of the ASCII representation of the access_token value, where the hash algorithm
  35. used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE
  36. Header. For instance, if the alg is RS256, hash the access_token value with SHA-256,
  37. then take the left-most 128 bits and base64url encode them. The at_hash value is a
  38. case sensitive string.
  39. Args:
  40. access_token (str): An access token string.
  41. hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256
  42. """
  43. hash_digest = hash_alg(access_token.encode("utf-8")).digest()
  44. cut_at = int(len(hash_digest) / 2)
  45. truncated = hash_digest[:cut_at]
  46. at_hash = base64url_encode(truncated)
  47. return at_hash.decode("utf-8")
  48. def base64url_decode(input):
  49. """Helper method to base64url_decode a string.
  50. Args:
  51. input (bytes): A base64url_encoded string (bytes) to decode.
  52. """
  53. rem = len(input) % 4
  54. if rem > 0:
  55. input += b"=" * (4 - rem)
  56. return base64.urlsafe_b64decode(input)
  57. def base64url_encode(input):
  58. """Helper method to base64url_encode a string.
  59. Args:
  60. input (bytes): A base64url_encoded string (bytes) to encode.
  61. """
  62. return base64.urlsafe_b64encode(input).replace(b"=", b"")
  63. def timedelta_total_seconds(delta):
  64. """Helper method to determine the total number of seconds
  65. from a timedelta.
  66. Args:
  67. delta (timedelta): A timedelta to convert to seconds.
  68. """
  69. return delta.days * 24 * 60 * 60 + delta.seconds
  70. def ensure_binary(s):
  71. """Coerce **s** to bytes."""
  72. if isinstance(s, bytes):
  73. return s
  74. if isinstance(s, str):
  75. return s.encode("utf-8", "strict")
  76. raise TypeError(f"not expecting type '{type(s)}'")
  77. # The following was copied from PyJWT:
  78. # https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc
  79. # Based on:
  80. # https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
  81. _PEMS = {
  82. b"CERTIFICATE",
  83. b"TRUSTED CERTIFICATE",
  84. b"PRIVATE KEY",
  85. b"PUBLIC KEY",
  86. b"ENCRYPTED PRIVATE KEY",
  87. b"OPENSSH PRIVATE KEY",
  88. b"DSA PRIVATE KEY",
  89. b"RSA PRIVATE KEY",
  90. b"RSA PUBLIC KEY",
  91. b"EC PRIVATE KEY",
  92. b"DH PARAMETERS",
  93. b"NEW CERTIFICATE REQUEST",
  94. b"CERTIFICATE REQUEST",
  95. b"SSH2 PUBLIC KEY",
  96. b"SSH2 ENCRYPTED PRIVATE KEY",
  97. b"X509 CRL",
  98. }
  99. _PEM_RE = re.compile(
  100. b"----[- ]BEGIN (" + b"|".join(re.escape(pem) for pem in _PEMS) + b")[- ]----",
  101. )
  102. def is_pem_format(key: bytes) -> bool:
  103. return bool(_PEM_RE.search(key))
  104. # Based on
  105. # https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b
  106. # /src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
  107. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  108. _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
  109. _SSH_KEY_FORMATS = [
  110. b"ssh-ed25519",
  111. b"ssh-rsa",
  112. b"ssh-dss",
  113. b"ecdsa-sha2-nistp256",
  114. b"ecdsa-sha2-nistp384",
  115. b"ecdsa-sha2-nistp521",
  116. ]
  117. def is_ssh_key(key: bytes) -> bool:
  118. if any(string_value in key for string_value in _SSH_KEY_FORMATS):
  119. return True
  120. ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
  121. if ssh_pubkey_match:
  122. key_type = ssh_pubkey_match.group(1)
  123. if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
  124. return True
  125. return False