You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

390 lines
13 KiB

  1. from __future__ import annotations
  2. import os
  3. import random
  4. import warnings
  5. from collections.abc import Generator, Sequence
  6. from typing import Any
  7. from .datastructures import Headers, MultipleValuesError
  8. from .exceptions import (
  9. InvalidHandshake,
  10. InvalidHeader,
  11. InvalidHeaderValue,
  12. InvalidMessage,
  13. InvalidStatus,
  14. InvalidUpgrade,
  15. NegotiationError,
  16. )
  17. from .extensions import ClientExtensionFactory, Extension
  18. from .headers import (
  19. build_authorization_basic,
  20. build_extension,
  21. build_host,
  22. build_subprotocol,
  23. parse_connection,
  24. parse_extension,
  25. parse_subprotocol,
  26. parse_upgrade,
  27. )
  28. from .http11 import Request, Response
  29. from .imports import lazy_import
  30. from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State
  31. from .typing import (
  32. ConnectionOption,
  33. ExtensionHeader,
  34. LoggerLike,
  35. Origin,
  36. Subprotocol,
  37. UpgradeProtocol,
  38. )
  39. from .uri import WebSocketURI
  40. from .utils import accept_key, generate_key
  41. __all__ = ["ClientProtocol"]
  42. class ClientProtocol(Protocol):
  43. """
  44. Sans-I/O implementation of a WebSocket client connection.
  45. Args:
  46. uri: URI of the WebSocket server, parsed
  47. with :func:`~websockets.uri.parse_uri`.
  48. origin: Value of the ``Origin`` header. This is useful when connecting
  49. to a server that validates the ``Origin`` header to defend against
  50. Cross-Site WebSocket Hijacking attacks.
  51. extensions: List of supported extensions, in order in which they
  52. should be tried.
  53. subprotocols: List of supported subprotocols, in order of decreasing
  54. preference.
  55. state: Initial state of the WebSocket connection.
  56. max_size: Maximum size of incoming messages in bytes;
  57. :obj:`None` disables the limit.
  58. logger: Logger for this connection;
  59. defaults to ``logging.getLogger("websockets.client")``;
  60. see the :doc:`logging guide <../../topics/logging>` for details.
  61. """
  62. def __init__(
  63. self,
  64. uri: WebSocketURI,
  65. *,
  66. origin: Origin | None = None,
  67. extensions: Sequence[ClientExtensionFactory] | None = None,
  68. subprotocols: Sequence[Subprotocol] | None = None,
  69. state: State = CONNECTING,
  70. max_size: int | None = 2**20,
  71. logger: LoggerLike | None = None,
  72. ) -> None:
  73. super().__init__(
  74. side=CLIENT,
  75. state=state,
  76. max_size=max_size,
  77. logger=logger,
  78. )
  79. self.uri = uri
  80. self.origin = origin
  81. self.available_extensions = extensions
  82. self.available_subprotocols = subprotocols
  83. self.key = generate_key()
  84. def connect(self) -> Request:
  85. """
  86. Create a handshake request to open a connection.
  87. You must send the handshake request with :meth:`send_request`.
  88. You can modify it before sending it, for example to add HTTP headers.
  89. Returns:
  90. WebSocket handshake request event to send to the server.
  91. """
  92. headers = Headers()
  93. headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure)
  94. if self.uri.user_info:
  95. headers["Authorization"] = build_authorization_basic(*self.uri.user_info)
  96. if self.origin is not None:
  97. headers["Origin"] = self.origin
  98. headers["Upgrade"] = "websocket"
  99. headers["Connection"] = "Upgrade"
  100. headers["Sec-WebSocket-Key"] = self.key
  101. headers["Sec-WebSocket-Version"] = "13"
  102. if self.available_extensions is not None:
  103. headers["Sec-WebSocket-Extensions"] = build_extension(
  104. [
  105. (extension_factory.name, extension_factory.get_request_params())
  106. for extension_factory in self.available_extensions
  107. ]
  108. )
  109. if self.available_subprotocols is not None:
  110. headers["Sec-WebSocket-Protocol"] = build_subprotocol(
  111. self.available_subprotocols
  112. )
  113. return Request(self.uri.resource_name, headers)
  114. def process_response(self, response: Response) -> None:
  115. """
  116. Check a handshake response.
  117. Args:
  118. request: WebSocket handshake response received from the server.
  119. Raises:
  120. InvalidHandshake: If the handshake response is invalid.
  121. """
  122. if response.status_code != 101:
  123. raise InvalidStatus(response)
  124. headers = response.headers
  125. connection: list[ConnectionOption] = sum(
  126. [parse_connection(value) for value in headers.get_all("Connection")], []
  127. )
  128. if not any(value.lower() == "upgrade" for value in connection):
  129. raise InvalidUpgrade(
  130. "Connection", ", ".join(connection) if connection else None
  131. )
  132. upgrade: list[UpgradeProtocol] = sum(
  133. [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
  134. )
  135. # For compatibility with non-strict implementations, ignore case when
  136. # checking the Upgrade header. It's supposed to be 'WebSocket'.
  137. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
  138. raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
  139. try:
  140. s_w_accept = headers["Sec-WebSocket-Accept"]
  141. except KeyError:
  142. raise InvalidHeader("Sec-WebSocket-Accept") from None
  143. except MultipleValuesError:
  144. raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None
  145. if s_w_accept != accept_key(self.key):
  146. raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
  147. self.extensions = self.process_extensions(headers)
  148. self.subprotocol = self.process_subprotocol(headers)
  149. def process_extensions(self, headers: Headers) -> list[Extension]:
  150. """
  151. Handle the Sec-WebSocket-Extensions HTTP response header.
  152. Check that each extension is supported, as well as its parameters.
  153. :rfc:`6455` leaves the rules up to the specification of each
  154. extension.
  155. To provide this level of flexibility, for each extension accepted by
  156. the server, we check for a match with each extension available in the
  157. client configuration. If no match is found, an exception is raised.
  158. If several variants of the same extension are accepted by the server,
  159. it may be configured several times, which won't make sense in general.
  160. Extensions must implement their own requirements. For this purpose,
  161. the list of previously accepted extensions is provided.
  162. Other requirements, for example related to mandatory extensions or the
  163. order of extensions, may be implemented by overriding this method.
  164. Args:
  165. headers: WebSocket handshake response headers.
  166. Returns:
  167. List of accepted extensions.
  168. Raises:
  169. InvalidHandshake: To abort the handshake.
  170. """
  171. accepted_extensions: list[Extension] = []
  172. extensions = headers.get_all("Sec-WebSocket-Extensions")
  173. if extensions:
  174. if self.available_extensions is None:
  175. raise NegotiationError("no extensions supported")
  176. parsed_extensions: list[ExtensionHeader] = sum(
  177. [parse_extension(header_value) for header_value in extensions], []
  178. )
  179. for name, response_params in parsed_extensions:
  180. for extension_factory in self.available_extensions:
  181. # Skip non-matching extensions based on their name.
  182. if extension_factory.name != name:
  183. continue
  184. # Skip non-matching extensions based on their params.
  185. try:
  186. extension = extension_factory.process_response_params(
  187. response_params, accepted_extensions
  188. )
  189. except NegotiationError:
  190. continue
  191. # Add matching extension to the final list.
  192. accepted_extensions.append(extension)
  193. # Break out of the loop once we have a match.
  194. break
  195. # If we didn't break from the loop, no extension in our list
  196. # matched what the server sent. Fail the connection.
  197. else:
  198. raise NegotiationError(
  199. f"Unsupported extension: "
  200. f"name = {name}, params = {response_params}"
  201. )
  202. return accepted_extensions
  203. def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
  204. """
  205. Handle the Sec-WebSocket-Protocol HTTP response header.
  206. If provided, check that it contains exactly one supported subprotocol.
  207. Args:
  208. headers: WebSocket handshake response headers.
  209. Returns:
  210. Subprotocol, if one was selected.
  211. """
  212. subprotocol: Subprotocol | None = None
  213. subprotocols = headers.get_all("Sec-WebSocket-Protocol")
  214. if subprotocols:
  215. if self.available_subprotocols is None:
  216. raise NegotiationError("no subprotocols supported")
  217. parsed_subprotocols: Sequence[Subprotocol] = sum(
  218. [parse_subprotocol(header_value) for header_value in subprotocols], []
  219. )
  220. if len(parsed_subprotocols) > 1:
  221. raise InvalidHeader(
  222. "Sec-WebSocket-Protocol",
  223. f"multiple values: {', '.join(parsed_subprotocols)}",
  224. )
  225. subprotocol = parsed_subprotocols[0]
  226. if subprotocol not in self.available_subprotocols:
  227. raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
  228. return subprotocol
  229. def send_request(self, request: Request) -> None:
  230. """
  231. Send a handshake request to the server.
  232. Args:
  233. request: WebSocket handshake request event.
  234. """
  235. if self.debug:
  236. self.logger.debug("> GET %s HTTP/1.1", request.path)
  237. for key, value in request.headers.raw_items():
  238. self.logger.debug("> %s: %s", key, value)
  239. self.writes.append(request.serialize())
  240. def parse(self) -> Generator[None]:
  241. if self.state is CONNECTING:
  242. try:
  243. response = yield from Response.parse(
  244. self.reader.read_line,
  245. self.reader.read_exact,
  246. self.reader.read_to_eof,
  247. )
  248. except Exception as exc:
  249. self.handshake_exc = InvalidMessage(
  250. "did not receive a valid HTTP response"
  251. )
  252. self.handshake_exc.__cause__ = exc
  253. self.send_eof()
  254. self.parser = self.discard()
  255. next(self.parser) # start coroutine
  256. yield
  257. if self.debug:
  258. code, phrase = response.status_code, response.reason_phrase
  259. self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
  260. for key, value in response.headers.raw_items():
  261. self.logger.debug("< %s: %s", key, value)
  262. if response.body:
  263. self.logger.debug("< [body] (%d bytes)", len(response.body))
  264. try:
  265. self.process_response(response)
  266. except InvalidHandshake as exc:
  267. response._exception = exc
  268. self.events.append(response)
  269. self.handshake_exc = exc
  270. self.send_eof()
  271. self.parser = self.discard()
  272. next(self.parser) # start coroutine
  273. yield
  274. assert self.state is CONNECTING
  275. self.state = OPEN
  276. self.events.append(response)
  277. yield from super().parse()
  278. class ClientConnection(ClientProtocol):
  279. def __init__(self, *args: Any, **kwargs: Any) -> None:
  280. warnings.warn( # deprecated in 11.0 - 2023-04-02
  281. "ClientConnection was renamed to ClientProtocol",
  282. DeprecationWarning,
  283. )
  284. super().__init__(*args, **kwargs)
  285. BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
  286. BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
  287. BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
  288. BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
  289. def backoff(
  290. initial_delay: float = BACKOFF_INITIAL_DELAY,
  291. min_delay: float = BACKOFF_MIN_DELAY,
  292. max_delay: float = BACKOFF_MAX_DELAY,
  293. factor: float = BACKOFF_FACTOR,
  294. ) -> Generator[float]:
  295. """
  296. Generate a series of backoff delays between reconnection attempts.
  297. Yields:
  298. How many seconds to wait before retrying to connect.
  299. """
  300. # Add a random initial delay between 0 and 5 seconds.
  301. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
  302. yield random.random() * initial_delay
  303. delay = min_delay
  304. while delay < max_delay:
  305. yield delay
  306. delay *= factor
  307. while True:
  308. yield max_delay
  309. lazy_import(
  310. globals(),
  311. deprecated_aliases={
  312. # deprecated in 14.0 - 2024-11-09
  313. "WebSocketClientProtocol": ".legacy.client",
  314. "connect": ".legacy.client",
  315. "unix_connect": ".legacy.client",
  316. },
  317. )