You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

161 lines
5.3 KiB

  1. """
  2. :mod:`websockets.auth` provides HTTP Basic Authentication according to
  3. :rfc:`7235` and :rfc:`7617`.
  4. """
  5. import functools
  6. import http
  7. from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union
  8. from .exceptions import InvalidHeader
  9. from .headers import build_www_authenticate_basic, parse_authorization_basic
  10. from .http import Headers
  11. from .server import HTTPResponse, WebSocketServerProtocol
  12. __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
  13. Credentials = Tuple[str, str]
  14. def is_credentials(value: Any) -> bool:
  15. try:
  16. username, password = value
  17. except (TypeError, ValueError):
  18. return False
  19. else:
  20. return isinstance(username, str) and isinstance(password, str)
  21. class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
  22. """
  23. WebSocket server protocol that enforces HTTP Basic Auth.
  24. """
  25. def __init__(
  26. self,
  27. *args: Any,
  28. realm: str,
  29. check_credentials: Callable[[str, str], Awaitable[bool]],
  30. **kwargs: Any,
  31. ) -> None:
  32. self.realm = realm
  33. self.check_credentials = check_credentials
  34. super().__init__(*args, **kwargs)
  35. async def process_request(
  36. self, path: str, request_headers: Headers
  37. ) -> Optional[HTTPResponse]:
  38. """
  39. Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed.
  40. If authentication succeeds, the username of the authenticated user is
  41. stored in the ``username`` attribute.
  42. """
  43. try:
  44. authorization = request_headers["Authorization"]
  45. except KeyError:
  46. return (
  47. http.HTTPStatus.UNAUTHORIZED,
  48. [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
  49. b"Missing credentials\n",
  50. )
  51. try:
  52. username, password = parse_authorization_basic(authorization)
  53. except InvalidHeader:
  54. return (
  55. http.HTTPStatus.UNAUTHORIZED,
  56. [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
  57. b"Unsupported credentials\n",
  58. )
  59. if not await self.check_credentials(username, password):
  60. return (
  61. http.HTTPStatus.UNAUTHORIZED,
  62. [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
  63. b"Invalid credentials\n",
  64. )
  65. self.username = username
  66. return await super().process_request(path, request_headers)
  67. def basic_auth_protocol_factory(
  68. realm: str,
  69. credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
  70. check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
  71. create_protocol: Type[
  72. BasicAuthWebSocketServerProtocol
  73. ] = BasicAuthWebSocketServerProtocol,
  74. ) -> Callable[[Any], BasicAuthWebSocketServerProtocol]:
  75. """
  76. Protocol factory that enforces HTTP Basic Auth.
  77. ``basic_auth_protocol_factory`` is designed to integrate with
  78. :func:`~websockets.server.serve` like this::
  79. websockets.serve(
  80. ...,
  81. create_protocol=websockets.basic_auth_protocol_factory(
  82. realm="my dev server",
  83. credentials=("hello", "iloveyou"),
  84. )
  85. )
  86. ``realm`` indicates the scope of protection. It should contain only ASCII
  87. characters because the encoding of non-ASCII characters is undefined.
  88. Refer to section 2.2 of :rfc:`7235` for details.
  89. ``credentials`` defines hard coded authorized credentials. It can be a
  90. ``(username, password)`` pair or a list of such pairs.
  91. ``check_credentials`` defines a coroutine that checks whether credentials
  92. are authorized. This coroutine receives ``username`` and ``password``
  93. arguments and returns a :class:`bool`.
  94. One of ``credentials`` or ``check_credentials`` must be provided but not
  95. both.
  96. By default, ``basic_auth_protocol_factory`` creates a factory for building
  97. :class:`BasicAuthWebSocketServerProtocol` instances. You can override this
  98. with the ``create_protocol`` parameter.
  99. :param realm: scope of protection
  100. :param credentials: hard coded credentials
  101. :param check_credentials: coroutine that verifies credentials
  102. :raises TypeError: if the credentials argument has the wrong type
  103. """
  104. if (credentials is None) == (check_credentials is None):
  105. raise TypeError("provide either credentials or check_credentials")
  106. if credentials is not None:
  107. if is_credentials(credentials):
  108. async def check_credentials(username: str, password: str) -> bool:
  109. return (username, password) == credentials
  110. elif isinstance(credentials, Iterable):
  111. credentials_list = list(credentials)
  112. if all(is_credentials(item) for item in credentials_list):
  113. credentials_dict = dict(credentials_list)
  114. async def check_credentials(username: str, password: str) -> bool:
  115. return credentials_dict.get(username) == password
  116. else:
  117. raise TypeError(f"invalid credentials argument: {credentials}")
  118. else:
  119. raise TypeError(f"invalid credentials argument: {credentials}")
  120. return functools.partial(
  121. create_protocol, realm=realm, check_credentials=check_credentials
  122. )