You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

118 lines
4.5 KiB

  1. import asyncio
  2. import json
  3. import typing
  4. from starlette import status
  5. from starlette.concurrency import run_in_threadpool
  6. from starlette.exceptions import HTTPException
  7. from starlette.requests import Request
  8. from starlette.responses import PlainTextResponse, Response
  9. from starlette.types import Message, Receive, Scope, Send
  10. from starlette.websockets import WebSocket
  11. class HTTPEndpoint:
  12. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  13. assert scope["type"] == "http"
  14. self.scope = scope
  15. self.receive = receive
  16. self.send = send
  17. def __await__(self) -> typing.Generator:
  18. return self.dispatch().__await__()
  19. async def dispatch(self) -> None:
  20. request = Request(self.scope, receive=self.receive)
  21. handler_name = "get" if request.method == "HEAD" else request.method.lower()
  22. handler = getattr(self, handler_name, self.method_not_allowed)
  23. is_async = asyncio.iscoroutinefunction(handler)
  24. if is_async:
  25. response = await handler(request)
  26. else:
  27. response = await run_in_threadpool(handler, request)
  28. await response(self.scope, self.receive, self.send)
  29. async def method_not_allowed(self, request: Request) -> Response:
  30. # If we're running inside a starlette application then raise an
  31. # exception, so that the configurable exception handler can deal with
  32. # returning the response. For plain ASGI apps, just return the response.
  33. if "app" in self.scope:
  34. raise HTTPException(status_code=405)
  35. return PlainTextResponse("Method Not Allowed", status_code=405)
  36. class WebSocketEndpoint:
  37. encoding: typing.Optional[str] = None # May be "text", "bytes", or "json".
  38. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  39. assert scope["type"] == "websocket"
  40. self.scope = scope
  41. self.receive = receive
  42. self.send = send
  43. def __await__(self) -> typing.Generator:
  44. return self.dispatch().__await__()
  45. async def dispatch(self) -> None:
  46. websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
  47. await self.on_connect(websocket)
  48. close_code = status.WS_1000_NORMAL_CLOSURE
  49. try:
  50. while True:
  51. message = await websocket.receive()
  52. if message["type"] == "websocket.receive":
  53. data = await self.decode(websocket, message)
  54. await self.on_receive(websocket, data)
  55. elif message["type"] == "websocket.disconnect":
  56. close_code = int(message.get("code", status.WS_1000_NORMAL_CLOSURE))
  57. break
  58. except Exception as exc:
  59. close_code = status.WS_1011_INTERNAL_ERROR
  60. raise exc
  61. finally:
  62. await self.on_disconnect(websocket, close_code)
  63. async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
  64. if self.encoding == "text":
  65. if "text" not in message:
  66. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  67. raise RuntimeError("Expected text websocket messages, but got bytes")
  68. return message["text"]
  69. elif self.encoding == "bytes":
  70. if "bytes" not in message:
  71. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  72. raise RuntimeError("Expected bytes websocket messages, but got text")
  73. return message["bytes"]
  74. elif self.encoding == "json":
  75. if message.get("text") is not None:
  76. text = message["text"]
  77. else:
  78. text = message["bytes"].decode("utf-8")
  79. try:
  80. return json.loads(text)
  81. except json.decoder.JSONDecodeError:
  82. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  83. raise RuntimeError("Malformed JSON data received.")
  84. assert (
  85. self.encoding is None
  86. ), f"Unsupported 'encoding' attribute {self.encoding}"
  87. return message["text"] if message.get("text") else message["bytes"]
  88. async def on_connect(self, websocket: WebSocket) -> None:
  89. """Override to handle an incoming websocket connection"""
  90. await websocket.accept()
  91. async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
  92. """Override to handle an incoming websocket message"""
  93. async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
  94. """Override to handle a disconnecting websocket"""