Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

553 řádky
15 KiB

  1. """Test utilities. Don't use outside of the uvloop project."""
  2. import asyncio
  3. import asyncio.events
  4. import collections
  5. import contextlib
  6. import gc
  7. import logging
  8. import os
  9. import pprint
  10. import re
  11. import select
  12. import socket
  13. import ssl
  14. import sys
  15. import tempfile
  16. import threading
  17. import time
  18. import unittest
  19. import uvloop
  20. class MockPattern(str):
  21. def __eq__(self, other):
  22. return bool(re.search(str(self), other, re.S))
  23. class TestCaseDict(collections.UserDict):
  24. def __init__(self, name):
  25. super().__init__()
  26. self.name = name
  27. def __setitem__(self, key, value):
  28. if key in self.data:
  29. raise RuntimeError('duplicate test {}.{}'.format(
  30. self.name, key))
  31. super().__setitem__(key, value)
  32. class BaseTestCaseMeta(type):
  33. @classmethod
  34. def __prepare__(mcls, name, bases):
  35. return TestCaseDict(name)
  36. def __new__(mcls, name, bases, dct):
  37. for test_name in dct:
  38. if not test_name.startswith('test_'):
  39. continue
  40. for base in bases:
  41. if hasattr(base, test_name):
  42. raise RuntimeError(
  43. 'duplicate test {}.{} (also defined in {} '
  44. 'parent class)'.format(
  45. name, test_name, base.__name__))
  46. return super().__new__(mcls, name, bases, dict(dct))
  47. class BaseTestCase(unittest.TestCase, metaclass=BaseTestCaseMeta):
  48. def new_loop(self):
  49. raise NotImplementedError
  50. def new_policy(self):
  51. raise NotImplementedError
  52. def mock_pattern(self, str):
  53. return MockPattern(str)
  54. async def wait_closed(self, obj):
  55. if not isinstance(obj, asyncio.StreamWriter):
  56. return
  57. try:
  58. await obj.wait_closed()
  59. except (BrokenPipeError, ConnectionError):
  60. pass
  61. def is_asyncio_loop(self):
  62. return type(self.loop).__module__.startswith('asyncio.')
  63. def run_loop_briefly(self, *, delay=0.01):
  64. self.loop.run_until_complete(asyncio.sleep(delay))
  65. def loop_exception_handler(self, loop, context):
  66. self.__unhandled_exceptions.append(context)
  67. self.loop.default_exception_handler(context)
  68. def setUp(self):
  69. self.loop = self.new_loop()
  70. asyncio.set_event_loop_policy(self.new_policy())
  71. asyncio.set_event_loop(self.loop)
  72. self._check_unclosed_resources_in_debug = True
  73. self.loop.set_exception_handler(self.loop_exception_handler)
  74. self.__unhandled_exceptions = []
  75. def tearDown(self):
  76. self.loop.close()
  77. if self.__unhandled_exceptions:
  78. print('Unexpected calls to loop.call_exception_handler():')
  79. pprint.pprint(self.__unhandled_exceptions)
  80. self.fail('unexpected calls to loop.call_exception_handler()')
  81. return
  82. if not self._check_unclosed_resources_in_debug:
  83. return
  84. # GC to show any resource warnings as the test completes
  85. gc.collect()
  86. gc.collect()
  87. gc.collect()
  88. if getattr(self.loop, '_debug_cc', False):
  89. gc.collect()
  90. gc.collect()
  91. gc.collect()
  92. self.assertEqual(
  93. self.loop._debug_uv_handles_total,
  94. self.loop._debug_uv_handles_freed,
  95. 'not all uv_handle_t handles were freed')
  96. self.assertEqual(
  97. self.loop._debug_cb_handles_count, 0,
  98. 'not all callbacks (call_soon) are GCed')
  99. self.assertEqual(
  100. self.loop._debug_cb_timer_handles_count, 0,
  101. 'not all timer callbacks (call_later) are GCed')
  102. self.assertEqual(
  103. self.loop._debug_stream_write_ctx_cnt, 0,
  104. 'not all stream write contexts are GCed')
  105. for h_name, h_cnt in self.loop._debug_handles_current.items():
  106. with self.subTest('Alive handle after test',
  107. handle_name=h_name):
  108. self.assertEqual(
  109. h_cnt, 0,
  110. 'alive {} after test'.format(h_name))
  111. for h_name, h_cnt in self.loop._debug_handles_total.items():
  112. with self.subTest('Total/closed handles',
  113. handle_name=h_name):
  114. self.assertEqual(
  115. h_cnt, self.loop._debug_handles_closed[h_name],
  116. 'total != closed for {}'.format(h_name))
  117. asyncio.set_event_loop(None)
  118. asyncio.set_event_loop_policy(None)
  119. self.loop = None
  120. def skip_unclosed_handles_check(self):
  121. self._check_unclosed_resources_in_debug = False
  122. def tcp_server(self, server_prog, *,
  123. family=socket.AF_INET,
  124. addr=None,
  125. timeout=5,
  126. backlog=1,
  127. max_clients=10):
  128. if addr is None:
  129. if family == socket.AF_UNIX:
  130. with tempfile.NamedTemporaryFile() as tmp:
  131. addr = tmp.name
  132. else:
  133. addr = ('127.0.0.1', 0)
  134. sock = socket.socket(family, socket.SOCK_STREAM)
  135. if timeout is None:
  136. raise RuntimeError('timeout is required')
  137. if timeout <= 0:
  138. raise RuntimeError('only blocking sockets are supported')
  139. sock.settimeout(timeout)
  140. try:
  141. sock.bind(addr)
  142. sock.listen(backlog)
  143. except OSError as ex:
  144. sock.close()
  145. raise ex
  146. return TestThreadedServer(
  147. self, sock, server_prog, timeout, max_clients)
  148. def tcp_client(self, client_prog,
  149. family=socket.AF_INET,
  150. timeout=10):
  151. sock = socket.socket(family, socket.SOCK_STREAM)
  152. if timeout is None:
  153. raise RuntimeError('timeout is required')
  154. if timeout <= 0:
  155. raise RuntimeError('only blocking sockets are supported')
  156. sock.settimeout(timeout)
  157. return TestThreadedClient(
  158. self, sock, client_prog, timeout)
  159. def unix_server(self, *args, **kwargs):
  160. return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
  161. def unix_client(self, *args, **kwargs):
  162. return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
  163. @contextlib.contextmanager
  164. def unix_sock_name(self):
  165. with tempfile.TemporaryDirectory() as td:
  166. fn = os.path.join(td, 'sock')
  167. try:
  168. yield fn
  169. finally:
  170. try:
  171. os.unlink(fn)
  172. except OSError:
  173. pass
  174. def _abort_socket_test(self, ex):
  175. try:
  176. self.loop.stop()
  177. finally:
  178. self.fail(ex)
  179. def _cert_fullname(test_file_name, cert_file_name):
  180. fullname = os.path.abspath(os.path.join(
  181. os.path.dirname(test_file_name), 'certs', cert_file_name))
  182. assert os.path.isfile(fullname)
  183. return fullname
  184. @contextlib.contextmanager
  185. def silence_long_exec_warning():
  186. class Filter(logging.Filter):
  187. def filter(self, record):
  188. return not (record.msg.startswith('Executing') and
  189. record.msg.endswith('seconds'))
  190. logger = logging.getLogger('asyncio')
  191. filter = Filter()
  192. logger.addFilter(filter)
  193. try:
  194. yield
  195. finally:
  196. logger.removeFilter(filter)
  197. def find_free_port(start_from=50000):
  198. for port in range(start_from, start_from + 500):
  199. sock = socket.socket()
  200. with sock:
  201. try:
  202. sock.bind(('', port))
  203. except socket.error:
  204. continue
  205. else:
  206. return port
  207. raise RuntimeError('could not find a free port')
  208. class SSLTestCase:
  209. def _create_server_ssl_context(self, certfile, keyfile=None):
  210. if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
  211. sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  212. elif hasattr(ssl, 'PROTOCOL_TLS'):
  213. sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS)
  214. else:
  215. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
  216. sslcontext.options |= ssl.OP_NO_SSLv2
  217. sslcontext.load_cert_chain(certfile, keyfile)
  218. return sslcontext
  219. def _create_client_ssl_context(self, *, disable_verify=True):
  220. sslcontext = ssl.create_default_context()
  221. sslcontext.check_hostname = False
  222. if disable_verify:
  223. sslcontext.verify_mode = ssl.CERT_NONE
  224. return sslcontext
  225. @contextlib.contextmanager
  226. def _silence_eof_received_warning(self):
  227. # TODO This warning has to be fixed in asyncio.
  228. logger = logging.getLogger('asyncio')
  229. filter = logging.Filter('has no effect when using ssl')
  230. logger.addFilter(filter)
  231. try:
  232. yield
  233. finally:
  234. logger.removeFilter(filter)
  235. class UVTestCase(BaseTestCase):
  236. implementation = 'uvloop'
  237. def new_loop(self):
  238. return uvloop.new_event_loop()
  239. def new_policy(self):
  240. return uvloop.EventLoopPolicy()
  241. class AIOTestCase(BaseTestCase):
  242. implementation = 'asyncio'
  243. def setUp(self):
  244. super().setUp()
  245. if sys.version_info < (3, 12):
  246. watcher = asyncio.SafeChildWatcher()
  247. watcher.attach_loop(self.loop)
  248. asyncio.set_child_watcher(watcher)
  249. def tearDown(self):
  250. if sys.version_info < (3, 12):
  251. asyncio.set_child_watcher(None)
  252. super().tearDown()
  253. def new_loop(self):
  254. return asyncio.new_event_loop()
  255. def new_policy(self):
  256. return asyncio.DefaultEventLoopPolicy()
  257. def has_IPv6():
  258. server_sock = socket.socket(socket.AF_INET6)
  259. with server_sock:
  260. try:
  261. server_sock.bind(('::1', 0))
  262. except OSError:
  263. return False
  264. else:
  265. return True
  266. has_IPv6 = has_IPv6()
  267. ###############################################################################
  268. # Socket Testing Utilities
  269. ###############################################################################
  270. class TestSocketWrapper:
  271. def __init__(self, sock):
  272. self.__sock = sock
  273. def recv_all(self, n):
  274. buf = b''
  275. while len(buf) < n:
  276. data = self.recv(n - len(buf))
  277. if data == b'':
  278. raise ConnectionAbortedError
  279. buf += data
  280. return buf
  281. def starttls(self, ssl_context, *,
  282. server_side=False,
  283. server_hostname=None,
  284. do_handshake_on_connect=True):
  285. assert isinstance(ssl_context, ssl.SSLContext)
  286. ssl_sock = ssl_context.wrap_socket(
  287. self.__sock, server_side=server_side,
  288. server_hostname=server_hostname,
  289. do_handshake_on_connect=do_handshake_on_connect)
  290. if server_side:
  291. ssl_sock.do_handshake()
  292. self.__sock.close()
  293. self.__sock = ssl_sock
  294. def __getattr__(self, name):
  295. return getattr(self.__sock, name)
  296. def __repr__(self):
  297. return '<{} {!r}>'.format(type(self).__name__, self.__sock)
  298. class SocketThread(threading.Thread):
  299. def stop(self):
  300. self._active = False
  301. self.join()
  302. def __enter__(self):
  303. self.start()
  304. return self
  305. def __exit__(self, *exc):
  306. self.stop()
  307. class TestThreadedClient(SocketThread):
  308. def __init__(self, test, sock, prog, timeout):
  309. threading.Thread.__init__(self, None, None, 'test-client')
  310. self.daemon = True
  311. self._timeout = timeout
  312. self._sock = sock
  313. self._active = True
  314. self._prog = prog
  315. self._test = test
  316. def run(self):
  317. try:
  318. self._prog(TestSocketWrapper(self._sock))
  319. except (KeyboardInterrupt, SystemExit):
  320. raise
  321. except BaseException as ex:
  322. self._test._abort_socket_test(ex)
  323. class TestThreadedServer(SocketThread):
  324. def __init__(self, test, sock, prog, timeout, max_clients):
  325. threading.Thread.__init__(self, None, None, 'test-server')
  326. self.daemon = True
  327. self._clients = 0
  328. self._finished_clients = 0
  329. self._max_clients = max_clients
  330. self._timeout = timeout
  331. self._sock = sock
  332. self._active = True
  333. self._prog = prog
  334. self._s1, self._s2 = socket.socketpair()
  335. self._s1.setblocking(False)
  336. self._test = test
  337. def stop(self):
  338. try:
  339. if self._s2 and self._s2.fileno() != -1:
  340. try:
  341. self._s2.send(b'stop')
  342. except OSError:
  343. pass
  344. finally:
  345. super().stop()
  346. def run(self):
  347. try:
  348. with self._sock:
  349. self._sock.setblocking(0)
  350. self._run()
  351. finally:
  352. self._s1.close()
  353. self._s2.close()
  354. def _run(self):
  355. while self._active:
  356. if self._clients >= self._max_clients:
  357. return
  358. r, w, x = select.select(
  359. [self._sock, self._s1], [], [], self._timeout)
  360. if self._s1 in r:
  361. return
  362. if self._sock in r:
  363. try:
  364. conn, addr = self._sock.accept()
  365. except BlockingIOError:
  366. continue
  367. except socket.timeout:
  368. if not self._active:
  369. return
  370. else:
  371. raise
  372. else:
  373. self._clients += 1
  374. conn.settimeout(self._timeout)
  375. try:
  376. with conn:
  377. self._handle_client(conn)
  378. except (KeyboardInterrupt, SystemExit):
  379. raise
  380. except BaseException as ex:
  381. self._active = False
  382. try:
  383. raise
  384. finally:
  385. self._test._abort_socket_test(ex)
  386. def _handle_client(self, sock):
  387. self._prog(TestSocketWrapper(sock))
  388. @property
  389. def addr(self):
  390. return self._sock.getsockname()
  391. ###############################################################################
  392. # A few helpers from asyncio/tests/testutils.py
  393. ###############################################################################
  394. def run_briefly(loop):
  395. async def once():
  396. pass
  397. gen = once()
  398. t = loop.create_task(gen)
  399. # Don't log a warning if the task is not done after run_until_complete().
  400. # It occurs if the loop is stopped or if a task raises a BaseException.
  401. t._log_destroy_pending = False
  402. try:
  403. loop.run_until_complete(t)
  404. finally:
  405. gen.close()
  406. def run_until(loop, pred, timeout=30):
  407. deadline = time.time() + timeout
  408. while not pred():
  409. if timeout is not None:
  410. timeout = deadline - time.time()
  411. if timeout <= 0:
  412. raise asyncio.futures.TimeoutError()
  413. loop.run_until_complete(asyncio.tasks.sleep(0.001))
  414. @contextlib.contextmanager
  415. def disable_logger():
  416. """Context manager to disable asyncio logger.
  417. For example, it can be used to ignore warnings in debug mode.
  418. """
  419. old_level = asyncio.log.logger.level
  420. try:
  421. asyncio.log.logger.setLevel(logging.CRITICAL + 1)
  422. yield
  423. finally:
  424. asyncio.log.logger.setLevel(old_level)