Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.
 
 
 
 

346 linhas
12 KiB

  1. from __future__ import annotations
  2. import codecs
  3. import queue
  4. import threading
  5. from typing import Any, Callable, Iterable, Iterator, Literal, overload
  6. from ..exceptions import ConcurrencyError
  7. from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
  8. from ..typing import Data
  9. from .utils import Deadline
  10. __all__ = ["Assembler"]
  11. UTF8Decoder = codecs.getincrementaldecoder("utf-8")
  12. class Assembler:
  13. """
  14. Assemble messages from frames.
  15. :class:`Assembler` expects only data frames. The stream of frames must
  16. respect the protocol; if it doesn't, the behavior is undefined.
  17. Args:
  18. pause: Called when the buffer of frames goes above the high water mark;
  19. should pause reading from the network.
  20. resume: Called when the buffer of frames goes below the low water mark;
  21. should resume reading from the network.
  22. """
  23. def __init__(
  24. self,
  25. high: int | None = None,
  26. low: int | None = None,
  27. pause: Callable[[], Any] = lambda: None,
  28. resume: Callable[[], Any] = lambda: None,
  29. ) -> None:
  30. # Serialize reads and writes -- except for reads via synchronization
  31. # primitives provided by the threading and queue modules.
  32. self.mutex = threading.Lock()
  33. # Queue of incoming frames.
  34. self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue()
  35. # We cannot put a hard limit on the size of the queue because a single
  36. # call to Protocol.data_received() could produce thousands of frames,
  37. # which must be buffered. Instead, we pause reading when the buffer goes
  38. # above the high limit and we resume when it goes under the low limit.
  39. if high is not None and low is None:
  40. low = high // 4
  41. if high is None and low is not None:
  42. high = low * 4
  43. if high is not None and low is not None:
  44. if low < 0:
  45. raise ValueError("low must be positive or equal to zero")
  46. if high < low:
  47. raise ValueError("high must be greater than or equal to low")
  48. self.high, self.low = high, low
  49. self.pause = pause
  50. self.resume = resume
  51. self.paused = False
  52. # This flag prevents concurrent calls to get() by user code.
  53. self.get_in_progress = False
  54. # This flag marks the end of the connection.
  55. self.closed = False
  56. def get_next_frame(self, timeout: float | None = None) -> Frame:
  57. # Helper to factor out the logic for getting the next frame from the
  58. # queue, while handling timeouts and reaching the end of the stream.
  59. if self.closed:
  60. try:
  61. frame = self.frames.get(block=False)
  62. except queue.Empty:
  63. raise EOFError("stream of frames ended") from None
  64. else:
  65. try:
  66. # Check for a frame that's already received if timeout <= 0.
  67. # SimpleQueue.get() doesn't support negative timeout values.
  68. if timeout is not None and timeout <= 0:
  69. frame = self.frames.get(block=False)
  70. else:
  71. frame = self.frames.get(block=True, timeout=timeout)
  72. except queue.Empty:
  73. raise TimeoutError(f"timed out in {timeout:.1f}s") from None
  74. if frame is None:
  75. raise EOFError("stream of frames ended")
  76. return frame
  77. def reset_queue(self, frames: Iterable[Frame]) -> None:
  78. # Helper to put frames back into the queue after they were fetched.
  79. # This happens only when the queue is empty. However, by the time
  80. # we acquire self.mutex, put() may have added items in the queue.
  81. # Therefore, we must handle the case where the queue is not empty.
  82. frame: Frame | None
  83. with self.mutex:
  84. queued = []
  85. try:
  86. while True:
  87. queued.append(self.frames.get(block=False))
  88. except queue.Empty:
  89. pass
  90. for frame in frames:
  91. self.frames.put(frame)
  92. # This loop runs only when a race condition occurs.
  93. for frame in queued: # pragma: no cover
  94. self.frames.put(frame)
  95. # This overload structure is required to avoid the error:
  96. # "parameter without a default follows parameter with a default"
  97. @overload
  98. def get(self, timeout: float | None, decode: Literal[True]) -> str: ...
  99. @overload
  100. def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ...
  101. @overload
  102. def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ...
  103. @overload
  104. def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ...
  105. @overload
  106. def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ...
  107. def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
  108. """
  109. Read the next message.
  110. :meth:`get` returns a single :class:`str` or :class:`bytes`.
  111. If the message is fragmented, :meth:`get` waits until the last frame is
  112. received, then it reassembles the message and returns it. To receive
  113. messages frame by frame, use :meth:`get_iter` instead.
  114. Args:
  115. timeout: If a timeout is provided and elapses before a complete
  116. message is received, :meth:`get` raises :exc:`TimeoutError`.
  117. decode: :obj:`False` disables UTF-8 decoding of text frames and
  118. returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
  119. binary frames and returns :class:`str`.
  120. Raises:
  121. EOFError: If the stream of frames has ended.
  122. UnicodeDecodeError: If a text frame contains invalid UTF-8.
  123. ConcurrencyError: If two coroutines run :meth:`get` or
  124. :meth:`get_iter` concurrently.
  125. TimeoutError: If a timeout is provided and elapses before a
  126. complete message is received.
  127. """
  128. with self.mutex:
  129. if self.get_in_progress:
  130. raise ConcurrencyError("get() or get_iter() is already running")
  131. self.get_in_progress = True
  132. # Locking with get_in_progress prevents concurrent execution
  133. # until get() fetches a complete message or times out.
  134. try:
  135. deadline = Deadline(timeout)
  136. # First frame
  137. frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False))
  138. with self.mutex:
  139. self.maybe_resume()
  140. assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
  141. if decode is None:
  142. decode = frame.opcode is OP_TEXT
  143. frames = [frame]
  144. # Following frames, for fragmented messages
  145. while not frame.fin:
  146. try:
  147. frame = self.get_next_frame(
  148. deadline.timeout(raise_if_elapsed=False)
  149. )
  150. except TimeoutError:
  151. # Put frames already received back into the queue
  152. # so that future calls to get() can return them.
  153. self.reset_queue(frames)
  154. raise
  155. with self.mutex:
  156. self.maybe_resume()
  157. assert frame.opcode is OP_CONT
  158. frames.append(frame)
  159. finally:
  160. self.get_in_progress = False
  161. data = b"".join(frame.data for frame in frames)
  162. if decode:
  163. return data.decode()
  164. else:
  165. return data
  166. @overload
  167. def get_iter(self, decode: Literal[True]) -> Iterator[str]: ...
  168. @overload
  169. def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ...
  170. @overload
  171. def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ...
  172. def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
  173. """
  174. Stream the next message.
  175. Iterating the return value of :meth:`get_iter` yields a :class:`str` or
  176. :class:`bytes` for each frame in the message.
  177. The iterator must be fully consumed before calling :meth:`get_iter` or
  178. :meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
  179. This method only makes sense for fragmented messages. If messages aren't
  180. fragmented, use :meth:`get` instead.
  181. Args:
  182. decode: :obj:`False` disables UTF-8 decoding of text frames and
  183. returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
  184. binary frames and returns :class:`str`.
  185. Raises:
  186. EOFError: If the stream of frames has ended.
  187. UnicodeDecodeError: If a text frame contains invalid UTF-8.
  188. ConcurrencyError: If two coroutines run :meth:`get` or
  189. :meth:`get_iter` concurrently.
  190. """
  191. with self.mutex:
  192. if self.get_in_progress:
  193. raise ConcurrencyError("get() or get_iter() is already running")
  194. self.get_in_progress = True
  195. # Locking with get_in_progress prevents concurrent execution
  196. # until get_iter() fetches a complete message or times out.
  197. # If get_iter() raises an exception e.g. in decoder.decode(),
  198. # get_in_progress remains set and the connection becomes unusable.
  199. # First frame
  200. frame = self.get_next_frame()
  201. with self.mutex:
  202. self.maybe_resume()
  203. assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
  204. if decode is None:
  205. decode = frame.opcode is OP_TEXT
  206. if decode:
  207. decoder = UTF8Decoder()
  208. yield decoder.decode(frame.data, frame.fin)
  209. else:
  210. yield frame.data
  211. # Following frames, for fragmented messages
  212. while not frame.fin:
  213. frame = self.get_next_frame()
  214. with self.mutex:
  215. self.maybe_resume()
  216. assert frame.opcode is OP_CONT
  217. if decode:
  218. yield decoder.decode(frame.data, frame.fin)
  219. else:
  220. yield frame.data
  221. self.get_in_progress = False
  222. def put(self, frame: Frame) -> None:
  223. """
  224. Add ``frame`` to the next message.
  225. Raises:
  226. EOFError: If the stream of frames has ended.
  227. """
  228. with self.mutex:
  229. if self.closed:
  230. raise EOFError("stream of frames ended")
  231. self.frames.put(frame)
  232. self.maybe_pause()
  233. # put() and get/get_iter() call maybe_pause() and maybe_resume() while
  234. # holding self.mutex. This guarantees that the calls interleave properly.
  235. # Specifically, it prevents a race condition where maybe_resume() would
  236. # run before maybe_pause(), leaving the connection incorrectly paused.
  237. # A race condition is possible when get/get_iter() call self.frames.get()
  238. # without holding self.mutex. However, it's harmless — and even beneficial!
  239. # It can only result in popping an item from the queue before maybe_resume()
  240. # runs and skipping a pause() - resume() cycle that would otherwise occur.
  241. def maybe_pause(self) -> None:
  242. """Pause the writer if queue is above the high water mark."""
  243. # Skip if flow control is disabled
  244. if self.high is None:
  245. return
  246. assert self.mutex.locked()
  247. # Check for "> high" to support high = 0
  248. if self.frames.qsize() > self.high and not self.paused:
  249. self.paused = True
  250. self.pause()
  251. def maybe_resume(self) -> None:
  252. """Resume the writer if queue is below the low water mark."""
  253. # Skip if flow control is disabled
  254. if self.low is None:
  255. return
  256. assert self.mutex.locked()
  257. # Check for "<= low" to support low = 0
  258. if self.frames.qsize() <= self.low and self.paused:
  259. self.paused = False
  260. self.resume()
  261. def close(self) -> None:
  262. """
  263. End the stream of frames.
  264. Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
  265. or :meth:`put` is safe. They will raise :exc:`EOFError`.
  266. """
  267. with self.mutex:
  268. if self.closed:
  269. return
  270. self.closed = True
  271. if self.get_in_progress:
  272. # Unblock get() or get_iter().
  273. self.frames.put(None)
  274. if self.paused:
  275. # Unblock recv_events().
  276. self.paused = False
  277. self.resume()