|
- from __future__ import annotations
-
- import inspect
- import sys
- from collections.abc import Callable, Iterable, Mapping
- from contextlib import AbstractContextManager
- from types import TracebackType
- from typing import TYPE_CHECKING, Any
-
- if sys.version_info < (3, 11):
- from ._exceptions import BaseExceptionGroup
-
- if TYPE_CHECKING:
- _Handler = Callable[[BaseExceptionGroup[Any]], Any]
-
-
- class _Catcher:
- def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]):
- self._handler_map = handler_map
-
- def __enter__(self) -> None:
- pass
-
- def __exit__(
- self,
- etype: type[BaseException] | None,
- exc: BaseException | None,
- tb: TracebackType | None,
- ) -> bool:
- if exc is not None:
- unhandled = self.handle_exception(exc)
- if unhandled is exc:
- return False
- elif unhandled is None:
- return True
- else:
- if isinstance(exc, BaseExceptionGroup):
- try:
- raise unhandled from exc.__cause__
- except BaseExceptionGroup:
- # Change __context__ to __cause__ because Python 3.11 does this
- # too
- unhandled.__context__ = exc.__cause__
- raise
-
- raise unhandled from exc
-
- return False
-
- def handle_exception(self, exc: BaseException) -> BaseException | None:
- excgroup: BaseExceptionGroup | None
- if isinstance(exc, BaseExceptionGroup):
- excgroup = exc
- else:
- excgroup = BaseExceptionGroup("", [exc])
-
- new_exceptions: list[BaseException] = []
- for exc_types, handler in self._handler_map.items():
- matched, excgroup = excgroup.split(exc_types)
- if matched:
- try:
- try:
- raise matched
- except BaseExceptionGroup:
- result = handler(matched)
- except BaseExceptionGroup as new_exc:
- if new_exc is matched:
- new_exceptions.append(new_exc)
- else:
- new_exceptions.extend(new_exc.exceptions)
- except BaseException as new_exc:
- new_exceptions.append(new_exc)
- else:
- if inspect.iscoroutine(result):
- raise TypeError(
- f"Error trying to handle {matched!r} with {handler!r}. "
- "Exception handler must be a sync function."
- ) from exc
-
- if not excgroup:
- break
-
- if new_exceptions:
- if len(new_exceptions) == 1:
- return new_exceptions[0]
-
- return BaseExceptionGroup("", new_exceptions)
- elif (
- excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc
- ):
- return exc
- else:
- return excgroup
-
-
- def catch(
- __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler],
- ) -> AbstractContextManager[None]:
- if not isinstance(__handlers, Mapping):
- raise TypeError("the argument must be a mapping")
-
- handler_map: dict[
- tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]]
- ] = {}
- for type_or_iterable, handler in __handlers.items():
- iterable: tuple[type[BaseException]]
- if isinstance(type_or_iterable, type) and issubclass(
- type_or_iterable, BaseException
- ):
- iterable = (type_or_iterable,)
- elif isinstance(type_or_iterable, Iterable):
- iterable = tuple(type_or_iterable)
- else:
- raise TypeError(
- "each key must be either an exception classes or an iterable thereof"
- )
-
- if not callable(handler):
- raise TypeError("handlers must be callable")
-
- for exc_type in iterable:
- if not isinstance(exc_type, type) or not issubclass(
- exc_type, BaseException
- ):
- raise TypeError(
- "each key must be either an exception classes or an iterable "
- "thereof"
- )
-
- if issubclass(exc_type, BaseExceptionGroup):
- raise TypeError(
- "catching ExceptionGroup with catch() is not allowed. "
- "Use except instead."
- )
-
- handler_map[iterable] = handler
-
- return _Catcher(handler_map)
|