您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

577 行
19 KiB

  1. import asyncio
  2. import contextlib
  3. import http
  4. import inspect
  5. import io
  6. import json
  7. import math
  8. import queue
  9. import sys
  10. import types
  11. import typing
  12. from concurrent.futures import Future
  13. from urllib.parse import unquote, urljoin, urlsplit
  14. import anyio.abc
  15. import requests
  16. from anyio.streams.stapled import StapledObjectStream
  17. from starlette.types import Message, Receive, Scope, Send
  18. from starlette.websockets import WebSocketDisconnect
  19. if sys.version_info >= (3, 8): # pragma: no cover
  20. from typing import TypedDict
  21. else: # pragma: no cover
  22. from typing_extensions import TypedDict
  23. _PortalFactoryType = typing.Callable[
  24. [], typing.ContextManager[anyio.abc.BlockingPortal]
  25. ]
  26. # Annotations for `Session.request()`
  27. Cookies = typing.Union[
  28. typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
  29. ]
  30. Params = typing.Union[bytes, typing.MutableMapping[str, str]]
  31. DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
  32. TimeOut = typing.Union[float, typing.Tuple[float, float]]
  33. FileType = typing.MutableMapping[str, typing.IO]
  34. AuthType = typing.Union[
  35. typing.Tuple[str, str],
  36. requests.auth.AuthBase,
  37. typing.Callable[[requests.PreparedRequest], requests.PreparedRequest],
  38. ]
  39. ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
  40. ASGI2App = typing.Callable[[Scope], ASGIInstance]
  41. ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
  42. class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
  43. def get_all(self, key: str, default: str) -> str:
  44. return self.getheaders(key)
  45. class _MockOriginalResponse:
  46. """
  47. We have to jump through some hoops to present the response as if
  48. it was made using urllib3.
  49. """
  50. def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
  51. self.msg = _HeaderDict(headers)
  52. self.closed = False
  53. def isclosed(self) -> bool:
  54. return self.closed
  55. class _Upgrade(Exception):
  56. def __init__(self, session: "WebSocketTestSession") -> None:
  57. self.session = session
  58. def _get_reason_phrase(status_code: int) -> str:
  59. try:
  60. return http.HTTPStatus(status_code).phrase
  61. except ValueError:
  62. return ""
  63. def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
  64. if inspect.isclass(app):
  65. return hasattr(app, "__await__")
  66. elif inspect.isfunction(app):
  67. return asyncio.iscoroutinefunction(app)
  68. call = getattr(app, "__call__", None)
  69. return asyncio.iscoroutinefunction(call)
  70. class _WrapASGI2:
  71. """
  72. Provide an ASGI3 interface onto an ASGI2 app.
  73. """
  74. def __init__(self, app: ASGI2App) -> None:
  75. self.app = app
  76. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  77. instance = self.app(scope)
  78. await instance(receive, send)
  79. class _AsyncBackend(TypedDict):
  80. backend: str
  81. backend_options: typing.Dict[str, typing.Any]
  82. class _ASGIAdapter(requests.adapters.HTTPAdapter):
  83. def __init__(
  84. self,
  85. app: ASGI3App,
  86. portal_factory: _PortalFactoryType,
  87. raise_server_exceptions: bool = True,
  88. root_path: str = "",
  89. ) -> None:
  90. self.app = app
  91. self.raise_server_exceptions = raise_server_exceptions
  92. self.root_path = root_path
  93. self.portal_factory = portal_factory
  94. def send(
  95. self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
  96. ) -> requests.Response:
  97. scheme, netloc, path, query, fragment = (
  98. str(item) for item in urlsplit(request.url)
  99. )
  100. default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
  101. if ":" in netloc:
  102. host, port_string = netloc.split(":", 1)
  103. port = int(port_string)
  104. else:
  105. host = netloc
  106. port = default_port
  107. # Include the 'host' header.
  108. if "host" in request.headers:
  109. headers: typing.List[typing.Tuple[bytes, bytes]] = []
  110. elif port == default_port:
  111. headers = [(b"host", host.encode())]
  112. else:
  113. headers = [(b"host", (f"{host}:{port}").encode())]
  114. # Include other request headers.
  115. headers += [
  116. (key.lower().encode(), value.encode())
  117. for key, value in request.headers.items()
  118. ]
  119. if scheme in {"ws", "wss"}:
  120. subprotocol = request.headers.get("sec-websocket-protocol", None)
  121. if subprotocol is None:
  122. subprotocols: typing.Sequence[str] = []
  123. else:
  124. subprotocols = [value.strip() for value in subprotocol.split(",")]
  125. scope = {
  126. "type": "websocket",
  127. "path": unquote(path),
  128. "root_path": self.root_path,
  129. "scheme": scheme,
  130. "query_string": query.encode(),
  131. "headers": headers,
  132. "client": ["testclient", 50000],
  133. "server": [host, port],
  134. "subprotocols": subprotocols,
  135. }
  136. session = WebSocketTestSession(self.app, scope, self.portal_factory)
  137. raise _Upgrade(session)
  138. scope = {
  139. "type": "http",
  140. "http_version": "1.1",
  141. "method": request.method,
  142. "path": unquote(path),
  143. "root_path": self.root_path,
  144. "scheme": scheme,
  145. "query_string": query.encode(),
  146. "headers": headers,
  147. "client": ["testclient", 50000],
  148. "server": [host, port],
  149. "extensions": {"http.response.template": {}},
  150. }
  151. request_complete = False
  152. response_started = False
  153. response_complete: anyio.Event
  154. raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
  155. template = None
  156. context = None
  157. async def receive() -> Message:
  158. nonlocal request_complete
  159. if request_complete:
  160. if not response_complete.is_set():
  161. await response_complete.wait()
  162. return {"type": "http.disconnect"}
  163. body = request.body
  164. if isinstance(body, str):
  165. body_bytes: bytes = body.encode("utf-8")
  166. elif body is None:
  167. body_bytes = b""
  168. elif isinstance(body, types.GeneratorType):
  169. try:
  170. chunk = body.send(None)
  171. if isinstance(chunk, str):
  172. chunk = chunk.encode("utf-8")
  173. return {"type": "http.request", "body": chunk, "more_body": True}
  174. except StopIteration:
  175. request_complete = True
  176. return {"type": "http.request", "body": b""}
  177. else:
  178. body_bytes = body
  179. request_complete = True
  180. return {"type": "http.request", "body": body_bytes}
  181. async def send(message: Message) -> None:
  182. nonlocal raw_kwargs, response_started, template, context
  183. if message["type"] == "http.response.start":
  184. assert (
  185. not response_started
  186. ), 'Received multiple "http.response.start" messages.'
  187. raw_kwargs["version"] = 11
  188. raw_kwargs["status"] = message["status"]
  189. raw_kwargs["reason"] = _get_reason_phrase(message["status"])
  190. raw_kwargs["headers"] = [
  191. (key.decode(), value.decode())
  192. for key, value in message.get("headers", [])
  193. ]
  194. raw_kwargs["preload_content"] = False
  195. raw_kwargs["original_response"] = _MockOriginalResponse(
  196. raw_kwargs["headers"]
  197. )
  198. response_started = True
  199. elif message["type"] == "http.response.body":
  200. assert (
  201. response_started
  202. ), 'Received "http.response.body" without "http.response.start".'
  203. assert (
  204. not response_complete.is_set()
  205. ), 'Received "http.response.body" after response completed.'
  206. body = message.get("body", b"")
  207. more_body = message.get("more_body", False)
  208. if request.method != "HEAD":
  209. raw_kwargs["body"].write(body)
  210. if not more_body:
  211. raw_kwargs["body"].seek(0)
  212. response_complete.set()
  213. elif message["type"] == "http.response.template":
  214. template = message["template"]
  215. context = message["context"]
  216. try:
  217. with self.portal_factory() as portal:
  218. response_complete = portal.call(anyio.Event)
  219. portal.call(self.app, scope, receive, send)
  220. except BaseException as exc:
  221. if self.raise_server_exceptions:
  222. raise exc
  223. if self.raise_server_exceptions:
  224. assert response_started, "TestClient did not receive any response."
  225. elif not response_started:
  226. raw_kwargs = {
  227. "version": 11,
  228. "status": 500,
  229. "reason": "Internal Server Error",
  230. "headers": [],
  231. "preload_content": False,
  232. "original_response": _MockOriginalResponse([]),
  233. "body": io.BytesIO(),
  234. }
  235. raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
  236. response = self.build_response(request, raw)
  237. if template is not None:
  238. response.template = template
  239. response.context = context
  240. return response
  241. class WebSocketTestSession:
  242. def __init__(
  243. self,
  244. app: ASGI3App,
  245. scope: Scope,
  246. portal_factory: _PortalFactoryType,
  247. ) -> None:
  248. self.app = app
  249. self.scope = scope
  250. self.accepted_subprotocol = None
  251. self.portal_factory = portal_factory
  252. self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
  253. self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
  254. def __enter__(self) -> "WebSocketTestSession":
  255. self.exit_stack = contextlib.ExitStack()
  256. self.portal = self.exit_stack.enter_context(self.portal_factory())
  257. try:
  258. _: "Future[None]" = self.portal.start_task_soon(self._run)
  259. self.send({"type": "websocket.connect"})
  260. message = self.receive()
  261. self._raise_on_close(message)
  262. except Exception:
  263. self.exit_stack.close()
  264. raise
  265. self.accepted_subprotocol = message.get("subprotocol", None)
  266. return self
  267. def __exit__(self, *args: typing.Any) -> None:
  268. try:
  269. self.close(1000)
  270. finally:
  271. self.exit_stack.close()
  272. while not self._send_queue.empty():
  273. message = self._send_queue.get()
  274. if isinstance(message, BaseException):
  275. raise message
  276. async def _run(self) -> None:
  277. """
  278. The sub-thread in which the websocket session runs.
  279. """
  280. scope = self.scope
  281. receive = self._asgi_receive
  282. send = self._asgi_send
  283. try:
  284. await self.app(scope, receive, send)
  285. except BaseException as exc:
  286. self._send_queue.put(exc)
  287. raise
  288. async def _asgi_receive(self) -> Message:
  289. while self._receive_queue.empty():
  290. await anyio.sleep(0)
  291. return self._receive_queue.get()
  292. async def _asgi_send(self, message: Message) -> None:
  293. self._send_queue.put(message)
  294. def _raise_on_close(self, message: Message) -> None:
  295. if message["type"] == "websocket.close":
  296. raise WebSocketDisconnect(message.get("code", 1000))
  297. def send(self, message: Message) -> None:
  298. self._receive_queue.put(message)
  299. def send_text(self, data: str) -> None:
  300. self.send({"type": "websocket.receive", "text": data})
  301. def send_bytes(self, data: bytes) -> None:
  302. self.send({"type": "websocket.receive", "bytes": data})
  303. def send_json(self, data: typing.Any, mode: str = "text") -> None:
  304. assert mode in ["text", "binary"]
  305. text = json.dumps(data)
  306. if mode == "text":
  307. self.send({"type": "websocket.receive", "text": text})
  308. else:
  309. self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
  310. def close(self, code: int = 1000) -> None:
  311. self.send({"type": "websocket.disconnect", "code": code})
  312. def receive(self) -> Message:
  313. message = self._send_queue.get()
  314. if isinstance(message, BaseException):
  315. raise message
  316. return message
  317. def receive_text(self) -> str:
  318. message = self.receive()
  319. self._raise_on_close(message)
  320. return message["text"]
  321. def receive_bytes(self) -> bytes:
  322. message = self.receive()
  323. self._raise_on_close(message)
  324. return message["bytes"]
  325. def receive_json(self, mode: str = "text") -> typing.Any:
  326. assert mode in ["text", "binary"]
  327. message = self.receive()
  328. self._raise_on_close(message)
  329. if mode == "text":
  330. text = message["text"]
  331. else:
  332. text = message["bytes"].decode("utf-8")
  333. return json.loads(text)
  334. class TestClient(requests.Session):
  335. __test__ = False # For pytest to not discover this up.
  336. task: "Future[None]"
  337. portal: typing.Optional[anyio.abc.BlockingPortal] = None
  338. def __init__(
  339. self,
  340. app: typing.Union[ASGI2App, ASGI3App],
  341. base_url: str = "http://testserver",
  342. raise_server_exceptions: bool = True,
  343. root_path: str = "",
  344. backend: str = "asyncio",
  345. backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
  346. ) -> None:
  347. super().__init__()
  348. self.async_backend = _AsyncBackend(
  349. backend=backend, backend_options=backend_options or {}
  350. )
  351. if _is_asgi3(app):
  352. app = typing.cast(ASGI3App, app)
  353. asgi_app = app
  354. else:
  355. app = typing.cast(ASGI2App, app)
  356. asgi_app = _WrapASGI2(app) #  type: ignore
  357. adapter = _ASGIAdapter(
  358. asgi_app,
  359. portal_factory=self._portal_factory,
  360. raise_server_exceptions=raise_server_exceptions,
  361. root_path=root_path,
  362. )
  363. self.mount("http://", adapter)
  364. self.mount("https://", adapter)
  365. self.mount("ws://", adapter)
  366. self.mount("wss://", adapter)
  367. self.headers.update({"user-agent": "testclient"})
  368. self.app = asgi_app
  369. self.base_url = base_url
  370. @contextlib.contextmanager
  371. def _portal_factory(
  372. self,
  373. ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
  374. if self.portal is not None:
  375. yield self.portal
  376. else:
  377. with anyio.start_blocking_portal(**self.async_backend) as portal:
  378. yield portal
  379. def request( # type: ignore
  380. self,
  381. method: str,
  382. url: str,
  383. params: Params = None,
  384. data: DataType = None,
  385. headers: typing.MutableMapping[str, str] = None,
  386. cookies: Cookies = None,
  387. files: FileType = None,
  388. auth: AuthType = None,
  389. timeout: TimeOut = None,
  390. allow_redirects: bool = None,
  391. proxies: typing.MutableMapping[str, str] = None,
  392. hooks: typing.Any = None,
  393. stream: bool = None,
  394. verify: typing.Union[bool, str] = None,
  395. cert: typing.Union[str, typing.Tuple[str, str]] = None,
  396. json: typing.Any = None,
  397. ) -> requests.Response:
  398. url = urljoin(self.base_url, url)
  399. return super().request(
  400. method,
  401. url,
  402. params=params,
  403. data=data,
  404. headers=headers,
  405. cookies=cookies,
  406. files=files,
  407. auth=auth,
  408. timeout=timeout,
  409. allow_redirects=allow_redirects,
  410. proxies=proxies,
  411. hooks=hooks,
  412. stream=stream,
  413. verify=verify,
  414. cert=cert,
  415. json=json,
  416. )
  417. def websocket_connect(
  418. self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
  419. ) -> typing.Any:
  420. url = urljoin("ws://testserver", url)
  421. headers = kwargs.get("headers", {})
  422. headers.setdefault("connection", "upgrade")
  423. headers.setdefault("sec-websocket-key", "testserver==")
  424. headers.setdefault("sec-websocket-version", "13")
  425. if subprotocols is not None:
  426. headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
  427. kwargs["headers"] = headers
  428. try:
  429. super().request("GET", url, **kwargs)
  430. except _Upgrade as exc:
  431. session = exc.session
  432. else:
  433. raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
  434. return session
  435. def __enter__(self) -> "TestClient":
  436. with contextlib.ExitStack() as stack:
  437. self.portal = portal = stack.enter_context(
  438. anyio.start_blocking_portal(**self.async_backend)
  439. )
  440. @stack.callback
  441. def reset_portal() -> None:
  442. self.portal = None
  443. self.stream_send = StapledObjectStream(
  444. *anyio.create_memory_object_stream(math.inf)
  445. )
  446. self.stream_receive = StapledObjectStream(
  447. *anyio.create_memory_object_stream(math.inf)
  448. )
  449. self.task = portal.start_task_soon(self.lifespan)
  450. portal.call(self.wait_startup)
  451. @stack.callback
  452. def wait_shutdown() -> None:
  453. portal.call(self.wait_shutdown)
  454. self.exit_stack = stack.pop_all()
  455. return self
  456. def __exit__(self, *args: typing.Any) -> None:
  457. self.exit_stack.close()
  458. async def lifespan(self) -> None:
  459. scope = {"type": "lifespan"}
  460. try:
  461. await self.app(scope, self.stream_receive.receive, self.stream_send.send)
  462. finally:
  463. await self.stream_send.send(None)
  464. async def wait_startup(self) -> None:
  465. await self.stream_receive.send({"type": "lifespan.startup"})
  466. async def receive() -> typing.Any:
  467. message = await self.stream_send.receive()
  468. if message is None:
  469. self.task.result()
  470. return message
  471. message = await receive()
  472. assert message["type"] in (
  473. "lifespan.startup.complete",
  474. "lifespan.startup.failed",
  475. )
  476. if message["type"] == "lifespan.startup.failed":
  477. await receive()
  478. async def wait_shutdown(self) -> None:
  479. async def receive() -> typing.Any:
  480. message = await self.stream_send.receive()
  481. if message is None:
  482. self.task.result()
  483. return message
  484. async with self.stream_send:
  485. await self.stream_receive.send({"type": "lifespan.shutdown"})
  486. message = await receive()
  487. assert message["type"] in (
  488. "lifespan.shutdown.complete",
  489. "lifespan.shutdown.failed",
  490. )
  491. if message["type"] == "lifespan.shutdown.failed":
  492. await receive()