You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

105 line
4.0 KiB

  1. import gzip
  2. import io
  3. from starlette.datastructures import Headers, MutableHeaders
  4. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  5. class GZipMiddleware:
  6. def __init__(
  7. self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
  8. ) -> None:
  9. self.app = app
  10. self.minimum_size = minimum_size
  11. self.compresslevel = compresslevel
  12. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  13. if scope["type"] == "http":
  14. headers = Headers(scope=scope)
  15. if "gzip" in headers.get("Accept-Encoding", ""):
  16. responder = GZipResponder(
  17. self.app, self.minimum_size, compresslevel=self.compresslevel
  18. )
  19. await responder(scope, receive, send)
  20. return
  21. await self.app(scope, receive, send)
  22. class GZipResponder:
  23. def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
  24. self.app = app
  25. self.minimum_size = minimum_size
  26. self.send: Send = unattached_send
  27. self.initial_message: Message = {}
  28. self.started = False
  29. self.gzip_buffer = io.BytesIO()
  30. self.gzip_file = gzip.GzipFile(
  31. mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
  32. )
  33. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  34. self.send = send
  35. await self.app(scope, receive, self.send_with_gzip)
  36. async def send_with_gzip(self, message: Message) -> None:
  37. message_type = message["type"]
  38. if message_type == "http.response.start":
  39. # Don't send the initial message until we've determined how to
  40. # modify the outgoing headers correctly.
  41. self.initial_message = message
  42. elif message_type == "http.response.body" and not self.started:
  43. self.started = True
  44. body = message.get("body", b"")
  45. more_body = message.get("more_body", False)
  46. if len(body) < self.minimum_size and not more_body:
  47. # Don't apply GZip to small outgoing responses.
  48. await self.send(self.initial_message)
  49. await self.send(message)
  50. elif not more_body:
  51. # Standard GZip response.
  52. self.gzip_file.write(body)
  53. self.gzip_file.close()
  54. body = self.gzip_buffer.getvalue()
  55. headers = MutableHeaders(raw=self.initial_message["headers"])
  56. headers["Content-Encoding"] = "gzip"
  57. headers["Content-Length"] = str(len(body))
  58. headers.add_vary_header("Accept-Encoding")
  59. message["body"] = body
  60. await self.send(self.initial_message)
  61. await self.send(message)
  62. else:
  63. # Initial body in streaming GZip response.
  64. headers = MutableHeaders(raw=self.initial_message["headers"])
  65. headers["Content-Encoding"] = "gzip"
  66. headers.add_vary_header("Accept-Encoding")
  67. del headers["Content-Length"]
  68. self.gzip_file.write(body)
  69. message["body"] = self.gzip_buffer.getvalue()
  70. self.gzip_buffer.seek(0)
  71. self.gzip_buffer.truncate()
  72. await self.send(self.initial_message)
  73. await self.send(message)
  74. elif message_type == "http.response.body":
  75. # Remaining body in streaming GZip response.
  76. body = message.get("body", b"")
  77. more_body = message.get("more_body", False)
  78. self.gzip_file.write(body)
  79. if not more_body:
  80. self.gzip_file.close()
  81. message["body"] = self.gzip_buffer.getvalue()
  82. self.gzip_buffer.seek(0)
  83. self.gzip_buffer.truncate()
  84. await self.send(message)
  85. async def unattached_send(message: Message) -> None:
  86. raise RuntimeError("send awaitable not set") # pragma: no cover