# -*- coding: utf-8 -*- # # Copyright 2017 Gehirn Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hashlib import hmac from typing import Callable from cryptography.hazmat.primitives.hashes import ( SHA256, SHA384, SHA512, ) from .exceptions import InvalidKeyTypeError from .jwk import AbstractJWKBase def std_hash_by_alg(alg: str) -> Callable[[bytes], object]: if alg.endswith('S256'): return hashlib.sha256 if alg.endswith('S384'): return hashlib.sha384 if alg.endswith('S512'): return hashlib.SHA512 raise ValueError('{} is not supported'.format(alg)) class AbstractSigningAlgorithm: def sign(self, message: bytes, key: AbstractJWKBase) -> bytes: raise NotImplementedError() # pragma: no cover def verify(self, message: bytes, key: AbstractJWKBase, signature: bytes) -> bool: raise NotImplementedError() # pragma: no cover class NoneAlgorithm(AbstractSigningAlgorithm): def sign(self, message: bytes, key: AbstractJWKBase) -> bytes: return b'' def verify(self, message: bytes, key: AbstractJWKBase, signature: bytes) -> bool: return hmac.compare_digest(signature, b'') none = NoneAlgorithm() class HMACAlgorithm(AbstractSigningAlgorithm): def __init__(self, hash_fun: Callable) -> None: self.hash_fun = hash_fun def _check_key(self, key: AbstractJWKBase) -> None: if key.get_kty() != 'oct': raise InvalidKeyTypeError(( 'an octet key is required, but passed is {}' ).format(key.get_kty())) def _sign(self, message: bytes, key: bytes) -> bytes: return hmac.new(key, message, self.hash_fun).digest() def sign(self, message: bytes, key: AbstractJWKBase) -> bytes: self._check_key(key) return key.sign(message, signer=self._sign) def verify(self, message: bytes, key: AbstractJWKBase, signature: bytes) -> bool: self._check_key(key) return key.verify(message, signature, signer=self._sign) HS256 = HMACAlgorithm(hashlib.sha256) HS384 = HMACAlgorithm(hashlib.sha384) HS512 = HMACAlgorithm(hashlib.sha512) class RSAAlgorithm(AbstractSigningAlgorithm): def __init__(self, hash_fun: object) -> None: self.hash_fun = hash_fun def _check_key(self, key: AbstractJWKBase, must_sign_key=False) -> None: if key.get_kty() != 'RSA': raise InvalidKeyTypeError(( 'a RSA key is required, but passed is {}' ).format(key.get_kty())) if must_sign_key and not key.is_sign_key(): raise InvalidKeyTypeError( 'a RSA private key is required, but passed is RSA public key') def sign(self, message: bytes, key: AbstractJWKBase) -> bytes: self._check_key(key, must_sign_key=True) return key.sign(message, hash_fun=self.hash_fun) def verify(self, message: bytes, key: AbstractJWKBase, signature: bytes) -> bool: self._check_key(key) return key.verify(message, signature, hash_fun=self.hash_fun) RS256 = RSAAlgorithm(SHA256) RS384 = RSAAlgorithm(SHA384) RS512 = RSAAlgorithm(SHA512) def supported_signing_algorithms(): # NOTE(yosida95): exclude vulnerable 'none' algorithm by default. return { 'HS256': HS256, 'HS384': HS384, 'HS512': HS512, 'RS256': RS256, 'RS384': RS384, 'RS512': RS512, }