You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

86 lines
3.5 KiB

  1. from __future__ import annotations
  2. import json
  3. from base64 import b64decode, b64encode
  4. from typing import Literal
  5. import itsdangerous
  6. from itsdangerous.exc import BadSignature
  7. from starlette.datastructures import MutableHeaders, Secret
  8. from starlette.requests import HTTPConnection
  9. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  10. class SessionMiddleware:
  11. def __init__(
  12. self,
  13. app: ASGIApp,
  14. secret_key: str | Secret,
  15. session_cookie: str = "session",
  16. max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
  17. path: str = "/",
  18. same_site: Literal["lax", "strict", "none"] = "lax",
  19. https_only: bool = False,
  20. domain: str | None = None,
  21. ) -> None:
  22. self.app = app
  23. self.signer = itsdangerous.TimestampSigner(str(secret_key))
  24. self.session_cookie = session_cookie
  25. self.max_age = max_age
  26. self.path = path
  27. self.security_flags = "httponly; samesite=" + same_site
  28. if https_only: # Secure flag can be used with HTTPS only
  29. self.security_flags += "; secure"
  30. if domain is not None:
  31. self.security_flags += f"; domain={domain}"
  32. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  33. if scope["type"] not in ("http", "websocket"): # pragma: no cover
  34. await self.app(scope, receive, send)
  35. return
  36. connection = HTTPConnection(scope)
  37. initial_session_was_empty = True
  38. if self.session_cookie in connection.cookies:
  39. data = connection.cookies[self.session_cookie].encode("utf-8")
  40. try:
  41. data = self.signer.unsign(data, max_age=self.max_age)
  42. scope["session"] = json.loads(b64decode(data))
  43. initial_session_was_empty = False
  44. except BadSignature:
  45. scope["session"] = {}
  46. else:
  47. scope["session"] = {}
  48. async def send_wrapper(message: Message) -> None:
  49. if message["type"] == "http.response.start":
  50. if scope["session"]:
  51. # We have session data to persist.
  52. data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
  53. data = self.signer.sign(data)
  54. headers = MutableHeaders(scope=message)
  55. header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
  56. session_cookie=self.session_cookie,
  57. data=data.decode("utf-8"),
  58. path=self.path,
  59. max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
  60. security_flags=self.security_flags,
  61. )
  62. headers.append("Set-Cookie", header_value)
  63. elif not initial_session_was_empty:
  64. # The session has been cleared.
  65. headers = MutableHeaders(scope=message)
  66. header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
  67. session_cookie=self.session_cookie,
  68. data="null",
  69. path=self.path,
  70. expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
  71. security_flags=self.security_flags,
  72. )
  73. headers.append("Set-Cookie", header_value)
  74. await send(message)
  75. await self.app(scope, receive, send_wrapper)