You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

161 line
4.5 KiB

  1. import enum
  2. import sys
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, Generic, Set, TypeVar, Union, overload
  5. from weakref import WeakKeyDictionary
  6. from ._core._eventloop import get_asynclib
  7. if sys.version_info >= (3, 8):
  8. from typing import Literal
  9. else:
  10. from typing_extensions import Literal
  11. T = TypeVar('T')
  12. D = TypeVar('D')
  13. async def checkpoint() -> None:
  14. """
  15. Check for cancellation and allow the scheduler to switch to another task.
  16. Equivalent to (but more efficient than)::
  17. await checkpoint_if_cancelled()
  18. await cancel_shielded_checkpoint()
  19. .. versionadded:: 3.0
  20. """
  21. await get_asynclib().checkpoint()
  22. async def checkpoint_if_cancelled() -> None:
  23. """
  24. Enter a checkpoint if the enclosing cancel scope has been cancelled.
  25. This does not allow the scheduler to switch to a different task.
  26. .. versionadded:: 3.0
  27. """
  28. await get_asynclib().checkpoint_if_cancelled()
  29. async def cancel_shielded_checkpoint() -> None:
  30. """
  31. Allow the scheduler to switch to another task but without checking for cancellation.
  32. Equivalent to (but potentially more efficient than)::
  33. with CancelScope(shield=True):
  34. await checkpoint()
  35. .. versionadded:: 3.0
  36. """
  37. await get_asynclib().cancel_shielded_checkpoint()
  38. def current_token() -> object:
  39. """Return a backend specific token object that can be used to get back to the event loop."""
  40. return get_asynclib().current_token()
  41. _run_vars = WeakKeyDictionary() # type: WeakKeyDictionary[Any, Dict[str, Any]]
  42. _token_wrappers: Dict[Any, '_TokenWrapper'] = {}
  43. @dataclass(frozen=True)
  44. class _TokenWrapper:
  45. __slots__ = '_token', '__weakref__'
  46. _token: object
  47. class _NoValueSet(enum.Enum):
  48. NO_VALUE_SET = enum.auto()
  49. class RunvarToken(Generic[T]):
  50. __slots__ = '_var', '_value', '_redeemed'
  51. def __init__(self, var: 'RunVar[T]', value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]]):
  52. self._var = var
  53. self._value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = value
  54. self._redeemed = False
  55. class RunVar(Generic[T]):
  56. """Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop."""
  57. __slots__ = '_name', '_default'
  58. NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
  59. _token_wrappers: Set[_TokenWrapper] = set()
  60. def __init__(self, name: str,
  61. default: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET):
  62. self._name = name
  63. self._default = default
  64. @property
  65. def _current_vars(self) -> Dict[str, T]:
  66. token = current_token()
  67. while True:
  68. try:
  69. return _run_vars[token]
  70. except TypeError:
  71. # Happens when token isn't weak referable (TrioToken).
  72. # This workaround does mean that some memory will leak on Trio until the problem
  73. # is fixed on their end.
  74. token = _TokenWrapper(token)
  75. self._token_wrappers.add(token)
  76. except KeyError:
  77. run_vars = _run_vars[token] = {}
  78. return run_vars
  79. @overload
  80. def get(self, default: D) -> Union[T, D]: ...
  81. @overload
  82. def get(self) -> T: ...
  83. def get(
  84. self, default: Union[D, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET
  85. ) -> Union[T, D]:
  86. try:
  87. return self._current_vars[self._name]
  88. except KeyError:
  89. if default is not RunVar.NO_VALUE_SET:
  90. return default
  91. elif self._default is not RunVar.NO_VALUE_SET:
  92. return self._default
  93. raise LookupError(f'Run variable "{self._name}" has no value and no default set')
  94. def set(self, value: T) -> RunvarToken[T]:
  95. current_vars = self._current_vars
  96. token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET))
  97. current_vars[self._name] = value
  98. return token
  99. def reset(self, token: RunvarToken[T]) -> None:
  100. if token._var is not self:
  101. raise ValueError('This token does not belong to this RunVar')
  102. if token._redeemed:
  103. raise ValueError('This token has already been used')
  104. if token._value is _NoValueSet.NO_VALUE_SET:
  105. try:
  106. del self._current_vars[self._name]
  107. except KeyError:
  108. pass
  109. else:
  110. self._current_vars[self._name] = token._value
  111. token._redeemed = True
  112. def __repr__(self) -> str:
  113. return f'<RunVar name={self._name!r}>'