Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 
 

151 righe
5.5 KiB

  1. import enum
  2. import json
  3. import typing
  4. from starlette.requests import HTTPConnection
  5. from starlette.types import Message, Receive, Scope, Send
  6. class WebSocketState(enum.Enum):
  7. CONNECTING = 0
  8. CONNECTED = 1
  9. DISCONNECTED = 2
  10. class WebSocketDisconnect(Exception):
  11. def __init__(self, code: int = 1000) -> None:
  12. self.code = code
  13. class WebSocket(HTTPConnection):
  14. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  15. super().__init__(scope)
  16. assert scope["type"] == "websocket"
  17. self._receive = receive
  18. self._send = send
  19. self.client_state = WebSocketState.CONNECTING
  20. self.application_state = WebSocketState.CONNECTING
  21. async def receive(self) -> Message:
  22. """
  23. Receive ASGI websocket messages, ensuring valid state transitions.
  24. """
  25. if self.client_state == WebSocketState.CONNECTING:
  26. message = await self._receive()
  27. message_type = message["type"]
  28. assert message_type == "websocket.connect"
  29. self.client_state = WebSocketState.CONNECTED
  30. return message
  31. elif self.client_state == WebSocketState.CONNECTED:
  32. message = await self._receive()
  33. message_type = message["type"]
  34. assert message_type in {"websocket.receive", "websocket.disconnect"}
  35. if message_type == "websocket.disconnect":
  36. self.client_state = WebSocketState.DISCONNECTED
  37. return message
  38. else:
  39. raise RuntimeError(
  40. 'Cannot call "receive" once a disconnect message has been received.'
  41. )
  42. async def send(self, message: Message) -> None:
  43. """
  44. Send ASGI websocket messages, ensuring valid state transitions.
  45. """
  46. if self.application_state == WebSocketState.CONNECTING:
  47. message_type = message["type"]
  48. assert message_type in {"websocket.accept", "websocket.close"}
  49. if message_type == "websocket.close":
  50. self.application_state = WebSocketState.DISCONNECTED
  51. else:
  52. self.application_state = WebSocketState.CONNECTED
  53. await self._send(message)
  54. elif self.application_state == WebSocketState.CONNECTED:
  55. message_type = message["type"]
  56. assert message_type in {"websocket.send", "websocket.close"}
  57. if message_type == "websocket.close":
  58. self.application_state = WebSocketState.DISCONNECTED
  59. await self._send(message)
  60. else:
  61. raise RuntimeError('Cannot call "send" once a close message has been sent.')
  62. async def accept(self, subprotocol: str = None) -> None:
  63. if self.client_state == WebSocketState.CONNECTING:
  64. # If we haven't yet seen the 'connect' message, then wait for it first.
  65. await self.receive()
  66. await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
  67. def _raise_on_disconnect(self, message: Message) -> None:
  68. if message["type"] == "websocket.disconnect":
  69. raise WebSocketDisconnect(message["code"])
  70. async def receive_text(self) -> str:
  71. assert self.application_state == WebSocketState.CONNECTED
  72. message = await self.receive()
  73. self._raise_on_disconnect(message)
  74. return message["text"]
  75. async def receive_bytes(self) -> bytes:
  76. assert self.application_state == WebSocketState.CONNECTED
  77. message = await self.receive()
  78. self._raise_on_disconnect(message)
  79. return message["bytes"]
  80. async def receive_json(self, mode: str = "text") -> typing.Any:
  81. assert mode in ["text", "binary"]
  82. assert self.application_state == WebSocketState.CONNECTED
  83. message = await self.receive()
  84. self._raise_on_disconnect(message)
  85. if mode == "text":
  86. text = message["text"]
  87. else:
  88. text = message["bytes"].decode("utf-8")
  89. return json.loads(text)
  90. async def iter_text(self) -> typing.AsyncIterator[str]:
  91. try:
  92. while True:
  93. yield await self.receive_text()
  94. except WebSocketDisconnect:
  95. pass
  96. async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
  97. try:
  98. while True:
  99. yield await self.receive_bytes()
  100. except WebSocketDisconnect:
  101. pass
  102. async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
  103. try:
  104. while True:
  105. yield await self.receive_json()
  106. except WebSocketDisconnect:
  107. pass
  108. async def send_text(self, data: str) -> None:
  109. await self.send({"type": "websocket.send", "text": data})
  110. async def send_bytes(self, data: bytes) -> None:
  111. await self.send({"type": "websocket.send", "bytes": data})
  112. async def send_json(self, data: typing.Any, mode: str = "text") -> None:
  113. assert mode in ["text", "binary"]
  114. text = json.dumps(data)
  115. if mode == "text":
  116. await self.send({"type": "websocket.send", "text": text})
  117. else:
  118. await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
  119. async def close(self, code: int = 1000) -> None:
  120. await self.send({"type": "websocket.close", "code": code})
  121. class WebSocketClose:
  122. def __init__(self, code: int = 1000) -> None:
  123. self.code = code
  124. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  125. await send({"type": "websocket.close", "code": self.code})