您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

149 行
4.8 KiB

  1. from __future__ import annotations
  2. import functools
  3. import inspect
  4. import sys
  5. from collections.abc import Sequence
  6. from typing import Any, Callable
  7. from urllib.parse import urlencode
  8. if sys.version_info >= (3, 10): # pragma: no cover
  9. from typing import ParamSpec
  10. else: # pragma: no cover
  11. from typing_extensions import ParamSpec
  12. from starlette._utils import is_async_callable
  13. from starlette.exceptions import HTTPException
  14. from starlette.requests import HTTPConnection, Request
  15. from starlette.responses import RedirectResponse
  16. from starlette.websockets import WebSocket
  17. _P = ParamSpec("_P")
  18. def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
  19. for scope in scopes:
  20. if scope not in conn.auth.scopes:
  21. return False
  22. return True
  23. def requires(
  24. scopes: str | Sequence[str],
  25. status_code: int = 403,
  26. redirect: str | None = None,
  27. ) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
  28. scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
  29. def decorator(
  30. func: Callable[_P, Any],
  31. ) -> Callable[_P, Any]:
  32. sig = inspect.signature(func)
  33. for idx, parameter in enumerate(sig.parameters.values()):
  34. if parameter.name == "request" or parameter.name == "websocket":
  35. type_ = parameter.name
  36. break
  37. else:
  38. raise Exception(f'No "request" or "websocket" argument on function "{func}"')
  39. if type_ == "websocket":
  40. # Handle websocket functions. (Always async)
  41. @functools.wraps(func)
  42. async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
  43. websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
  44. assert isinstance(websocket, WebSocket)
  45. if not has_required_scope(websocket, scopes_list):
  46. await websocket.close()
  47. else:
  48. await func(*args, **kwargs)
  49. return websocket_wrapper
  50. elif is_async_callable(func):
  51. # Handle async request/response functions.
  52. @functools.wraps(func)
  53. async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
  54. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  55. assert isinstance(request, Request)
  56. if not has_required_scope(request, scopes_list):
  57. if redirect is not None:
  58. orig_request_qparam = urlencode({"next": str(request.url)})
  59. next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
  60. return RedirectResponse(url=next_url, status_code=303)
  61. raise HTTPException(status_code=status_code)
  62. return await func(*args, **kwargs)
  63. return async_wrapper
  64. else:
  65. # Handle sync request/response functions.
  66. @functools.wraps(func)
  67. def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
  68. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  69. assert isinstance(request, Request)
  70. if not has_required_scope(request, scopes_list):
  71. if redirect is not None:
  72. orig_request_qparam = urlencode({"next": str(request.url)})
  73. next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
  74. return RedirectResponse(url=next_url, status_code=303)
  75. raise HTTPException(status_code=status_code)
  76. return func(*args, **kwargs)
  77. return sync_wrapper
  78. return decorator
  79. class AuthenticationError(Exception):
  80. pass
  81. class AuthenticationBackend:
  82. async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
  83. raise NotImplementedError() # pragma: no cover
  84. class AuthCredentials:
  85. def __init__(self, scopes: Sequence[str] | None = None):
  86. self.scopes = [] if scopes is None else list(scopes)
  87. class BaseUser:
  88. @property
  89. def is_authenticated(self) -> bool:
  90. raise NotImplementedError() # pragma: no cover
  91. @property
  92. def display_name(self) -> str:
  93. raise NotImplementedError() # pragma: no cover
  94. @property
  95. def identity(self) -> str:
  96. raise NotImplementedError() # pragma: no cover
  97. class SimpleUser(BaseUser):
  98. def __init__(self, username: str) -> None:
  99. self.username = username
  100. @property
  101. def is_authenticated(self) -> bool:
  102. return True
  103. @property
  104. def display_name(self) -> str:
  105. return self.username
  106. class UnauthenticatedUser(BaseUser):
  107. @property
  108. def is_authenticated(self) -> bool:
  109. return False
  110. @property
  111. def display_name(self) -> str:
  112. return ""