Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

124 řádky
5.0 KiB

  1. from __future__ import annotations
  2. import json
  3. from collections.abc import Generator
  4. from typing import Any, Callable
  5. from starlette import status
  6. from starlette._utils import is_async_callable
  7. from starlette.concurrency import run_in_threadpool
  8. from starlette.exceptions import HTTPException
  9. from starlette.requests import Request
  10. from starlette.responses import PlainTextResponse, Response
  11. from starlette.types import Message, Receive, Scope, Send
  12. from starlette.websockets import WebSocket
  13. class HTTPEndpoint:
  14. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  15. assert scope["type"] == "http"
  16. self.scope = scope
  17. self.receive = receive
  18. self.send = send
  19. self._allowed_methods = [
  20. method
  21. for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
  22. if getattr(self, method.lower(), None) is not None
  23. ]
  24. def __await__(self) -> Generator[Any, None, None]:
  25. return self.dispatch().__await__()
  26. async def dispatch(self) -> None:
  27. request = Request(self.scope, receive=self.receive)
  28. handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
  29. handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed)
  30. is_async = is_async_callable(handler)
  31. if is_async:
  32. response = await handler(request)
  33. else:
  34. response = await run_in_threadpool(handler, request)
  35. await response(self.scope, self.receive, self.send)
  36. async def method_not_allowed(self, request: Request) -> Response:
  37. # If we're running inside a starlette application then raise an
  38. # exception, so that the configurable exception handler can deal with
  39. # returning the response. For plain ASGI apps, just return the response.
  40. headers = {"Allow": ", ".join(self._allowed_methods)}
  41. if "app" in self.scope:
  42. raise HTTPException(status_code=405, headers=headers)
  43. return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
  44. class WebSocketEndpoint:
  45. encoding: str | None = None # May be "text", "bytes", or "json".
  46. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  47. assert scope["type"] == "websocket"
  48. self.scope = scope
  49. self.receive = receive
  50. self.send = send
  51. def __await__(self) -> Generator[Any, None, None]:
  52. return self.dispatch().__await__()
  53. async def dispatch(self) -> None:
  54. websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
  55. await self.on_connect(websocket)
  56. close_code = status.WS_1000_NORMAL_CLOSURE
  57. try:
  58. while True:
  59. message = await websocket.receive()
  60. if message["type"] == "websocket.receive":
  61. data = await self.decode(websocket, message)
  62. await self.on_receive(websocket, data)
  63. elif message["type"] == "websocket.disconnect": # pragma: no branch
  64. close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
  65. break
  66. except Exception as exc:
  67. close_code = status.WS_1011_INTERNAL_ERROR
  68. raise exc
  69. finally:
  70. await self.on_disconnect(websocket, close_code)
  71. async def decode(self, websocket: WebSocket, message: Message) -> Any:
  72. if self.encoding == "text":
  73. if "text" not in message:
  74. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  75. raise RuntimeError("Expected text websocket messages, but got bytes")
  76. return message["text"]
  77. elif self.encoding == "bytes":
  78. if "bytes" not in message:
  79. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  80. raise RuntimeError("Expected bytes websocket messages, but got text")
  81. return message["bytes"]
  82. elif self.encoding == "json":
  83. if message.get("text") is not None:
  84. text = message["text"]
  85. else:
  86. text = message["bytes"].decode("utf-8")
  87. try:
  88. return json.loads(text)
  89. except json.decoder.JSONDecodeError:
  90. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  91. raise RuntimeError("Malformed JSON data received.")
  92. assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
  93. return message["text"] if message.get("text") else message["bytes"]
  94. async def on_connect(self, websocket: WebSocket) -> None:
  95. """Override to handle an incoming websocket connection"""
  96. await websocket.accept()
  97. async def on_receive(self, websocket: WebSocket, data: Any) -> None:
  98. """Override to handle an incoming websocket message"""
  99. async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
  100. """Override to handle a disconnecting websocket"""