Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.
 
 
 
 

145 wiersze
4.4 KiB

  1. import asyncio
  2. import functools
  3. import inspect
  4. import typing
  5. from starlette.exceptions import HTTPException
  6. from starlette.requests import HTTPConnection, Request
  7. from starlette.responses import RedirectResponse, Response
  8. from starlette.websockets import WebSocket
  9. def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
  10. for scope in scopes:
  11. if scope not in conn.auth.scopes:
  12. return False
  13. return True
  14. def requires(
  15. scopes: typing.Union[str, typing.Sequence[str]],
  16. status_code: int = 403,
  17. redirect: str = None,
  18. ) -> typing.Callable:
  19. scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
  20. def decorator(func: typing.Callable) -> typing.Callable:
  21. sig = inspect.signature(func)
  22. for idx, parameter in enumerate(sig.parameters.values()):
  23. if parameter.name == "request" or parameter.name == "websocket":
  24. type_ = parameter.name
  25. break
  26. else:
  27. raise Exception(
  28. f'No "request" or "websocket" argument on function "{func}"'
  29. )
  30. if type_ == "websocket":
  31. # Handle websocket functions. (Always async)
  32. @functools.wraps(func)
  33. async def websocket_wrapper(
  34. *args: typing.Any, **kwargs: typing.Any
  35. ) -> None:
  36. websocket = kwargs.get(
  37. "websocket", args[idx] if idx < len(args) else None
  38. )
  39. assert isinstance(websocket, WebSocket)
  40. if not has_required_scope(websocket, scopes_list):
  41. await websocket.close()
  42. else:
  43. await func(*args, **kwargs)
  44. return websocket_wrapper
  45. elif asyncio.iscoroutinefunction(func):
  46. # Handle async request/response functions.
  47. @functools.wraps(func)
  48. async def async_wrapper(
  49. *args: typing.Any, **kwargs: typing.Any
  50. ) -> Response:
  51. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  52. assert isinstance(request, Request)
  53. if not has_required_scope(request, scopes_list):
  54. if redirect is not None:
  55. return RedirectResponse(
  56. url=request.url_for(redirect), status_code=303
  57. )
  58. raise HTTPException(status_code=status_code)
  59. return await func(*args, **kwargs)
  60. return async_wrapper
  61. else:
  62. # Handle sync request/response functions.
  63. @functools.wraps(func)
  64. def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
  65. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  66. assert isinstance(request, Request)
  67. if not has_required_scope(request, scopes_list):
  68. if redirect is not None:
  69. return RedirectResponse(
  70. url=request.url_for(redirect), status_code=303
  71. )
  72. raise HTTPException(status_code=status_code)
  73. return func(*args, **kwargs)
  74. return sync_wrapper
  75. return decorator
  76. class AuthenticationError(Exception):
  77. pass
  78. class AuthenticationBackend:
  79. async def authenticate(
  80. self, conn: HTTPConnection
  81. ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
  82. raise NotImplementedError() # pragma: no cover
  83. class AuthCredentials:
  84. def __init__(self, scopes: typing.Sequence[str] = None):
  85. self.scopes = [] if scopes is None else list(scopes)
  86. class BaseUser:
  87. @property
  88. def is_authenticated(self) -> bool:
  89. raise NotImplementedError() # pragma: no cover
  90. @property
  91. def display_name(self) -> str:
  92. raise NotImplementedError() # pragma: no cover
  93. @property
  94. def identity(self) -> str:
  95. raise NotImplementedError() # pragma: no cover
  96. class SimpleUser(BaseUser):
  97. def __init__(self, username: str) -> None:
  98. self.username = username
  99. @property
  100. def is_authenticated(self) -> bool:
  101. return True
  102. @property
  103. def display_name(self) -> str:
  104. return self.username
  105. class UnauthenticatedUser(BaseUser):
  106. @property
  107. def is_authenticated(self) -> bool:
  108. return False
  109. @property
  110. def display_name(self) -> str:
  111. return ""