You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

589 lines
21 KiB

  1. """
  2. :mod:`websockets.extensions.permessage_deflate` implements the Compression
  3. Extensions for WebSocket as specified in :rfc:`7692`.
  4. """
  5. import zlib
  6. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  7. from ..exceptions import (
  8. DuplicateParameter,
  9. InvalidParameterName,
  10. InvalidParameterValue,
  11. NegotiationError,
  12. PayloadTooBig,
  13. )
  14. from ..framing import CTRL_OPCODES, OP_CONT, Frame
  15. from ..typing import ExtensionParameter
  16. from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
  17. __all__ = [
  18. "PerMessageDeflate",
  19. "ClientPerMessageDeflateFactory",
  20. "ServerPerMessageDeflateFactory",
  21. ]
  22. _EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
  23. _MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
  24. class PerMessageDeflate(Extension):
  25. """
  26. Per-Message Deflate extension.
  27. """
  28. name = "permessage-deflate"
  29. def __init__(
  30. self,
  31. remote_no_context_takeover: bool,
  32. local_no_context_takeover: bool,
  33. remote_max_window_bits: int,
  34. local_max_window_bits: int,
  35. compress_settings: Optional[Dict[Any, Any]] = None,
  36. ) -> None:
  37. """
  38. Configure the Per-Message Deflate extension.
  39. """
  40. if compress_settings is None:
  41. compress_settings = {}
  42. assert remote_no_context_takeover in [False, True]
  43. assert local_no_context_takeover in [False, True]
  44. assert 8 <= remote_max_window_bits <= 15
  45. assert 8 <= local_max_window_bits <= 15
  46. assert "wbits" not in compress_settings
  47. self.remote_no_context_takeover = remote_no_context_takeover
  48. self.local_no_context_takeover = local_no_context_takeover
  49. self.remote_max_window_bits = remote_max_window_bits
  50. self.local_max_window_bits = local_max_window_bits
  51. self.compress_settings = compress_settings
  52. if not self.remote_no_context_takeover:
  53. self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
  54. if not self.local_no_context_takeover:
  55. self.encoder = zlib.compressobj(
  56. wbits=-self.local_max_window_bits, **self.compress_settings
  57. )
  58. # To handle continuation frames properly, we must keep track of
  59. # whether that initial frame was encoded.
  60. self.decode_cont_data = False
  61. # There's no need for self.encode_cont_data because we always encode
  62. # outgoing frames, so it would always be True.
  63. def __repr__(self) -> str:
  64. return (
  65. f"PerMessageDeflate("
  66. f"remote_no_context_takeover={self.remote_no_context_takeover}, "
  67. f"local_no_context_takeover={self.local_no_context_takeover}, "
  68. f"remote_max_window_bits={self.remote_max_window_bits}, "
  69. f"local_max_window_bits={self.local_max_window_bits})"
  70. )
  71. def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame:
  72. """
  73. Decode an incoming frame.
  74. """
  75. # Skip control frames.
  76. if frame.opcode in CTRL_OPCODES:
  77. return frame
  78. # Handle continuation data frames:
  79. # - skip if the initial data frame wasn't encoded
  80. # - reset "decode continuation data" flag if it's a final frame
  81. if frame.opcode == OP_CONT:
  82. if not self.decode_cont_data:
  83. return frame
  84. if frame.fin:
  85. self.decode_cont_data = False
  86. # Handle text and binary data frames:
  87. # - skip if the frame isn't encoded
  88. # - set "decode continuation data" flag if it's a non-final frame
  89. else:
  90. if not frame.rsv1:
  91. return frame
  92. if not frame.fin: # frame.rsv1 is True at this point
  93. self.decode_cont_data = True
  94. # Re-initialize per-message decoder.
  95. if self.remote_no_context_takeover:
  96. self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
  97. # Uncompress compressed frames. Protect against zip bombs by
  98. # preventing zlib from decompressing more than max_length bytes
  99. # (except when the limit is disabled with max_size = None).
  100. data = frame.data
  101. if frame.fin:
  102. data += _EMPTY_UNCOMPRESSED_BLOCK
  103. max_length = 0 if max_size is None else max_size
  104. data = self.decoder.decompress(data, max_length)
  105. if self.decoder.unconsumed_tail:
  106. raise PayloadTooBig(
  107. f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)"
  108. )
  109. # Allow garbage collection of the decoder if it won't be reused.
  110. if frame.fin and self.remote_no_context_takeover:
  111. del self.decoder
  112. return frame._replace(data=data, rsv1=False)
  113. def encode(self, frame: Frame) -> Frame:
  114. """
  115. Encode an outgoing frame.
  116. """
  117. # Skip control frames.
  118. if frame.opcode in CTRL_OPCODES:
  119. return frame
  120. # Since we always encode and never fragment messages, there's no logic
  121. # similar to decode() here at this time.
  122. if frame.opcode != OP_CONT:
  123. # Re-initialize per-message decoder.
  124. if self.local_no_context_takeover:
  125. self.encoder = zlib.compressobj(
  126. wbits=-self.local_max_window_bits, **self.compress_settings
  127. )
  128. # Compress data frames.
  129. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
  130. if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
  131. data = data[:-4]
  132. # Allow garbage collection of the encoder if it won't be reused.
  133. if frame.fin and self.local_no_context_takeover:
  134. del self.encoder
  135. return frame._replace(data=data, rsv1=True)
  136. def _build_parameters(
  137. server_no_context_takeover: bool,
  138. client_no_context_takeover: bool,
  139. server_max_window_bits: Optional[int],
  140. client_max_window_bits: Optional[Union[int, bool]],
  141. ) -> List[ExtensionParameter]:
  142. """
  143. Build a list of ``(name, value)`` pairs for some compression parameters.
  144. """
  145. params: List[ExtensionParameter] = []
  146. if server_no_context_takeover:
  147. params.append(("server_no_context_takeover", None))
  148. if client_no_context_takeover:
  149. params.append(("client_no_context_takeover", None))
  150. if server_max_window_bits:
  151. params.append(("server_max_window_bits", str(server_max_window_bits)))
  152. if client_max_window_bits is True: # only in handshake requests
  153. params.append(("client_max_window_bits", None))
  154. elif client_max_window_bits:
  155. params.append(("client_max_window_bits", str(client_max_window_bits)))
  156. return params
  157. def _extract_parameters(
  158. params: Sequence[ExtensionParameter], *, is_server: bool
  159. ) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]:
  160. """
  161. Extract compression parameters from a list of ``(name, value)`` pairs.
  162. If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided
  163. without a value. This is only allow in handshake requests.
  164. """
  165. server_no_context_takeover: bool = False
  166. client_no_context_takeover: bool = False
  167. server_max_window_bits: Optional[int] = None
  168. client_max_window_bits: Optional[Union[int, bool]] = None
  169. for name, value in params:
  170. if name == "server_no_context_takeover":
  171. if server_no_context_takeover:
  172. raise DuplicateParameter(name)
  173. if value is None:
  174. server_no_context_takeover = True
  175. else:
  176. raise InvalidParameterValue(name, value)
  177. elif name == "client_no_context_takeover":
  178. if client_no_context_takeover:
  179. raise DuplicateParameter(name)
  180. if value is None:
  181. client_no_context_takeover = True
  182. else:
  183. raise InvalidParameterValue(name, value)
  184. elif name == "server_max_window_bits":
  185. if server_max_window_bits is not None:
  186. raise DuplicateParameter(name)
  187. if value in _MAX_WINDOW_BITS_VALUES:
  188. server_max_window_bits = int(value)
  189. else:
  190. raise InvalidParameterValue(name, value)
  191. elif name == "client_max_window_bits":
  192. if client_max_window_bits is not None:
  193. raise DuplicateParameter(name)
  194. if is_server and value is None: # only in handshake requests
  195. client_max_window_bits = True
  196. elif value in _MAX_WINDOW_BITS_VALUES:
  197. client_max_window_bits = int(value)
  198. else:
  199. raise InvalidParameterValue(name, value)
  200. else:
  201. raise InvalidParameterName(name)
  202. return (
  203. server_no_context_takeover,
  204. client_no_context_takeover,
  205. server_max_window_bits,
  206. client_max_window_bits,
  207. )
  208. class ClientPerMessageDeflateFactory(ClientExtensionFactory):
  209. """
  210. Client-side extension factory for the Per-Message Deflate extension.
  211. Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to
  212. ``True`` to include them in the negotiation offer without a value or to an
  213. integer value to include them with this value.
  214. .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1
  215. :param server_no_context_takeover: defaults to ``False``
  216. :param client_no_context_takeover: defaults to ``False``
  217. :param server_max_window_bits: optional, defaults to ``None``
  218. :param client_max_window_bits: optional, defaults to ``None``
  219. :param compress_settings: optional, keyword arguments for
  220. :func:`zlib.compressobj`, excluding ``wbits``
  221. """
  222. name = "permessage-deflate"
  223. def __init__(
  224. self,
  225. server_no_context_takeover: bool = False,
  226. client_no_context_takeover: bool = False,
  227. server_max_window_bits: Optional[int] = None,
  228. client_max_window_bits: Optional[Union[int, bool]] = None,
  229. compress_settings: Optional[Dict[str, Any]] = None,
  230. ) -> None:
  231. """
  232. Configure the Per-Message Deflate extension factory.
  233. """
  234. if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
  235. raise ValueError("server_max_window_bits must be between 8 and 15")
  236. if not (
  237. client_max_window_bits is None
  238. or client_max_window_bits is True
  239. or 8 <= client_max_window_bits <= 15
  240. ):
  241. raise ValueError("client_max_window_bits must be between 8 and 15")
  242. if compress_settings is not None and "wbits" in compress_settings:
  243. raise ValueError(
  244. "compress_settings must not include wbits, "
  245. "set client_max_window_bits instead"
  246. )
  247. self.server_no_context_takeover = server_no_context_takeover
  248. self.client_no_context_takeover = client_no_context_takeover
  249. self.server_max_window_bits = server_max_window_bits
  250. self.client_max_window_bits = client_max_window_bits
  251. self.compress_settings = compress_settings
  252. def get_request_params(self) -> List[ExtensionParameter]:
  253. """
  254. Build request parameters.
  255. """
  256. return _build_parameters(
  257. self.server_no_context_takeover,
  258. self.client_no_context_takeover,
  259. self.server_max_window_bits,
  260. self.client_max_window_bits,
  261. )
  262. def process_response_params(
  263. self,
  264. params: Sequence[ExtensionParameter],
  265. accepted_extensions: Sequence["Extension"],
  266. ) -> PerMessageDeflate:
  267. """
  268. Process response parameters.
  269. Return an extension instance.
  270. """
  271. if any(other.name == self.name for other in accepted_extensions):
  272. raise NegotiationError(f"received duplicate {self.name}")
  273. # Request parameters are available in instance variables.
  274. # Load response parameters in local variables.
  275. (
  276. server_no_context_takeover,
  277. client_no_context_takeover,
  278. server_max_window_bits,
  279. client_max_window_bits,
  280. ) = _extract_parameters(params, is_server=False)
  281. # After comparing the request and the response, the final
  282. # configuration must be available in the local variables.
  283. # server_no_context_takeover
  284. #
  285. # Req. Resp. Result
  286. # ------ ------ --------------------------------------------------
  287. # False False False
  288. # False True True
  289. # True False Error!
  290. # True True True
  291. if self.server_no_context_takeover:
  292. if not server_no_context_takeover:
  293. raise NegotiationError("expected server_no_context_takeover")
  294. # client_no_context_takeover
  295. #
  296. # Req. Resp. Result
  297. # ------ ------ --------------------------------------------------
  298. # False False False
  299. # False True True
  300. # True False True - must change value
  301. # True True True
  302. if self.client_no_context_takeover:
  303. if not client_no_context_takeover:
  304. client_no_context_takeover = True
  305. # server_max_window_bits
  306. # Req. Resp. Result
  307. # ------ ------ --------------------------------------------------
  308. # None None None
  309. # None 8≤M≤15 M
  310. # 8≤N≤15 None Error!
  311. # 8≤N≤15 8≤M≤N M
  312. # 8≤N≤15 N<M≤15 Error!
  313. if self.server_max_window_bits is None:
  314. pass
  315. else:
  316. if server_max_window_bits is None:
  317. raise NegotiationError("expected server_max_window_bits")
  318. elif server_max_window_bits > self.server_max_window_bits:
  319. raise NegotiationError("unsupported server_max_window_bits")
  320. # client_max_window_bits
  321. # Req. Resp. Result
  322. # ------ ------ --------------------------------------------------
  323. # None None None
  324. # None 8≤M≤15 Error!
  325. # True None None
  326. # True 8≤M≤15 M
  327. # 8≤N≤15 None N - must change value
  328. # 8≤N≤15 8≤M≤N M
  329. # 8≤N≤15 N<M≤15 Error!
  330. if self.client_max_window_bits is None:
  331. if client_max_window_bits is not None:
  332. raise NegotiationError("unexpected client_max_window_bits")
  333. elif self.client_max_window_bits is True:
  334. pass
  335. else:
  336. if client_max_window_bits is None:
  337. client_max_window_bits = self.client_max_window_bits
  338. elif client_max_window_bits > self.client_max_window_bits:
  339. raise NegotiationError("unsupported client_max_window_bits")
  340. return PerMessageDeflate(
  341. server_no_context_takeover, # remote_no_context_takeover
  342. client_no_context_takeover, # local_no_context_takeover
  343. server_max_window_bits or 15, # remote_max_window_bits
  344. client_max_window_bits or 15, # local_max_window_bits
  345. self.compress_settings,
  346. )
  347. class ServerPerMessageDeflateFactory(ServerExtensionFactory):
  348. """
  349. Server-side extension factory for the Per-Message Deflate extension.
  350. Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to
  351. ``True`` to include them in the negotiation offer without a value or to an
  352. integer value to include them with this value.
  353. .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1
  354. :param server_no_context_takeover: defaults to ``False``
  355. :param client_no_context_takeover: defaults to ``False``
  356. :param server_max_window_bits: optional, defaults to ``None``
  357. :param client_max_window_bits: optional, defaults to ``None``
  358. :param compress_settings: optional, keyword arguments for
  359. :func:`zlib.compressobj`, excluding ``wbits``
  360. """
  361. name = "permessage-deflate"
  362. def __init__(
  363. self,
  364. server_no_context_takeover: bool = False,
  365. client_no_context_takeover: bool = False,
  366. server_max_window_bits: Optional[int] = None,
  367. client_max_window_bits: Optional[int] = None,
  368. compress_settings: Optional[Dict[str, Any]] = None,
  369. ) -> None:
  370. """
  371. Configure the Per-Message Deflate extension factory.
  372. """
  373. if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
  374. raise ValueError("server_max_window_bits must be between 8 and 15")
  375. if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
  376. raise ValueError("client_max_window_bits must be between 8 and 15")
  377. if compress_settings is not None and "wbits" in compress_settings:
  378. raise ValueError(
  379. "compress_settings must not include wbits, "
  380. "set server_max_window_bits instead"
  381. )
  382. self.server_no_context_takeover = server_no_context_takeover
  383. self.client_no_context_takeover = client_no_context_takeover
  384. self.server_max_window_bits = server_max_window_bits
  385. self.client_max_window_bits = client_max_window_bits
  386. self.compress_settings = compress_settings
  387. def process_request_params(
  388. self,
  389. params: Sequence[ExtensionParameter],
  390. accepted_extensions: Sequence["Extension"],
  391. ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]:
  392. """
  393. Process request parameters.
  394. Return response params and an extension instance.
  395. """
  396. if any(other.name == self.name for other in accepted_extensions):
  397. raise NegotiationError(f"skipped duplicate {self.name}")
  398. # Load request parameters in local variables.
  399. (
  400. server_no_context_takeover,
  401. client_no_context_takeover,
  402. server_max_window_bits,
  403. client_max_window_bits,
  404. ) = _extract_parameters(params, is_server=True)
  405. # Configuration parameters are available in instance variables.
  406. # After comparing the request and the configuration, the response must
  407. # be available in the local variables.
  408. # server_no_context_takeover
  409. #
  410. # Config Req. Resp.
  411. # ------ ------ --------------------------------------------------
  412. # False False False
  413. # False True True
  414. # True False True - must change value to True
  415. # True True True
  416. if self.server_no_context_takeover:
  417. if not server_no_context_takeover:
  418. server_no_context_takeover = True
  419. # client_no_context_takeover
  420. #
  421. # Config Req. Resp.
  422. # ------ ------ --------------------------------------------------
  423. # False False False
  424. # False True True (or False)
  425. # True False True - must change value to True
  426. # True True True (or False)
  427. if self.client_no_context_takeover:
  428. if not client_no_context_takeover:
  429. client_no_context_takeover = True
  430. # server_max_window_bits
  431. # Config Req. Resp.
  432. # ------ ------ --------------------------------------------------
  433. # None None None
  434. # None 8≤M≤15 M
  435. # 8≤N≤15 None N - must change value
  436. # 8≤N≤15 8≤M≤N M
  437. # 8≤N≤15 N<M≤15 N - must change value
  438. if self.server_max_window_bits is None:
  439. pass
  440. else:
  441. if server_max_window_bits is None:
  442. server_max_window_bits = self.server_max_window_bits
  443. elif server_max_window_bits > self.server_max_window_bits:
  444. server_max_window_bits = self.server_max_window_bits
  445. # client_max_window_bits
  446. # Config Req. Resp.
  447. # ------ ------ --------------------------------------------------
  448. # None None None
  449. # None True None - must change value
  450. # None 8≤M≤15 M (or None)
  451. # 8≤N≤15 None Error!
  452. # 8≤N≤15 True N - must change value
  453. # 8≤N≤15 8≤M≤N M (or None)
  454. # 8≤N≤15 N<M≤15 N
  455. if self.client_max_window_bits is None:
  456. if client_max_window_bits is True:
  457. client_max_window_bits = self.client_max_window_bits
  458. else:
  459. if client_max_window_bits is None:
  460. raise NegotiationError("required client_max_window_bits")
  461. elif client_max_window_bits is True:
  462. client_max_window_bits = self.client_max_window_bits
  463. elif self.client_max_window_bits < client_max_window_bits:
  464. client_max_window_bits = self.client_max_window_bits
  465. return (
  466. _build_parameters(
  467. server_no_context_takeover,
  468. client_no_context_takeover,
  469. server_max_window_bits,
  470. client_max_window_bits,
  471. ),
  472. PerMessageDeflate(
  473. client_no_context_takeover, # remote_no_context_takeover
  474. server_no_context_takeover, # local_no_context_takeover
  475. client_max_window_bits or 15, # remote_max_window_bits
  476. server_max_window_bits or 15, # local_max_window_bits
  477. self.compress_settings,
  478. ),
  479. )