You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

587 line
21 KiB

  1. """
  2. :mod:`websockets.client` defines the WebSocket client APIs.
  3. """
  4. import asyncio
  5. import collections.abc
  6. import functools
  7. import logging
  8. import warnings
  9. from types import TracebackType
  10. from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast
  11. from .exceptions import (
  12. InvalidHandshake,
  13. InvalidHeader,
  14. InvalidMessage,
  15. InvalidStatusCode,
  16. NegotiationError,
  17. RedirectHandshake,
  18. SecurityError,
  19. )
  20. from .extensions.base import ClientExtensionFactory, Extension
  21. from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
  22. from .handshake import build_request, check_response
  23. from .headers import (
  24. ExtensionHeader,
  25. build_authorization_basic,
  26. build_extension,
  27. build_subprotocol,
  28. parse_extension,
  29. parse_subprotocol,
  30. )
  31. from .http import USER_AGENT, Headers, HeadersLike, read_response
  32. from .protocol import WebSocketCommonProtocol
  33. from .typing import Origin, Subprotocol
  34. from .uri import WebSocketURI, parse_uri
  35. __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
  36. logger = logging.getLogger(__name__)
  37. class WebSocketClientProtocol(WebSocketCommonProtocol):
  38. """
  39. :class:`~asyncio.Protocol` subclass implementing a WebSocket client.
  40. This class inherits most of its methods from
  41. :class:`~websockets.protocol.WebSocketCommonProtocol`.
  42. """
  43. is_client = True
  44. side = "client"
  45. def __init__(
  46. self,
  47. *,
  48. origin: Optional[Origin] = None,
  49. extensions: Optional[Sequence[ClientExtensionFactory]] = None,
  50. subprotocols: Optional[Sequence[Subprotocol]] = None,
  51. extra_headers: Optional[HeadersLike] = None,
  52. **kwargs: Any,
  53. ) -> None:
  54. self.origin = origin
  55. self.available_extensions = extensions
  56. self.available_subprotocols = subprotocols
  57. self.extra_headers = extra_headers
  58. super().__init__(**kwargs)
  59. def write_http_request(self, path: str, headers: Headers) -> None:
  60. """
  61. Write request line and headers to the HTTP request.
  62. """
  63. self.path = path
  64. self.request_headers = headers
  65. logger.debug("%s > GET %s HTTP/1.1", self.side, path)
  66. logger.debug("%s > %r", self.side, headers)
  67. # Since the path and headers only contain ASCII characters,
  68. # we can keep this simple.
  69. request = f"GET {path} HTTP/1.1\r\n"
  70. request += str(headers)
  71. self.writer.write(request.encode())
  72. async def read_http_response(self) -> Tuple[int, Headers]:
  73. """
  74. Read status line and headers from the HTTP response.
  75. If the response contains a body, it may be read from ``self.reader``
  76. after this coroutine returns.
  77. :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is
  78. malformed or isn't an HTTP/1.1 GET response
  79. """
  80. try:
  81. status_code, reason, headers = await read_response(self.reader)
  82. except Exception as exc:
  83. raise InvalidMessage("did not receive a valid HTTP response") from exc
  84. logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason)
  85. logger.debug("%s < %r", self.side, headers)
  86. self.response_headers = headers
  87. return status_code, self.response_headers
  88. @staticmethod
  89. def process_extensions(
  90. headers: Headers,
  91. available_extensions: Optional[Sequence[ClientExtensionFactory]],
  92. ) -> List[Extension]:
  93. """
  94. Handle the Sec-WebSocket-Extensions HTTP response header.
  95. Check that each extension is supported, as well as its parameters.
  96. Return the list of accepted extensions.
  97. Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
  98. connection.
  99. :rfc:`6455` leaves the rules up to the specification of each
  100. :extension.
  101. To provide this level of flexibility, for each extension accepted by
  102. the server, we check for a match with each extension available in the
  103. client configuration. If no match is found, an exception is raised.
  104. If several variants of the same extension are accepted by the server,
  105. it may be configured severel times, which won't make sense in general.
  106. Extensions must implement their own requirements. For this purpose,
  107. the list of previously accepted extensions is provided.
  108. Other requirements, for example related to mandatory extensions or the
  109. order of extensions, may be implemented by overriding this method.
  110. """
  111. accepted_extensions: List[Extension] = []
  112. header_values = headers.get_all("Sec-WebSocket-Extensions")
  113. if header_values:
  114. if available_extensions is None:
  115. raise InvalidHandshake("no extensions supported")
  116. parsed_header_values: List[ExtensionHeader] = sum(
  117. [parse_extension(header_value) for header_value in header_values], []
  118. )
  119. for name, response_params in parsed_header_values:
  120. for extension_factory in available_extensions:
  121. # Skip non-matching extensions based on their name.
  122. if extension_factory.name != name:
  123. continue
  124. # Skip non-matching extensions based on their params.
  125. try:
  126. extension = extension_factory.process_response_params(
  127. response_params, accepted_extensions
  128. )
  129. except NegotiationError:
  130. continue
  131. # Add matching extension to the final list.
  132. accepted_extensions.append(extension)
  133. # Break out of the loop once we have a match.
  134. break
  135. # If we didn't break from the loop, no extension in our list
  136. # matched what the server sent. Fail the connection.
  137. else:
  138. raise NegotiationError(
  139. f"Unsupported extension: "
  140. f"name = {name}, params = {response_params}"
  141. )
  142. return accepted_extensions
  143. @staticmethod
  144. def process_subprotocol(
  145. headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
  146. ) -> Optional[Subprotocol]:
  147. """
  148. Handle the Sec-WebSocket-Protocol HTTP response header.
  149. Check that it contains exactly one supported subprotocol.
  150. Return the selected subprotocol.
  151. """
  152. subprotocol: Optional[Subprotocol] = None
  153. header_values = headers.get_all("Sec-WebSocket-Protocol")
  154. if header_values:
  155. if available_subprotocols is None:
  156. raise InvalidHandshake("no subprotocols supported")
  157. parsed_header_values: Sequence[Subprotocol] = sum(
  158. [parse_subprotocol(header_value) for header_value in header_values], []
  159. )
  160. if len(parsed_header_values) > 1:
  161. subprotocols = ", ".join(parsed_header_values)
  162. raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
  163. subprotocol = parsed_header_values[0]
  164. if subprotocol not in available_subprotocols:
  165. raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
  166. return subprotocol
  167. async def handshake(
  168. self,
  169. wsuri: WebSocketURI,
  170. origin: Optional[Origin] = None,
  171. available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
  172. available_subprotocols: Optional[Sequence[Subprotocol]] = None,
  173. extra_headers: Optional[HeadersLike] = None,
  174. ) -> None:
  175. """
  176. Perform the client side of the opening handshake.
  177. :param origin: sets the Origin HTTP header
  178. :param available_extensions: list of supported extensions in the order
  179. in which they should be used
  180. :param available_subprotocols: list of supported subprotocols in order
  181. of decreasing preference
  182. :param extra_headers: sets additional HTTP request headers; it must be
  183. a :class:`~websockets.http.Headers` instance, a
  184. :class:`~collections.abc.Mapping`, or an iterable of ``(name,
  185. value)`` pairs
  186. :raises ~websockets.exceptions.InvalidHandshake: if the handshake
  187. fails
  188. """
  189. request_headers = Headers()
  190. if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
  191. request_headers["Host"] = wsuri.host
  192. else:
  193. request_headers["Host"] = f"{wsuri.host}:{wsuri.port}"
  194. if wsuri.user_info:
  195. request_headers["Authorization"] = build_authorization_basic(
  196. *wsuri.user_info
  197. )
  198. if origin is not None:
  199. request_headers["Origin"] = origin
  200. key = build_request(request_headers)
  201. if available_extensions is not None:
  202. extensions_header = build_extension(
  203. [
  204. (extension_factory.name, extension_factory.get_request_params())
  205. for extension_factory in available_extensions
  206. ]
  207. )
  208. request_headers["Sec-WebSocket-Extensions"] = extensions_header
  209. if available_subprotocols is not None:
  210. protocol_header = build_subprotocol(available_subprotocols)
  211. request_headers["Sec-WebSocket-Protocol"] = protocol_header
  212. if extra_headers is not None:
  213. if isinstance(extra_headers, Headers):
  214. extra_headers = extra_headers.raw_items()
  215. elif isinstance(extra_headers, collections.abc.Mapping):
  216. extra_headers = extra_headers.items()
  217. for name, value in extra_headers:
  218. request_headers[name] = value
  219. request_headers.setdefault("User-Agent", USER_AGENT)
  220. self.write_http_request(wsuri.resource_name, request_headers)
  221. status_code, response_headers = await self.read_http_response()
  222. if status_code in (301, 302, 303, 307, 308):
  223. if "Location" not in response_headers:
  224. raise InvalidHeader("Location")
  225. raise RedirectHandshake(response_headers["Location"])
  226. elif status_code != 101:
  227. raise InvalidStatusCode(status_code)
  228. check_response(response_headers, key)
  229. self.extensions = self.process_extensions(
  230. response_headers, available_extensions
  231. )
  232. self.subprotocol = self.process_subprotocol(
  233. response_headers, available_subprotocols
  234. )
  235. self.connection_open()
  236. class Connect:
  237. """
  238. Connect to the WebSocket server at the given ``uri``.
  239. Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
  240. can then be used to send and receive messages.
  241. :func:`connect` can also be used as a asynchronous context manager. In
  242. that case, the connection is closed when exiting the context.
  243. :func:`connect` is a wrapper around the event loop's
  244. :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments
  245. are passed to :meth:`~asyncio.loop.create_connection`.
  246. For example, you can set the ``ssl`` keyword argument to a
  247. :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to
  248. a ``wss://`` URI, if this argument isn't provided explicitly, it's set to
  249. ``True``, which means Python's default :class:`~ssl.SSLContext` is used.
  250. You can connect to a different host and port from those found in ``uri``
  251. by setting ``host`` and ``port`` keyword arguments. This only changes the
  252. destination of the TCP connection. The host name from ``uri`` is still
  253. used in the TLS handshake for secure connections and in the ``Host`` HTTP
  254. header.
  255. The ``create_protocol`` parameter allows customizing the
  256. :class:`~asyncio.Protocol` that manages the connection. It should be a
  257. callable or class accepting the same arguments as
  258. :class:`WebSocketClientProtocol` and returning an instance of
  259. :class:`WebSocketClientProtocol` or a subclass. It defaults to
  260. :class:`WebSocketClientProtocol`.
  261. The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
  262. ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is
  263. described in :class:`~websockets.protocol.WebSocketCommonProtocol`.
  264. :func:`connect` also accepts the following optional arguments:
  265. * ``compression`` is a shortcut to configure compression extensions;
  266. by default it enables the "permessage-deflate" extension; set it to
  267. ``None`` to disable compression
  268. * ``origin`` sets the Origin HTTP header
  269. * ``extensions`` is a list of supported extensions in order of
  270. decreasing preference
  271. * ``subprotocols`` is a list of supported subprotocols in order of
  272. decreasing preference
  273. * ``extra_headers`` sets additional HTTP request headers; it can be a
  274. :class:`~websockets.http.Headers` instance, a
  275. :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
  276. pairs
  277. :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid
  278. :raises ~websockets.handshake.InvalidHandshake: if the opening handshake
  279. fails
  280. """
  281. MAX_REDIRECTS_ALLOWED = 10
  282. def __init__(
  283. self,
  284. uri: str,
  285. *,
  286. path: Optional[str] = None,
  287. create_protocol: Optional[Type[WebSocketClientProtocol]] = None,
  288. ping_interval: float = 20,
  289. ping_timeout: float = 20,
  290. close_timeout: Optional[float] = None,
  291. max_size: int = 2 ** 20,
  292. max_queue: int = 2 ** 5,
  293. read_limit: int = 2 ** 16,
  294. write_limit: int = 2 ** 16,
  295. loop: Optional[asyncio.AbstractEventLoop] = None,
  296. legacy_recv: bool = False,
  297. klass: Optional[Type[WebSocketClientProtocol]] = None,
  298. timeout: Optional[float] = None,
  299. compression: Optional[str] = "deflate",
  300. origin: Optional[Origin] = None,
  301. extensions: Optional[Sequence[ClientExtensionFactory]] = None,
  302. subprotocols: Optional[Sequence[Subprotocol]] = None,
  303. extra_headers: Optional[HeadersLike] = None,
  304. **kwargs: Any,
  305. ) -> None:
  306. # Backwards compatibility: close_timeout used to be called timeout.
  307. if timeout is None:
  308. timeout = 10
  309. else:
  310. warnings.warn("rename timeout to close_timeout", DeprecationWarning)
  311. # If both are specified, timeout is ignored.
  312. if close_timeout is None:
  313. close_timeout = timeout
  314. # Backwards compatibility: create_protocol used to be called klass.
  315. if klass is None:
  316. klass = WebSocketClientProtocol
  317. else:
  318. warnings.warn("rename klass to create_protocol", DeprecationWarning)
  319. # If both are specified, klass is ignored.
  320. if create_protocol is None:
  321. create_protocol = klass
  322. if loop is None:
  323. loop = asyncio.get_event_loop()
  324. wsuri = parse_uri(uri)
  325. if wsuri.secure:
  326. kwargs.setdefault("ssl", True)
  327. elif kwargs.get("ssl") is not None:
  328. raise ValueError(
  329. "connect() received a ssl argument for a ws:// URI, "
  330. "use a wss:// URI to enable TLS"
  331. )
  332. if compression == "deflate":
  333. if extensions is None:
  334. extensions = []
  335. if not any(
  336. extension_factory.name == ClientPerMessageDeflateFactory.name
  337. for extension_factory in extensions
  338. ):
  339. extensions = list(extensions) + [
  340. ClientPerMessageDeflateFactory(client_max_window_bits=True)
  341. ]
  342. elif compression is not None:
  343. raise ValueError(f"unsupported compression: {compression}")
  344. factory = functools.partial(
  345. create_protocol,
  346. ping_interval=ping_interval,
  347. ping_timeout=ping_timeout,
  348. close_timeout=close_timeout,
  349. max_size=max_size,
  350. max_queue=max_queue,
  351. read_limit=read_limit,
  352. write_limit=write_limit,
  353. loop=loop,
  354. host=wsuri.host,
  355. port=wsuri.port,
  356. secure=wsuri.secure,
  357. legacy_recv=legacy_recv,
  358. origin=origin,
  359. extensions=extensions,
  360. subprotocols=subprotocols,
  361. extra_headers=extra_headers,
  362. )
  363. if path is None:
  364. host: Optional[str]
  365. port: Optional[int]
  366. if kwargs.get("sock") is None:
  367. host, port = wsuri.host, wsuri.port
  368. else:
  369. # If sock is given, host and port shouldn't be specified.
  370. host, port = None, None
  371. # If host and port are given, override values from the URI.
  372. host = kwargs.pop("host", host)
  373. port = kwargs.pop("port", port)
  374. create_connection = functools.partial(
  375. loop.create_connection, factory, host, port, **kwargs
  376. )
  377. else:
  378. create_connection = functools.partial(
  379. loop.create_unix_connection, factory, path, **kwargs
  380. )
  381. # This is a coroutine function.
  382. self._create_connection = create_connection
  383. self._wsuri = wsuri
  384. self._origin = origin
  385. def handle_redirect(self, uri: str) -> None:
  386. # Update the state of this instance to connect to a new URI.
  387. old_wsuri = self._wsuri
  388. new_wsuri = parse_uri(uri)
  389. # Forbid TLS downgrade.
  390. if old_wsuri.secure and not new_wsuri.secure:
  391. raise SecurityError("redirect from WSS to WS")
  392. same_origin = (
  393. old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
  394. )
  395. # Rewrite the host and port arguments for cross-origin redirects.
  396. # This preserves connection overrides with the host and port
  397. # arguments if the redirect points to the same host and port.
  398. if not same_origin:
  399. # Replace the host and port argument passed to the protocol factory.
  400. factory = self._create_connection.args[0]
  401. factory = functools.partial(
  402. factory.func,
  403. *factory.args,
  404. **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
  405. )
  406. # Replace the host and port argument passed to create_connection.
  407. self._create_connection = functools.partial(
  408. self._create_connection.func,
  409. *(factory, new_wsuri.host, new_wsuri.port),
  410. **self._create_connection.keywords,
  411. )
  412. # Set the new WebSocket URI. This suffices for same-origin redirects.
  413. self._wsuri = new_wsuri
  414. # async with connect(...)
  415. async def __aenter__(self) -> WebSocketClientProtocol:
  416. return await self
  417. async def __aexit__(
  418. self,
  419. exc_type: Optional[Type[BaseException]],
  420. exc_value: Optional[BaseException],
  421. traceback: Optional[TracebackType],
  422. ) -> None:
  423. await self.ws_client.close()
  424. # await connect(...)
  425. def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
  426. # Create a suitable iterator by calling __await__ on a coroutine.
  427. return self.__await_impl__().__await__()
  428. async def __await_impl__(self) -> WebSocketClientProtocol:
  429. for redirects in range(self.MAX_REDIRECTS_ALLOWED):
  430. transport, protocol = await self._create_connection()
  431. # https://github.com/python/typeshed/pull/2756
  432. transport = cast(asyncio.Transport, transport)
  433. protocol = cast(WebSocketClientProtocol, protocol)
  434. try:
  435. try:
  436. await protocol.handshake(
  437. self._wsuri,
  438. origin=self._origin,
  439. available_extensions=protocol.available_extensions,
  440. available_subprotocols=protocol.available_subprotocols,
  441. extra_headers=protocol.extra_headers,
  442. )
  443. except Exception:
  444. protocol.fail_connection()
  445. await protocol.wait_closed()
  446. raise
  447. else:
  448. self.ws_client = protocol
  449. return protocol
  450. except RedirectHandshake as exc:
  451. self.handle_redirect(exc.uri)
  452. else:
  453. raise SecurityError("too many redirects")
  454. # yield from connect(...)
  455. __iter__ = __await__
  456. connect = Connect
  457. def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect:
  458. """
  459. Similar to :func:`connect`, but for connecting to a Unix socket.
  460. This function calls the event loop's
  461. :meth:`~asyncio.loop.create_unix_connection` method.
  462. It is only available on Unix.
  463. It's mainly useful for debugging servers listening on Unix sockets.
  464. :param path: file system path to the Unix socket
  465. :param uri: WebSocket URI
  466. """
  467. return connect(uri=uri, path=path, **kwargs)