You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

151 lines
5.5 KiB

  1. import enum
  2. import json
  3. import typing
  4. from starlette.requests import HTTPConnection
  5. from starlette.types import Message, Receive, Scope, Send
  6. class WebSocketState(enum.Enum):
  7. CONNECTING = 0
  8. CONNECTED = 1
  9. DISCONNECTED = 2
  10. class WebSocketDisconnect(Exception):
  11. def __init__(self, code: int = 1000) -> None:
  12. self.code = code
  13. class WebSocket(HTTPConnection):
  14. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  15. super().__init__(scope)
  16. assert scope["type"] == "websocket"
  17. self._receive = receive
  18. self._send = send
  19. self.client_state = WebSocketState.CONNECTING
  20. self.application_state = WebSocketState.CONNECTING
  21. async def receive(self) -> Message:
  22. """
  23. Receive ASGI websocket messages, ensuring valid state transitions.
  24. """
  25. if self.client_state == WebSocketState.CONNECTING:
  26. message = await self._receive()
  27. message_type = message["type"]
  28. assert message_type == "websocket.connect"
  29. self.client_state = WebSocketState.CONNECTED
  30. return message
  31. elif self.client_state == WebSocketState.CONNECTED:
  32. message = await self._receive()
  33. message_type = message["type"]
  34. assert message_type in {"websocket.receive", "websocket.disconnect"}
  35. if message_type == "websocket.disconnect":
  36. self.client_state = WebSocketState.DISCONNECTED
  37. return message
  38. else:
  39. raise RuntimeError(
  40. 'Cannot call "receive" once a disconnect message has been received.'
  41. )
  42. async def send(self, message: Message) -> None:
  43. """
  44. Send ASGI websocket messages, ensuring valid state transitions.
  45. """
  46. if self.application_state == WebSocketState.CONNECTING:
  47. message_type = message["type"]
  48. assert message_type in {"websocket.accept", "websocket.close"}
  49. if message_type == "websocket.close":
  50. self.application_state = WebSocketState.DISCONNECTED
  51. else:
  52. self.application_state = WebSocketState.CONNECTED
  53. await self._send(message)
  54. elif self.application_state == WebSocketState.CONNECTED:
  55. message_type = message["type"]
  56. assert message_type in {"websocket.send", "websocket.close"}
  57. if message_type == "websocket.close":
  58. self.application_state = WebSocketState.DISCONNECTED
  59. await self._send(message)
  60. else:
  61. raise RuntimeError('Cannot call "send" once a close message has been sent.')
  62. async def accept(self, subprotocol: str = None) -> None:
  63. if self.client_state == WebSocketState.CONNECTING:
  64. # If we haven't yet seen the 'connect' message, then wait for it first.
  65. await self.receive()
  66. await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
  67. def _raise_on_disconnect(self, message: Message) -> None:
  68. if message["type"] == "websocket.disconnect":
  69. raise WebSocketDisconnect(message["code"])
  70. async def receive_text(self) -> str:
  71. assert self.application_state == WebSocketState.CONNECTED
  72. message = await self.receive()
  73. self._raise_on_disconnect(message)
  74. return message["text"]
  75. async def receive_bytes(self) -> bytes:
  76. assert self.application_state == WebSocketState.CONNECTED
  77. message = await self.receive()
  78. self._raise_on_disconnect(message)
  79. return message["bytes"]
  80. async def receive_json(self, mode: str = "text") -> typing.Any:
  81. assert mode in ["text", "binary"]
  82. assert self.application_state == WebSocketState.CONNECTED
  83. message = await self.receive()
  84. self._raise_on_disconnect(message)
  85. if mode == "text":
  86. text = message["text"]
  87. else:
  88. text = message["bytes"].decode("utf-8")
  89. return json.loads(text)
  90. async def iter_text(self) -> typing.AsyncIterator[str]:
  91. try:
  92. while True:
  93. yield await self.receive_text()
  94. except WebSocketDisconnect:
  95. pass
  96. async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
  97. try:
  98. while True:
  99. yield await self.receive_bytes()
  100. except WebSocketDisconnect:
  101. pass
  102. async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
  103. try:
  104. while True:
  105. yield await self.receive_json()
  106. except WebSocketDisconnect:
  107. pass
  108. async def send_text(self, data: str) -> None:
  109. await self.send({"type": "websocket.send", "text": data})
  110. async def send_bytes(self, data: bytes) -> None:
  111. await self.send({"type": "websocket.send", "bytes": data})
  112. async def send_json(self, data: typing.Any, mode: str = "text") -> None:
  113. assert mode in ["text", "binary"]
  114. text = json.dumps(data)
  115. if mode == "text":
  116. await self.send({"type": "websocket.send", "text": text})
  117. else:
  118. await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
  119. async def close(self, code: int = 1000) -> None:
  120. await self.send({"type": "websocket.close", "code": code})
  121. class WebSocketClose:
  122. def __init__(self, code: int = 1000) -> None:
  123. self.code = code
  124. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  125. await send({"type": "websocket.close", "code": self.code})