|
- from __future__ import annotations
-
- import enum
- import json
- from collections.abc import AsyncIterator, Iterable
- from typing import Any, cast
-
- from starlette.requests import HTTPConnection
- from starlette.responses import Response
- from starlette.types import Message, Receive, Scope, Send
-
-
- class WebSocketState(enum.Enum):
- CONNECTING = 0
- CONNECTED = 1
- DISCONNECTED = 2
- RESPONSE = 3
-
-
- class WebSocketDisconnect(Exception):
- def __init__(self, code: int = 1000, reason: str | None = None) -> None:
- self.code = code
- self.reason = reason or ""
-
-
- class WebSocket(HTTPConnection):
- def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
- super().__init__(scope)
- assert scope["type"] == "websocket"
- self._receive = receive
- self._send = send
- self.client_state = WebSocketState.CONNECTING
- self.application_state = WebSocketState.CONNECTING
-
- async def receive(self) -> Message:
- """
- Receive ASGI websocket messages, ensuring valid state transitions.
- """
- if self.client_state == WebSocketState.CONNECTING:
- message = await self._receive()
- message_type = message["type"]
- if message_type != "websocket.connect":
- raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
- self.client_state = WebSocketState.CONNECTED
- return message
- elif self.client_state == WebSocketState.CONNECTED:
- message = await self._receive()
- message_type = message["type"]
- if message_type not in {"websocket.receive", "websocket.disconnect"}:
- raise RuntimeError(
- f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
- )
- if message_type == "websocket.disconnect":
- self.client_state = WebSocketState.DISCONNECTED
- return message
- else:
- raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
-
- async def send(self, message: Message) -> None:
- """
- Send ASGI websocket messages, ensuring valid state transitions.
- """
- if self.application_state == WebSocketState.CONNECTING:
- message_type = message["type"]
- if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
- raise RuntimeError(
- 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
- f"but got {message_type!r}"
- )
- if message_type == "websocket.close":
- self.application_state = WebSocketState.DISCONNECTED
- elif message_type == "websocket.http.response.start":
- self.application_state = WebSocketState.RESPONSE
- else:
- self.application_state = WebSocketState.CONNECTED
- await self._send(message)
- elif self.application_state == WebSocketState.CONNECTED:
- message_type = message["type"]
- if message_type not in {"websocket.send", "websocket.close"}:
- raise RuntimeError(
- f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
- )
- if message_type == "websocket.close":
- self.application_state = WebSocketState.DISCONNECTED
- try:
- await self._send(message)
- except OSError:
- self.application_state = WebSocketState.DISCONNECTED
- raise WebSocketDisconnect(code=1006)
- elif self.application_state == WebSocketState.RESPONSE:
- message_type = message["type"]
- if message_type != "websocket.http.response.body":
- raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
- if not message.get("more_body", False):
- self.application_state = WebSocketState.DISCONNECTED
- await self._send(message)
- else:
- raise RuntimeError('Cannot call "send" once a close message has been sent.')
-
- async def accept(
- self,
- subprotocol: str | None = None,
- headers: Iterable[tuple[bytes, bytes]] | None = None,
- ) -> None:
- headers = headers or []
-
- if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
- # If we haven't yet seen the 'connect' message, then wait for it first.
- await self.receive()
- await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
-
- def _raise_on_disconnect(self, message: Message) -> None:
- if message["type"] == "websocket.disconnect":
- raise WebSocketDisconnect(message["code"], message.get("reason"))
-
- async def receive_text(self) -> str:
- if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
- message = await self.receive()
- self._raise_on_disconnect(message)
- return cast(str, message["text"])
-
- async def receive_bytes(self) -> bytes:
- if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
- message = await self.receive()
- self._raise_on_disconnect(message)
- return cast(bytes, message["bytes"])
-
- async def receive_json(self, mode: str = "text") -> Any:
- if mode not in {"text", "binary"}:
- raise RuntimeError('The "mode" argument should be "text" or "binary".')
- if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
- message = await self.receive()
- self._raise_on_disconnect(message)
-
- if mode == "text":
- text = message["text"]
- else:
- text = message["bytes"].decode("utf-8")
- return json.loads(text)
-
- async def iter_text(self) -> AsyncIterator[str]:
- try:
- while True:
- yield await self.receive_text()
- except WebSocketDisconnect:
- pass
-
- async def iter_bytes(self) -> AsyncIterator[bytes]:
- try:
- while True:
- yield await self.receive_bytes()
- except WebSocketDisconnect:
- pass
-
- async def iter_json(self) -> AsyncIterator[Any]:
- try:
- while True:
- yield await self.receive_json()
- except WebSocketDisconnect:
- pass
-
- async def send_text(self, data: str) -> None:
- await self.send({"type": "websocket.send", "text": data})
-
- async def send_bytes(self, data: bytes) -> None:
- await self.send({"type": "websocket.send", "bytes": data})
-
- async def send_json(self, data: Any, mode: str = "text") -> None:
- if mode not in {"text", "binary"}:
- raise RuntimeError('The "mode" argument should be "text" or "binary".')
- text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
- if mode == "text":
- await self.send({"type": "websocket.send", "text": text})
- else:
- await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
-
- async def close(self, code: int = 1000, reason: str | None = None) -> None:
- await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
-
- async def send_denial_response(self, response: Response) -> None:
- if "websocket.http.response" in self.scope.get("extensions", {}):
- await response(self.scope, self.receive, self.send)
- else:
- raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
-
-
- class WebSocketClose:
- def __init__(self, code: int = 1000, reason: str | None = None) -> None:
- self.code = code
- self.reason = reason or ""
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
|