You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

282 regels
12 KiB

  1. import logging
  2. import re
  3. import ssl
  4. from dataclasses import dataclass
  5. from functools import wraps
  6. from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
  7. from .. import BrokenResourceError, EndOfStream, aclose_forcefully, get_cancelled_exc_class
  8. from .._core._typedattr import TypedAttributeSet, typed_attribute
  9. from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
  10. T_Retval = TypeVar('T_Retval')
  11. _PCTRTT = Tuple[Tuple[str, str], ...]
  12. _PCTRTTT = Tuple[_PCTRTT, ...]
  13. class TLSAttribute(TypedAttributeSet):
  14. """Contains Transport Layer Security related attributes."""
  15. #: the selected ALPN protocol
  16. alpn_protocol: Optional[str] = typed_attribute()
  17. #: the channel binding for type ``tls-unique``
  18. channel_binding_tls_unique: bytes = typed_attribute()
  19. #: the selected cipher
  20. cipher: Tuple[str, str, int] = typed_attribute()
  21. #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more
  22. #: information)
  23. peer_certificate: Optional[Dict[str, Union[str, _PCTRTTT, _PCTRTT]]] = typed_attribute()
  24. #: the peer certificate in binary form
  25. peer_certificate_binary: Optional[bytes] = typed_attribute()
  26. #: ``True`` if this is the server side of the connection
  27. server_side: bool = typed_attribute()
  28. #: ciphers shared between both ends of the TLS connection
  29. shared_ciphers: List[Tuple[str, str, int]] = typed_attribute()
  30. #: the :class:`~ssl.SSLObject` used for encryption
  31. ssl_object: ssl.SSLObject = typed_attribute()
  32. #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being
  33. #: closed
  34. standard_compatible: bool = typed_attribute()
  35. #: the TLS protocol version (e.g. ``TLSv1.2``)
  36. tls_version: str = typed_attribute()
  37. @dataclass(eq=False)
  38. class TLSStream(ByteStream):
  39. """
  40. A stream wrapper that encrypts all sent data and decrypts received data.
  41. This class has no public initializer; use :meth:`wrap` instead.
  42. All extra attributes from :class:`~TLSAttribute` are supported.
  43. :var AnyByteStream transport_stream: the wrapped stream
  44. """
  45. transport_stream: AnyByteStream
  46. standard_compatible: bool
  47. _ssl_object: ssl.SSLObject
  48. _read_bio: ssl.MemoryBIO
  49. _write_bio: ssl.MemoryBIO
  50. @classmethod
  51. async def wrap(cls, transport_stream: AnyByteStream, *, server_side: Optional[bool] = None,
  52. hostname: Optional[str] = None, ssl_context: Optional[ssl.SSLContext] = None,
  53. standard_compatible: bool = True) -> 'TLSStream':
  54. """
  55. Wrap an existing stream with Transport Layer Security.
  56. This performs a TLS handshake with the peer.
  57. :param transport_stream: a bytes-transporting stream to wrap
  58. :param server_side: ``True`` if this is the server side of the connection, ``False`` if
  59. this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been
  60. provided, ``False`` otherwise). Used only to create a default context when an explicit
  61. context has not been provided.
  62. :param hostname: host name of the peer (if host name checking is desired)
  63. :param ssl_context: the SSLContext object to use (if not provided, a secure default will be
  64. created)
  65. :param standard_compatible: if ``False``, skip the closing handshake when closing the
  66. connection, and don't raise an exception if the peer does the same
  67. :raises ~ssl.SSLError: if the TLS handshake fails
  68. """
  69. if server_side is None:
  70. server_side = not hostname
  71. if not ssl_context:
  72. purpose = ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
  73. ssl_context = ssl.create_default_context(purpose)
  74. # Re-enable detection of unexpected EOFs if it was disabled by Python
  75. if hasattr(ssl, 'OP_IGNORE_UNEXPECTED_EOF'):
  76. ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined]
  77. bio_in = ssl.MemoryBIO()
  78. bio_out = ssl.MemoryBIO()
  79. ssl_object = ssl_context.wrap_bio(bio_in, bio_out, server_side=server_side,
  80. server_hostname=hostname)
  81. wrapper = cls(transport_stream=transport_stream,
  82. standard_compatible=standard_compatible, _ssl_object=ssl_object,
  83. _read_bio=bio_in, _write_bio=bio_out)
  84. await wrapper._call_sslobject_method(ssl_object.do_handshake)
  85. return wrapper
  86. async def _call_sslobject_method(
  87. self, func: Callable[..., T_Retval], *args: object
  88. ) -> T_Retval:
  89. while True:
  90. try:
  91. result = func(*args)
  92. except ssl.SSLWantReadError:
  93. try:
  94. # Flush any pending writes first
  95. if self._write_bio.pending:
  96. await self.transport_stream.send(self._write_bio.read())
  97. data = await self.transport_stream.receive()
  98. except EndOfStream:
  99. self._read_bio.write_eof()
  100. except OSError as exc:
  101. self._read_bio.write_eof()
  102. self._write_bio.write_eof()
  103. raise BrokenResourceError from exc
  104. else:
  105. self._read_bio.write(data)
  106. except ssl.SSLWantWriteError:
  107. await self.transport_stream.send(self._write_bio.read())
  108. except ssl.SSLSyscallError as exc:
  109. self._read_bio.write_eof()
  110. self._write_bio.write_eof()
  111. raise BrokenResourceError from exc
  112. except ssl.SSLError as exc:
  113. self._read_bio.write_eof()
  114. self._write_bio.write_eof()
  115. if (isinstance(exc, ssl.SSLEOFError)
  116. or 'UNEXPECTED_EOF_WHILE_READING' in exc.strerror):
  117. if self.standard_compatible:
  118. raise BrokenResourceError from exc
  119. else:
  120. raise EndOfStream from None
  121. raise
  122. else:
  123. # Flush any pending writes first
  124. if self._write_bio.pending:
  125. await self.transport_stream.send(self._write_bio.read())
  126. return result
  127. async def unwrap(self) -> Tuple[AnyByteStream, bytes]:
  128. """
  129. Does the TLS closing handshake.
  130. :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
  131. """
  132. await self._call_sslobject_method(self._ssl_object.unwrap)
  133. self._read_bio.write_eof()
  134. self._write_bio.write_eof()
  135. return self.transport_stream, self._read_bio.read()
  136. async def aclose(self) -> None:
  137. if self.standard_compatible:
  138. try:
  139. await self.unwrap()
  140. except BaseException:
  141. await aclose_forcefully(self.transport_stream)
  142. raise
  143. await self.transport_stream.aclose()
  144. async def receive(self, max_bytes: int = 65536) -> bytes:
  145. data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
  146. if not data:
  147. raise EndOfStream
  148. return data
  149. async def send(self, item: bytes) -> None:
  150. await self._call_sslobject_method(self._ssl_object.write, item)
  151. async def send_eof(self) -> None:
  152. tls_version = self.extra(TLSAttribute.tls_version)
  153. match = re.match(r'TLSv(\d+)(?:\.(\d+))?', tls_version)
  154. if match:
  155. major, minor = int(match.group(1)), int(match.group(2) or 0)
  156. if (major, minor) < (1, 3):
  157. raise NotImplementedError(f'send_eof() requires at least TLSv1.3; current '
  158. f'session uses {tls_version}')
  159. raise NotImplementedError('send_eof() has not yet been implemented for TLS streams')
  160. @property
  161. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  162. return {
  163. **self.transport_stream.extra_attributes,
  164. TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
  165. TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
  166. TLSAttribute.cipher: self._ssl_object.cipher,
  167. TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
  168. TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(True),
  169. TLSAttribute.server_side: lambda: self._ssl_object.server_side,
  170. TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(),
  171. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  172. TLSAttribute.ssl_object: lambda: self._ssl_object,
  173. TLSAttribute.tls_version: self._ssl_object.version
  174. }
  175. @dataclass(eq=False)
  176. class TLSListener(Listener[TLSStream]):
  177. """
  178. A convenience listener that wraps another listener and auto-negotiates a TLS session on every
  179. accepted connection.
  180. If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
  181. called to do whatever post-mortem processing is deemed necessary.
  182. Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
  183. :param Listener listener: the listener to wrap
  184. :param ssl_context: the SSL context object
  185. :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
  186. :param handshake_timeout: time limit for the TLS handshake
  187. (passed to :func:`~anyio.fail_after`)
  188. """
  189. listener: Listener[Any]
  190. ssl_context: ssl.SSLContext
  191. standard_compatible: bool = True
  192. handshake_timeout: float = 30
  193. @staticmethod
  194. async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
  195. f"""
  196. Handle an exception raised during the TLS handshake.
  197. This method does 3 things:
  198. #. Forcefully closes the original stream
  199. #. Logs the exception (unless it was a cancellation exception) using the ``{__name__}``
  200. logger
  201. #. Reraises the exception if it was a base exception or a cancellation exception
  202. :param exc: the exception
  203. :param stream: the original stream
  204. """
  205. await aclose_forcefully(stream)
  206. # Log all except cancellation exceptions
  207. if not isinstance(exc, get_cancelled_exc_class()):
  208. logging.getLogger(__name__).exception('Error during TLS handshake')
  209. # Only reraise base exceptions and cancellation exceptions
  210. if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
  211. raise
  212. async def serve(self, handler: Callable[[TLSStream], Any],
  213. task_group: Optional[TaskGroup] = None) -> None:
  214. @wraps(handler)
  215. async def handler_wrapper(stream: AnyByteStream) -> None:
  216. from .. import fail_after
  217. try:
  218. with fail_after(self.handshake_timeout):
  219. wrapped_stream = await TLSStream.wrap(
  220. stream, ssl_context=self.ssl_context,
  221. standard_compatible=self.standard_compatible)
  222. except BaseException as exc:
  223. await self.handle_handshake_error(exc, stream)
  224. else:
  225. await handler(wrapped_stream)
  226. await self.listener.serve(handler_wrapper, task_group)
  227. async def aclose(self) -> None:
  228. await self.listener.aclose()
  229. @property
  230. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  231. return {
  232. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  233. }