Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.
 
 
 
 

538 lignes
15 KiB

  1. """Test utilities. Don't use outside of the uvloop project."""
  2. import asyncio
  3. import asyncio.events
  4. import collections
  5. import contextlib
  6. import gc
  7. import logging
  8. import os
  9. import pprint
  10. import re
  11. import select
  12. import socket
  13. import ssl
  14. import sys
  15. import tempfile
  16. import threading
  17. import time
  18. import unittest
  19. import uvloop
  20. class MockPattern(str):
  21. def __eq__(self, other):
  22. return bool(re.search(str(self), other, re.S))
  23. class TestCaseDict(collections.UserDict):
  24. def __init__(self, name):
  25. super().__init__()
  26. self.name = name
  27. def __setitem__(self, key, value):
  28. if key in self.data:
  29. raise RuntimeError('duplicate test {}.{}'.format(
  30. self.name, key))
  31. super().__setitem__(key, value)
  32. class BaseTestCaseMeta(type):
  33. @classmethod
  34. def __prepare__(mcls, name, bases):
  35. return TestCaseDict(name)
  36. def __new__(mcls, name, bases, dct):
  37. for test_name in dct:
  38. if not test_name.startswith('test_'):
  39. continue
  40. for base in bases:
  41. if hasattr(base, test_name):
  42. raise RuntimeError(
  43. 'duplicate test {}.{} (also defined in {} '
  44. 'parent class)'.format(
  45. name, test_name, base.__name__))
  46. return super().__new__(mcls, name, bases, dict(dct))
  47. class BaseTestCase(unittest.TestCase, metaclass=BaseTestCaseMeta):
  48. def new_loop(self):
  49. raise NotImplementedError
  50. def mock_pattern(self, str):
  51. return MockPattern(str)
  52. def has_start_serving(self):
  53. return not (self.is_asyncio_loop() and
  54. sys.version_info[:2] in [(3, 5), (3, 6)])
  55. def is_asyncio_loop(self):
  56. return type(self.loop).__module__.startswith('asyncio.')
  57. def run_loop_briefly(self, *, delay=0.01):
  58. self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop))
  59. def loop_exception_handler(self, loop, context):
  60. self.__unhandled_exceptions.append(context)
  61. self.loop.default_exception_handler(context)
  62. def setUp(self):
  63. self.loop = self.new_loop()
  64. asyncio.set_event_loop(None)
  65. self._check_unclosed_resources_in_debug = True
  66. self.loop.set_exception_handler(self.loop_exception_handler)
  67. self.__unhandled_exceptions = []
  68. if hasattr(asyncio, '_get_running_loop'):
  69. # Disable `_get_running_loop`.
  70. self._get_running_loop = asyncio.events._get_running_loop
  71. asyncio.events._get_running_loop = lambda: None
  72. self.PY37 = sys.version_info[:2] >= (3, 7)
  73. self.PY36 = sys.version_info[:2] >= (3, 6)
  74. def tearDown(self):
  75. self.loop.close()
  76. if self.__unhandled_exceptions:
  77. print('Unexpected calls to loop.call_exception_handler():')
  78. pprint.pprint(self.__unhandled_exceptions)
  79. self.fail('unexpected calls to loop.call_exception_handler()')
  80. return
  81. if hasattr(asyncio, '_get_running_loop'):
  82. asyncio.events._get_running_loop = self._get_running_loop
  83. if not self._check_unclosed_resources_in_debug:
  84. return
  85. # GC to show any resource warnings as the test completes
  86. gc.collect()
  87. gc.collect()
  88. gc.collect()
  89. if getattr(self.loop, '_debug_cc', False):
  90. gc.collect()
  91. gc.collect()
  92. gc.collect()
  93. self.assertEqual(
  94. self.loop._debug_uv_handles_total,
  95. self.loop._debug_uv_handles_freed,
  96. 'not all uv_handle_t handles were freed')
  97. self.assertEqual(
  98. self.loop._debug_cb_handles_count, 0,
  99. 'not all callbacks (call_soon) are GCed')
  100. self.assertEqual(
  101. self.loop._debug_cb_timer_handles_count, 0,
  102. 'not all timer callbacks (call_later) are GCed')
  103. self.assertEqual(
  104. self.loop._debug_stream_write_ctx_cnt, 0,
  105. 'not all stream write contexts are GCed')
  106. for h_name, h_cnt in self.loop._debug_handles_current.items():
  107. with self.subTest('Alive handle after test',
  108. handle_name=h_name):
  109. self.assertEqual(
  110. h_cnt, 0,
  111. 'alive {} after test'.format(h_name))
  112. for h_name, h_cnt in self.loop._debug_handles_total.items():
  113. with self.subTest('Total/closed handles',
  114. handle_name=h_name):
  115. self.assertEqual(
  116. h_cnt, self.loop._debug_handles_closed[h_name],
  117. 'total != closed for {}'.format(h_name))
  118. asyncio.set_event_loop(None)
  119. self.loop = None
  120. def skip_unclosed_handles_check(self):
  121. self._check_unclosed_resources_in_debug = False
  122. def tcp_server(self, server_prog, *,
  123. family=socket.AF_INET,
  124. addr=None,
  125. timeout=5,
  126. backlog=1,
  127. max_clients=10):
  128. if addr is None:
  129. if family == socket.AF_UNIX:
  130. with tempfile.NamedTemporaryFile() as tmp:
  131. addr = tmp.name
  132. else:
  133. addr = ('127.0.0.1', 0)
  134. sock = socket.socket(family, socket.SOCK_STREAM)
  135. if timeout is None:
  136. raise RuntimeError('timeout is required')
  137. if timeout <= 0:
  138. raise RuntimeError('only blocking sockets are supported')
  139. sock.settimeout(timeout)
  140. try:
  141. sock.bind(addr)
  142. sock.listen(backlog)
  143. except OSError as ex:
  144. sock.close()
  145. raise ex
  146. return TestThreadedServer(
  147. self, sock, server_prog, timeout, max_clients)
  148. def tcp_client(self, client_prog,
  149. family=socket.AF_INET,
  150. timeout=10):
  151. sock = socket.socket(family, socket.SOCK_STREAM)
  152. if timeout is None:
  153. raise RuntimeError('timeout is required')
  154. if timeout <= 0:
  155. raise RuntimeError('only blocking sockets are supported')
  156. sock.settimeout(timeout)
  157. return TestThreadedClient(
  158. self, sock, client_prog, timeout)
  159. def unix_server(self, *args, **kwargs):
  160. return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
  161. def unix_client(self, *args, **kwargs):
  162. return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
  163. @contextlib.contextmanager
  164. def unix_sock_name(self):
  165. with tempfile.TemporaryDirectory() as td:
  166. fn = os.path.join(td, 'sock')
  167. try:
  168. yield fn
  169. finally:
  170. try:
  171. os.unlink(fn)
  172. except OSError:
  173. pass
  174. def _abort_socket_test(self, ex):
  175. try:
  176. self.loop.stop()
  177. finally:
  178. self.fail(ex)
  179. def _cert_fullname(test_file_name, cert_file_name):
  180. fullname = os.path.abspath(os.path.join(
  181. os.path.dirname(test_file_name), 'certs', cert_file_name))
  182. assert os.path.isfile(fullname)
  183. return fullname
  184. @contextlib.contextmanager
  185. def silence_long_exec_warning():
  186. class Filter(logging.Filter):
  187. def filter(self, record):
  188. return not (record.msg.startswith('Executing') and
  189. record.msg.endswith('seconds'))
  190. logger = logging.getLogger('asyncio')
  191. filter = Filter()
  192. logger.addFilter(filter)
  193. try:
  194. yield
  195. finally:
  196. logger.removeFilter(filter)
  197. def find_free_port(start_from=50000):
  198. for port in range(start_from, start_from + 500):
  199. sock = socket.socket()
  200. with sock:
  201. try:
  202. sock.bind(('', port))
  203. except socket.error:
  204. continue
  205. else:
  206. return port
  207. raise RuntimeError('could not find a free port')
  208. class SSLTestCase:
  209. def _create_server_ssl_context(self, certfile, keyfile=None):
  210. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
  211. sslcontext.options |= ssl.OP_NO_SSLv2
  212. sslcontext.load_cert_chain(certfile, keyfile)
  213. return sslcontext
  214. def _create_client_ssl_context(self, *, disable_verify=True):
  215. sslcontext = ssl.create_default_context()
  216. sslcontext.check_hostname = False
  217. if disable_verify:
  218. sslcontext.verify_mode = ssl.CERT_NONE
  219. return sslcontext
  220. @contextlib.contextmanager
  221. def _silence_eof_received_warning(self):
  222. # TODO This warning has to be fixed in asyncio.
  223. logger = logging.getLogger('asyncio')
  224. filter = logging.Filter('has no effect when using ssl')
  225. logger.addFilter(filter)
  226. try:
  227. yield
  228. finally:
  229. logger.removeFilter(filter)
  230. class UVTestCase(BaseTestCase):
  231. implementation = 'uvloop'
  232. def new_loop(self):
  233. return uvloop.new_event_loop()
  234. class AIOTestCase(BaseTestCase):
  235. implementation = 'asyncio'
  236. def setUp(self):
  237. super().setUp()
  238. watcher = asyncio.SafeChildWatcher()
  239. watcher.attach_loop(self.loop)
  240. asyncio.set_child_watcher(watcher)
  241. def tearDown(self):
  242. asyncio.set_child_watcher(None)
  243. super().tearDown()
  244. def new_loop(self):
  245. return asyncio.new_event_loop()
  246. def has_IPv6():
  247. server_sock = socket.socket(socket.AF_INET6)
  248. with server_sock:
  249. try:
  250. server_sock.bind(('::1', 0))
  251. except OSError:
  252. return False
  253. else:
  254. return True
  255. has_IPv6 = has_IPv6()
  256. ###############################################################################
  257. # Socket Testing Utilities
  258. ###############################################################################
  259. class TestSocketWrapper:
  260. def __init__(self, sock):
  261. self.__sock = sock
  262. def recv_all(self, n):
  263. buf = b''
  264. while len(buf) < n:
  265. data = self.recv(n - len(buf))
  266. if data == b'':
  267. raise ConnectionAbortedError
  268. buf += data
  269. return buf
  270. def starttls(self, ssl_context, *,
  271. server_side=False,
  272. server_hostname=None,
  273. do_handshake_on_connect=True):
  274. assert isinstance(ssl_context, ssl.SSLContext)
  275. ssl_sock = ssl_context.wrap_socket(
  276. self.__sock, server_side=server_side,
  277. server_hostname=server_hostname,
  278. do_handshake_on_connect=do_handshake_on_connect)
  279. if server_side:
  280. ssl_sock.do_handshake()
  281. self.__sock.close()
  282. self.__sock = ssl_sock
  283. def __getattr__(self, name):
  284. return getattr(self.__sock, name)
  285. def __repr__(self):
  286. return '<{} {!r}>'.format(type(self).__name__, self.__sock)
  287. class SocketThread(threading.Thread):
  288. def stop(self):
  289. self._active = False
  290. self.join()
  291. def __enter__(self):
  292. self.start()
  293. return self
  294. def __exit__(self, *exc):
  295. self.stop()
  296. class TestThreadedClient(SocketThread):
  297. def __init__(self, test, sock, prog, timeout):
  298. threading.Thread.__init__(self, None, None, 'test-client')
  299. self.daemon = True
  300. self._timeout = timeout
  301. self._sock = sock
  302. self._active = True
  303. self._prog = prog
  304. self._test = test
  305. def run(self):
  306. try:
  307. self._prog(TestSocketWrapper(self._sock))
  308. except Exception as ex:
  309. self._test._abort_socket_test(ex)
  310. class TestThreadedServer(SocketThread):
  311. def __init__(self, test, sock, prog, timeout, max_clients):
  312. threading.Thread.__init__(self, None, None, 'test-server')
  313. self.daemon = True
  314. self._clients = 0
  315. self._finished_clients = 0
  316. self._max_clients = max_clients
  317. self._timeout = timeout
  318. self._sock = sock
  319. self._active = True
  320. self._prog = prog
  321. self._s1, self._s2 = socket.socketpair()
  322. self._s1.setblocking(False)
  323. self._test = test
  324. def stop(self):
  325. try:
  326. if self._s2 and self._s2.fileno() != -1:
  327. try:
  328. self._s2.send(b'stop')
  329. except OSError:
  330. pass
  331. finally:
  332. super().stop()
  333. def run(self):
  334. try:
  335. with self._sock:
  336. self._sock.setblocking(0)
  337. self._run()
  338. finally:
  339. self._s1.close()
  340. self._s2.close()
  341. def _run(self):
  342. while self._active:
  343. if self._clients >= self._max_clients:
  344. return
  345. r, w, x = select.select(
  346. [self._sock, self._s1], [], [], self._timeout)
  347. if self._s1 in r:
  348. return
  349. if self._sock in r:
  350. try:
  351. conn, addr = self._sock.accept()
  352. except BlockingIOError:
  353. continue
  354. except socket.timeout:
  355. if not self._active:
  356. return
  357. else:
  358. raise
  359. else:
  360. self._clients += 1
  361. conn.settimeout(self._timeout)
  362. try:
  363. with conn:
  364. self._handle_client(conn)
  365. except Exception as ex:
  366. self._active = False
  367. try:
  368. raise
  369. finally:
  370. self._test._abort_socket_test(ex)
  371. def _handle_client(self, sock):
  372. self._prog(TestSocketWrapper(sock))
  373. @property
  374. def addr(self):
  375. return self._sock.getsockname()
  376. ###############################################################################
  377. # A few helpers from asyncio/tests/testutils.py
  378. ###############################################################################
  379. def run_briefly(loop):
  380. async def once():
  381. pass
  382. gen = once()
  383. t = loop.create_task(gen)
  384. # Don't log a warning if the task is not done after run_until_complete().
  385. # It occurs if the loop is stopped or if a task raises a BaseException.
  386. t._log_destroy_pending = False
  387. try:
  388. loop.run_until_complete(t)
  389. finally:
  390. gen.close()
  391. def run_until(loop, pred, timeout=30):
  392. deadline = time.time() + timeout
  393. while not pred():
  394. if timeout is not None:
  395. timeout = deadline - time.time()
  396. if timeout <= 0:
  397. raise asyncio.futures.TimeoutError()
  398. loop.run_until_complete(asyncio.tasks.sleep(0.001, loop=loop))
  399. @contextlib.contextmanager
  400. def disable_logger():
  401. """Context manager to disable asyncio logger.
  402. For example, it can be used to ignore warnings in debug mode.
  403. """
  404. old_level = asyncio.log.logger.level
  405. try:
  406. asyncio.log.logger.setLevel(logging.CRITICAL + 1)
  407. yield
  408. finally:
  409. asyncio.log.logger.setLevel(old_level)