You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

159 line
5.2 KiB

  1. from __future__ import annotations
  2. import base64
  3. import binascii
  4. from ..datastructures import Headers, MultipleValuesError
  5. from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
  6. from ..headers import parse_connection, parse_upgrade
  7. from ..typing import ConnectionOption, UpgradeProtocol
  8. from ..utils import accept_key as accept, generate_key
  9. __all__ = ["build_request", "check_request", "build_response", "check_response"]
  10. def build_request(headers: Headers) -> str:
  11. """
  12. Build a handshake request to send to the server.
  13. Update request headers passed in argument.
  14. Args:
  15. headers: Handshake request headers.
  16. Returns:
  17. ``key`` that must be passed to :func:`check_response`.
  18. """
  19. key = generate_key()
  20. headers["Upgrade"] = "websocket"
  21. headers["Connection"] = "Upgrade"
  22. headers["Sec-WebSocket-Key"] = key
  23. headers["Sec-WebSocket-Version"] = "13"
  24. return key
  25. def check_request(headers: Headers) -> str:
  26. """
  27. Check a handshake request received from the client.
  28. This function doesn't verify that the request is an HTTP/1.1 or higher GET
  29. request and doesn't perform ``Host`` and ``Origin`` checks. These controls
  30. are usually performed earlier in the HTTP request handling code. They're
  31. the responsibility of the caller.
  32. Args:
  33. headers: Handshake request headers.
  34. Returns:
  35. ``key`` that must be passed to :func:`build_response`.
  36. Raises:
  37. InvalidHandshake: If the handshake request is invalid.
  38. Then, the server must return a 400 Bad Request error.
  39. """
  40. connection: list[ConnectionOption] = sum(
  41. [parse_connection(value) for value in headers.get_all("Connection")], []
  42. )
  43. if not any(value.lower() == "upgrade" for value in connection):
  44. raise InvalidUpgrade("Connection", ", ".join(connection))
  45. upgrade: list[UpgradeProtocol] = sum(
  46. [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
  47. )
  48. # For compatibility with non-strict implementations, ignore case when
  49. # checking the Upgrade header. The RFC always uses "websocket", except
  50. # in section 11.2. (IANA registration) where it uses "WebSocket".
  51. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
  52. raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
  53. try:
  54. s_w_key = headers["Sec-WebSocket-Key"]
  55. except KeyError as exc:
  56. raise InvalidHeader("Sec-WebSocket-Key") from exc
  57. except MultipleValuesError as exc:
  58. raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
  59. try:
  60. raw_key = base64.b64decode(s_w_key.encode(), validate=True)
  61. except binascii.Error as exc:
  62. raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
  63. if len(raw_key) != 16:
  64. raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
  65. try:
  66. s_w_version = headers["Sec-WebSocket-Version"]
  67. except KeyError as exc:
  68. raise InvalidHeader("Sec-WebSocket-Version") from exc
  69. except MultipleValuesError as exc:
  70. raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
  71. if s_w_version != "13":
  72. raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
  73. return s_w_key
  74. def build_response(headers: Headers, key: str) -> None:
  75. """
  76. Build a handshake response to send to the client.
  77. Update response headers passed in argument.
  78. Args:
  79. headers: Handshake response headers.
  80. key: Returned by :func:`check_request`.
  81. """
  82. headers["Upgrade"] = "websocket"
  83. headers["Connection"] = "Upgrade"
  84. headers["Sec-WebSocket-Accept"] = accept(key)
  85. def check_response(headers: Headers, key: str) -> None:
  86. """
  87. Check a handshake response received from the server.
  88. This function doesn't verify that the response is an HTTP/1.1 or higher
  89. response with a 101 status code. These controls are the responsibility of
  90. the caller.
  91. Args:
  92. headers: Handshake response headers.
  93. key: Returned by :func:`build_request`.
  94. Raises:
  95. InvalidHandshake: If the handshake response is invalid.
  96. """
  97. connection: list[ConnectionOption] = sum(
  98. [parse_connection(value) for value in headers.get_all("Connection")], []
  99. )
  100. if not any(value.lower() == "upgrade" for value in connection):
  101. raise InvalidUpgrade("Connection", " ".join(connection))
  102. upgrade: list[UpgradeProtocol] = sum(
  103. [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
  104. )
  105. # For compatibility with non-strict implementations, ignore case when
  106. # checking the Upgrade header. The RFC always uses "websocket", except
  107. # in section 11.2. (IANA registration) where it uses "WebSocket".
  108. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
  109. raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
  110. try:
  111. s_w_accept = headers["Sec-WebSocket-Accept"]
  112. except KeyError as exc:
  113. raise InvalidHeader("Sec-WebSocket-Accept") from exc
  114. except MultipleValuesError as exc:
  115. raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
  116. if s_w_accept != accept(key):
  117. raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)