Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.
 
 
 
 

746 wiersze
27 KiB

  1. from __future__ import annotations
  2. import contextlib
  3. import inspect
  4. import io
  5. import json
  6. import math
  7. import sys
  8. import warnings
  9. from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence
  10. from concurrent.futures import Future
  11. from contextlib import AbstractContextManager
  12. from types import GeneratorType
  13. from typing import (
  14. Any,
  15. Callable,
  16. Literal,
  17. TypedDict,
  18. Union,
  19. cast,
  20. )
  21. from urllib.parse import unquote, urljoin
  22. import anyio
  23. import anyio.abc
  24. import anyio.from_thread
  25. from anyio.streams.stapled import StapledObjectStream
  26. from starlette._utils import is_async_callable
  27. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  28. from starlette.websockets import WebSocketDisconnect
  29. if sys.version_info >= (3, 10): # pragma: no cover
  30. from typing import TypeGuard
  31. else: # pragma: no cover
  32. from typing_extensions import TypeGuard
  33. if sys.version_info >= (3, 11): # pragma: no cover
  34. from typing import Self
  35. else: # pragma: no cover
  36. from typing_extensions import Self
  37. try:
  38. import httpx
  39. except ModuleNotFoundError: # pragma: no cover
  40. raise RuntimeError(
  41. "The starlette.testclient module requires the httpx package to be installed.\n"
  42. "You can install this with:\n"
  43. " $ pip install httpx\n"
  44. )
  45. _PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
  46. ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
  47. ASGI2App = Callable[[Scope], ASGIInstance]
  48. ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
  49. _RequestData = Mapping[str, Union[str, Iterable[str], bytes]]
  50. def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
  51. if inspect.isclass(app):
  52. return hasattr(app, "__await__")
  53. return is_async_callable(app)
  54. class _WrapASGI2:
  55. """
  56. Provide an ASGI3 interface onto an ASGI2 app.
  57. """
  58. def __init__(self, app: ASGI2App) -> None:
  59. self.app = app
  60. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  61. instance = self.app(scope)
  62. await instance(receive, send)
  63. class _AsyncBackend(TypedDict):
  64. backend: str
  65. backend_options: dict[str, Any]
  66. class _Upgrade(Exception):
  67. def __init__(self, session: WebSocketTestSession) -> None:
  68. self.session = session
  69. class WebSocketDenialResponse( # type: ignore[misc]
  70. httpx.Response,
  71. WebSocketDisconnect,
  72. ):
  73. """
  74. A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
  75. `WebSocket` is closed before being accepted with a `send_denial_response()`.
  76. """
  77. class WebSocketTestSession:
  78. def __init__(
  79. self,
  80. app: ASGI3App,
  81. scope: Scope,
  82. portal_factory: _PortalFactoryType,
  83. ) -> None:
  84. self.app = app
  85. self.scope = scope
  86. self.accepted_subprotocol = None
  87. self.portal_factory = portal_factory
  88. self.extra_headers = None
  89. def __enter__(self) -> WebSocketTestSession:
  90. with contextlib.ExitStack() as stack:
  91. self.portal = portal = stack.enter_context(self.portal_factory())
  92. fut, cs = portal.start_task(self._run)
  93. stack.callback(fut.result)
  94. stack.callback(portal.call, cs.cancel)
  95. self.send({"type": "websocket.connect"})
  96. message = self.receive()
  97. self._raise_on_close(message)
  98. self.accepted_subprotocol = message.get("subprotocol", None)
  99. self.extra_headers = message.get("headers", None)
  100. stack.callback(self.close, 1000)
  101. self.exit_stack = stack.pop_all()
  102. return self
  103. def __exit__(self, *args: Any) -> bool | None:
  104. return self.exit_stack.__exit__(*args)
  105. async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
  106. """
  107. The sub-thread in which the websocket session runs.
  108. """
  109. send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
  110. send_tx, send_rx = send
  111. receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
  112. receive_tx, receive_rx = receive
  113. with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
  114. self._receive_tx = receive_tx
  115. self._send_rx = send_rx
  116. task_status.started(cs)
  117. await self.app(self.scope, receive_rx.receive, send_tx.send)
  118. # wait for cs.cancel to be called before closing streams
  119. await anyio.sleep_forever()
  120. def _raise_on_close(self, message: Message) -> None:
  121. if message["type"] == "websocket.close":
  122. raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
  123. elif message["type"] == "websocket.http.response.start":
  124. status_code: int = message["status"]
  125. headers: list[tuple[bytes, bytes]] = message["headers"]
  126. body: list[bytes] = []
  127. while True:
  128. message = self.receive()
  129. assert message["type"] == "websocket.http.response.body"
  130. body.append(message["body"])
  131. if not message.get("more_body", False):
  132. break
  133. raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
  134. def send(self, message: Message) -> None:
  135. self.portal.call(self._receive_tx.send, message)
  136. def send_text(self, data: str) -> None:
  137. self.send({"type": "websocket.receive", "text": data})
  138. def send_bytes(self, data: bytes) -> None:
  139. self.send({"type": "websocket.receive", "bytes": data})
  140. def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
  141. text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
  142. if mode == "text":
  143. self.send({"type": "websocket.receive", "text": text})
  144. else:
  145. self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
  146. def close(self, code: int = 1000, reason: str | None = None) -> None:
  147. self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
  148. def receive(self) -> Message:
  149. return self.portal.call(self._send_rx.receive)
  150. def receive_text(self) -> str:
  151. message = self.receive()
  152. self._raise_on_close(message)
  153. return cast(str, message["text"])
  154. def receive_bytes(self) -> bytes:
  155. message = self.receive()
  156. self._raise_on_close(message)
  157. return cast(bytes, message["bytes"])
  158. def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
  159. message = self.receive()
  160. self._raise_on_close(message)
  161. if mode == "text":
  162. text = message["text"]
  163. else:
  164. text = message["bytes"].decode("utf-8")
  165. return json.loads(text)
  166. class _TestClientTransport(httpx.BaseTransport):
  167. def __init__(
  168. self,
  169. app: ASGI3App,
  170. portal_factory: _PortalFactoryType,
  171. raise_server_exceptions: bool = True,
  172. root_path: str = "",
  173. *,
  174. client: tuple[str, int],
  175. app_state: dict[str, Any],
  176. ) -> None:
  177. self.app = app
  178. self.raise_server_exceptions = raise_server_exceptions
  179. self.root_path = root_path
  180. self.portal_factory = portal_factory
  181. self.app_state = app_state
  182. self.client = client
  183. def handle_request(self, request: httpx.Request) -> httpx.Response:
  184. scheme = request.url.scheme
  185. netloc = request.url.netloc.decode(encoding="ascii")
  186. path = request.url.path
  187. raw_path = request.url.raw_path
  188. query = request.url.query.decode(encoding="ascii")
  189. default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
  190. if ":" in netloc:
  191. host, port_string = netloc.split(":", 1)
  192. port = int(port_string)
  193. else:
  194. host = netloc
  195. port = default_port
  196. # Include the 'host' header.
  197. if "host" in request.headers:
  198. headers: list[tuple[bytes, bytes]] = []
  199. elif port == default_port: # pragma: no cover
  200. headers = [(b"host", host.encode())]
  201. else: # pragma: no cover
  202. headers = [(b"host", (f"{host}:{port}").encode())]
  203. # Include other request headers.
  204. headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
  205. scope: dict[str, Any]
  206. if scheme in {"ws", "wss"}:
  207. subprotocol = request.headers.get("sec-websocket-protocol", None)
  208. if subprotocol is None:
  209. subprotocols: Sequence[str] = []
  210. else:
  211. subprotocols = [value.strip() for value in subprotocol.split(",")]
  212. scope = {
  213. "type": "websocket",
  214. "path": unquote(path),
  215. "raw_path": raw_path.split(b"?", 1)[0],
  216. "root_path": self.root_path,
  217. "scheme": scheme,
  218. "query_string": query.encode(),
  219. "headers": headers,
  220. "client": self.client,
  221. "server": [host, port],
  222. "subprotocols": subprotocols,
  223. "state": self.app_state.copy(),
  224. "extensions": {"websocket.http.response": {}},
  225. }
  226. session = WebSocketTestSession(self.app, scope, self.portal_factory)
  227. raise _Upgrade(session)
  228. scope = {
  229. "type": "http",
  230. "http_version": "1.1",
  231. "method": request.method,
  232. "path": unquote(path),
  233. "raw_path": raw_path.split(b"?", 1)[0],
  234. "root_path": self.root_path,
  235. "scheme": scheme,
  236. "query_string": query.encode(),
  237. "headers": headers,
  238. "client": self.client,
  239. "server": [host, port],
  240. "extensions": {"http.response.debug": {}},
  241. "state": self.app_state.copy(),
  242. }
  243. request_complete = False
  244. response_started = False
  245. response_complete: anyio.Event
  246. raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
  247. template = None
  248. context = None
  249. async def receive() -> Message:
  250. nonlocal request_complete
  251. if request_complete:
  252. if not response_complete.is_set():
  253. await response_complete.wait()
  254. return {"type": "http.disconnect"}
  255. body = request.read()
  256. if isinstance(body, str):
  257. body_bytes: bytes = body.encode("utf-8") # pragma: no cover
  258. elif body is None:
  259. body_bytes = b"" # pragma: no cover
  260. elif isinstance(body, GeneratorType):
  261. try: # pragma: no cover
  262. chunk = body.send(None)
  263. if isinstance(chunk, str):
  264. chunk = chunk.encode("utf-8")
  265. return {"type": "http.request", "body": chunk, "more_body": True}
  266. except StopIteration: # pragma: no cover
  267. request_complete = True
  268. return {"type": "http.request", "body": b""}
  269. else:
  270. body_bytes = body
  271. request_complete = True
  272. return {"type": "http.request", "body": body_bytes}
  273. async def send(message: Message) -> None:
  274. nonlocal raw_kwargs, response_started, template, context
  275. if message["type"] == "http.response.start":
  276. assert not response_started, 'Received multiple "http.response.start" messages.'
  277. raw_kwargs["status_code"] = message["status"]
  278. raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
  279. response_started = True
  280. elif message["type"] == "http.response.body":
  281. assert response_started, 'Received "http.response.body" without "http.response.start".'
  282. assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
  283. body = message.get("body", b"")
  284. more_body = message.get("more_body", False)
  285. if request.method != "HEAD":
  286. raw_kwargs["stream"].write(body)
  287. if not more_body:
  288. raw_kwargs["stream"].seek(0)
  289. response_complete.set()
  290. elif message["type"] == "http.response.debug":
  291. template = message["info"]["template"]
  292. context = message["info"]["context"]
  293. try:
  294. with self.portal_factory() as portal:
  295. response_complete = portal.call(anyio.Event)
  296. portal.call(self.app, scope, receive, send)
  297. except BaseException as exc:
  298. if self.raise_server_exceptions:
  299. raise exc
  300. if self.raise_server_exceptions:
  301. assert response_started, "TestClient did not receive any response."
  302. elif not response_started:
  303. raw_kwargs = {
  304. "status_code": 500,
  305. "headers": [],
  306. "stream": io.BytesIO(),
  307. }
  308. raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
  309. response = httpx.Response(**raw_kwargs, request=request)
  310. if template is not None:
  311. response.template = template # type: ignore[attr-defined]
  312. response.context = context # type: ignore[attr-defined]
  313. return response
  314. class TestClient(httpx.Client):
  315. __test__ = False
  316. task: Future[None]
  317. portal: anyio.abc.BlockingPortal | None = None
  318. def __init__(
  319. self,
  320. app: ASGIApp,
  321. base_url: str = "http://testserver",
  322. raise_server_exceptions: bool = True,
  323. root_path: str = "",
  324. backend: Literal["asyncio", "trio"] = "asyncio",
  325. backend_options: dict[str, Any] | None = None,
  326. cookies: httpx._types.CookieTypes | None = None,
  327. headers: dict[str, str] | None = None,
  328. follow_redirects: bool = True,
  329. client: tuple[str, int] = ("testclient", 50000),
  330. ) -> None:
  331. self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
  332. if _is_asgi3(app):
  333. asgi_app = app
  334. else:
  335. app = cast(ASGI2App, app) # type: ignore[assignment]
  336. asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
  337. self.app = asgi_app
  338. self.app_state: dict[str, Any] = {}
  339. transport = _TestClientTransport(
  340. self.app,
  341. portal_factory=self._portal_factory,
  342. raise_server_exceptions=raise_server_exceptions,
  343. root_path=root_path,
  344. app_state=self.app_state,
  345. client=client,
  346. )
  347. if headers is None:
  348. headers = {}
  349. headers.setdefault("user-agent", "testclient")
  350. super().__init__(
  351. base_url=base_url,
  352. headers=headers,
  353. transport=transport,
  354. follow_redirects=follow_redirects,
  355. cookies=cookies,
  356. )
  357. @contextlib.contextmanager
  358. def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
  359. if self.portal is not None:
  360. yield self.portal
  361. else:
  362. with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
  363. yield portal
  364. def request( # type: ignore[override]
  365. self,
  366. method: str,
  367. url: httpx._types.URLTypes,
  368. *,
  369. content: httpx._types.RequestContent | None = None,
  370. data: _RequestData | None = None,
  371. files: httpx._types.RequestFiles | None = None,
  372. json: Any = None,
  373. params: httpx._types.QueryParamTypes | None = None,
  374. headers: httpx._types.HeaderTypes | None = None,
  375. cookies: httpx._types.CookieTypes | None = None,
  376. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  377. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  378. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  379. extensions: dict[str, Any] | None = None,
  380. ) -> httpx.Response:
  381. if timeout is not httpx.USE_CLIENT_DEFAULT:
  382. warnings.warn(
  383. "You should not use the 'timeout' argument with the TestClient. "
  384. "See https://github.com/encode/starlette/issues/1108 for more information.",
  385. DeprecationWarning,
  386. )
  387. url = self._merge_url(url)
  388. return super().request(
  389. method,
  390. url,
  391. content=content,
  392. data=data,
  393. files=files,
  394. json=json,
  395. params=params,
  396. headers=headers,
  397. cookies=cookies,
  398. auth=auth,
  399. follow_redirects=follow_redirects,
  400. timeout=timeout,
  401. extensions=extensions,
  402. )
  403. def get( # type: ignore[override]
  404. self,
  405. url: httpx._types.URLTypes,
  406. *,
  407. params: httpx._types.QueryParamTypes | None = None,
  408. headers: httpx._types.HeaderTypes | None = None,
  409. cookies: httpx._types.CookieTypes | None = None,
  410. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  411. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  412. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  413. extensions: dict[str, Any] | None = None,
  414. ) -> httpx.Response:
  415. return super().get(
  416. url,
  417. params=params,
  418. headers=headers,
  419. cookies=cookies,
  420. auth=auth,
  421. follow_redirects=follow_redirects,
  422. timeout=timeout,
  423. extensions=extensions,
  424. )
  425. def options( # type: ignore[override]
  426. self,
  427. url: httpx._types.URLTypes,
  428. *,
  429. params: httpx._types.QueryParamTypes | None = None,
  430. headers: httpx._types.HeaderTypes | None = None,
  431. cookies: httpx._types.CookieTypes | None = None,
  432. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  433. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  434. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  435. extensions: dict[str, Any] | None = None,
  436. ) -> httpx.Response:
  437. return super().options(
  438. url,
  439. params=params,
  440. headers=headers,
  441. cookies=cookies,
  442. auth=auth,
  443. follow_redirects=follow_redirects,
  444. timeout=timeout,
  445. extensions=extensions,
  446. )
  447. def head( # type: ignore[override]
  448. self,
  449. url: httpx._types.URLTypes,
  450. *,
  451. params: httpx._types.QueryParamTypes | None = None,
  452. headers: httpx._types.HeaderTypes | None = None,
  453. cookies: httpx._types.CookieTypes | None = None,
  454. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  455. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  456. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  457. extensions: dict[str, Any] | None = None,
  458. ) -> httpx.Response:
  459. return super().head(
  460. url,
  461. params=params,
  462. headers=headers,
  463. cookies=cookies,
  464. auth=auth,
  465. follow_redirects=follow_redirects,
  466. timeout=timeout,
  467. extensions=extensions,
  468. )
  469. def post( # type: ignore[override]
  470. self,
  471. url: httpx._types.URLTypes,
  472. *,
  473. content: httpx._types.RequestContent | None = None,
  474. data: _RequestData | None = None,
  475. files: httpx._types.RequestFiles | None = None,
  476. json: Any = None,
  477. params: httpx._types.QueryParamTypes | None = None,
  478. headers: httpx._types.HeaderTypes | None = None,
  479. cookies: httpx._types.CookieTypes | None = None,
  480. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  481. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  482. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  483. extensions: dict[str, Any] | None = None,
  484. ) -> httpx.Response:
  485. return super().post(
  486. url,
  487. content=content,
  488. data=data,
  489. files=files,
  490. json=json,
  491. params=params,
  492. headers=headers,
  493. cookies=cookies,
  494. auth=auth,
  495. follow_redirects=follow_redirects,
  496. timeout=timeout,
  497. extensions=extensions,
  498. )
  499. def put( # type: ignore[override]
  500. self,
  501. url: httpx._types.URLTypes,
  502. *,
  503. content: httpx._types.RequestContent | None = None,
  504. data: _RequestData | None = None,
  505. files: httpx._types.RequestFiles | None = None,
  506. json: Any = None,
  507. params: httpx._types.QueryParamTypes | None = None,
  508. headers: httpx._types.HeaderTypes | None = None,
  509. cookies: httpx._types.CookieTypes | None = None,
  510. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  511. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  512. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  513. extensions: dict[str, Any] | None = None,
  514. ) -> httpx.Response:
  515. return super().put(
  516. url,
  517. content=content,
  518. data=data,
  519. files=files,
  520. json=json,
  521. params=params,
  522. headers=headers,
  523. cookies=cookies,
  524. auth=auth,
  525. follow_redirects=follow_redirects,
  526. timeout=timeout,
  527. extensions=extensions,
  528. )
  529. def patch( # type: ignore[override]
  530. self,
  531. url: httpx._types.URLTypes,
  532. *,
  533. content: httpx._types.RequestContent | None = None,
  534. data: _RequestData | None = None,
  535. files: httpx._types.RequestFiles | None = None,
  536. json: Any = None,
  537. params: httpx._types.QueryParamTypes | None = None,
  538. headers: httpx._types.HeaderTypes | None = None,
  539. cookies: httpx._types.CookieTypes | None = None,
  540. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  541. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  542. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  543. extensions: dict[str, Any] | None = None,
  544. ) -> httpx.Response:
  545. return super().patch(
  546. url,
  547. content=content,
  548. data=data,
  549. files=files,
  550. json=json,
  551. params=params,
  552. headers=headers,
  553. cookies=cookies,
  554. auth=auth,
  555. follow_redirects=follow_redirects,
  556. timeout=timeout,
  557. extensions=extensions,
  558. )
  559. def delete( # type: ignore[override]
  560. self,
  561. url: httpx._types.URLTypes,
  562. *,
  563. params: httpx._types.QueryParamTypes | None = None,
  564. headers: httpx._types.HeaderTypes | None = None,
  565. cookies: httpx._types.CookieTypes | None = None,
  566. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  567. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  568. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  569. extensions: dict[str, Any] | None = None,
  570. ) -> httpx.Response:
  571. return super().delete(
  572. url,
  573. params=params,
  574. headers=headers,
  575. cookies=cookies,
  576. auth=auth,
  577. follow_redirects=follow_redirects,
  578. timeout=timeout,
  579. extensions=extensions,
  580. )
  581. def websocket_connect(
  582. self,
  583. url: str,
  584. subprotocols: Sequence[str] | None = None,
  585. **kwargs: Any,
  586. ) -> WebSocketTestSession:
  587. url = urljoin("ws://testserver", url)
  588. headers = kwargs.get("headers", {})
  589. headers.setdefault("connection", "upgrade")
  590. headers.setdefault("sec-websocket-key", "testserver==")
  591. headers.setdefault("sec-websocket-version", "13")
  592. if subprotocols is not None:
  593. headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
  594. kwargs["headers"] = headers
  595. try:
  596. super().request("GET", url, **kwargs)
  597. except _Upgrade as exc:
  598. session = exc.session
  599. else:
  600. raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
  601. return session
  602. def __enter__(self) -> Self:
  603. with contextlib.ExitStack() as stack:
  604. self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
  605. @stack.callback
  606. def reset_portal() -> None:
  607. self.portal = None
  608. send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
  609. anyio.create_memory_object_stream(math.inf)
  610. )
  611. receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
  612. math.inf
  613. )
  614. for channel in (*send, *receive):
  615. stack.callback(channel.close)
  616. self.stream_send = StapledObjectStream(*send)
  617. self.stream_receive = StapledObjectStream(*receive)
  618. self.task = portal.start_task_soon(self.lifespan)
  619. portal.call(self.wait_startup)
  620. @stack.callback
  621. def wait_shutdown() -> None:
  622. portal.call(self.wait_shutdown)
  623. self.exit_stack = stack.pop_all()
  624. return self
  625. def __exit__(self, *args: Any) -> None:
  626. self.exit_stack.close()
  627. async def lifespan(self) -> None:
  628. scope = {"type": "lifespan", "state": self.app_state}
  629. try:
  630. await self.app(scope, self.stream_receive.receive, self.stream_send.send)
  631. finally:
  632. await self.stream_send.send(None)
  633. async def wait_startup(self) -> None:
  634. await self.stream_receive.send({"type": "lifespan.startup"})
  635. async def receive() -> Any:
  636. message = await self.stream_send.receive()
  637. if message is None:
  638. self.task.result()
  639. return message
  640. message = await receive()
  641. assert message["type"] in (
  642. "lifespan.startup.complete",
  643. "lifespan.startup.failed",
  644. )
  645. if message["type"] == "lifespan.startup.failed":
  646. await receive()
  647. async def wait_shutdown(self) -> None:
  648. async def receive() -> Any:
  649. message = await self.stream_send.receive()
  650. if message is None:
  651. self.task.result()
  652. return message
  653. await self.stream_receive.send({"type": "lifespan.shutdown"})
  654. message = await receive()
  655. assert message["type"] in (
  656. "lifespan.shutdown.complete",
  657. "lifespan.shutdown.failed",
  658. )
  659. if message["type"] == "lifespan.shutdown.failed":
  660. await receive()