Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

139 řádky
4.6 KiB

  1. from __future__ import annotations
  2. import inspect
  3. import sys
  4. from collections.abc import Callable, Iterable, Mapping
  5. from contextlib import AbstractContextManager
  6. from types import TracebackType
  7. from typing import TYPE_CHECKING, Any
  8. if sys.version_info < (3, 11):
  9. from ._exceptions import BaseExceptionGroup
  10. if TYPE_CHECKING:
  11. _Handler = Callable[[BaseExceptionGroup[Any]], Any]
  12. class _Catcher:
  13. def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]):
  14. self._handler_map = handler_map
  15. def __enter__(self) -> None:
  16. pass
  17. def __exit__(
  18. self,
  19. etype: type[BaseException] | None,
  20. exc: BaseException | None,
  21. tb: TracebackType | None,
  22. ) -> bool:
  23. if exc is not None:
  24. unhandled = self.handle_exception(exc)
  25. if unhandled is exc:
  26. return False
  27. elif unhandled is None:
  28. return True
  29. else:
  30. if isinstance(exc, BaseExceptionGroup):
  31. try:
  32. raise unhandled from exc.__cause__
  33. except BaseExceptionGroup:
  34. # Change __context__ to __cause__ because Python 3.11 does this
  35. # too
  36. unhandled.__context__ = exc.__cause__
  37. raise
  38. raise unhandled from exc
  39. return False
  40. def handle_exception(self, exc: BaseException) -> BaseException | None:
  41. excgroup: BaseExceptionGroup | None
  42. if isinstance(exc, BaseExceptionGroup):
  43. excgroup = exc
  44. else:
  45. excgroup = BaseExceptionGroup("", [exc])
  46. new_exceptions: list[BaseException] = []
  47. for exc_types, handler in self._handler_map.items():
  48. matched, excgroup = excgroup.split(exc_types)
  49. if matched:
  50. try:
  51. try:
  52. raise matched
  53. except BaseExceptionGroup:
  54. result = handler(matched)
  55. except BaseExceptionGroup as new_exc:
  56. if new_exc is matched:
  57. new_exceptions.append(new_exc)
  58. else:
  59. new_exceptions.extend(new_exc.exceptions)
  60. except BaseException as new_exc:
  61. new_exceptions.append(new_exc)
  62. else:
  63. if inspect.iscoroutine(result):
  64. raise TypeError(
  65. f"Error trying to handle {matched!r} with {handler!r}. "
  66. "Exception handler must be a sync function."
  67. ) from exc
  68. if not excgroup:
  69. break
  70. if new_exceptions:
  71. if len(new_exceptions) == 1:
  72. return new_exceptions[0]
  73. return BaseExceptionGroup("", new_exceptions)
  74. elif (
  75. excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc
  76. ):
  77. return exc
  78. else:
  79. return excgroup
  80. def catch(
  81. __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler],
  82. ) -> AbstractContextManager[None]:
  83. if not isinstance(__handlers, Mapping):
  84. raise TypeError("the argument must be a mapping")
  85. handler_map: dict[
  86. tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]]
  87. ] = {}
  88. for type_or_iterable, handler in __handlers.items():
  89. iterable: tuple[type[BaseException]]
  90. if isinstance(type_or_iterable, type) and issubclass(
  91. type_or_iterable, BaseException
  92. ):
  93. iterable = (type_or_iterable,)
  94. elif isinstance(type_or_iterable, Iterable):
  95. iterable = tuple(type_or_iterable)
  96. else:
  97. raise TypeError(
  98. "each key must be either an exception classes or an iterable thereof"
  99. )
  100. if not callable(handler):
  101. raise TypeError("handlers must be callable")
  102. for exc_type in iterable:
  103. if not isinstance(exc_type, type) or not issubclass(
  104. exc_type, BaseException
  105. ):
  106. raise TypeError(
  107. "each key must be either an exception classes or an iterable "
  108. "thereof"
  109. )
  110. if issubclass(exc_type, BaseExceptionGroup):
  111. raise TypeError(
  112. "catching ExceptionGroup with catch() is not allowed. "
  113. "Use except instead."
  114. )
  115. handler_map[iterable] = handler
  116. return _Catcher(handler_map)