You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

146 line
5.8 KiB

  1. import gzip
  2. import io
  3. from typing import NoReturn
  4. from starlette.datastructures import Headers, MutableHeaders
  5. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  6. DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
  7. class GZipMiddleware:
  8. def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
  9. self.app = app
  10. self.minimum_size = minimum_size
  11. self.compresslevel = compresslevel
  12. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  13. if scope["type"] != "http": # pragma: no cover
  14. await self.app(scope, receive, send)
  15. return
  16. headers = Headers(scope=scope)
  17. responder: ASGIApp
  18. if "gzip" in headers.get("Accept-Encoding", ""):
  19. responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
  20. else:
  21. responder = IdentityResponder(self.app, self.minimum_size)
  22. await responder(scope, receive, send)
  23. class IdentityResponder:
  24. content_encoding: str
  25. def __init__(self, app: ASGIApp, minimum_size: int) -> None:
  26. self.app = app
  27. self.minimum_size = minimum_size
  28. self.send: Send = unattached_send
  29. self.initial_message: Message = {}
  30. self.started = False
  31. self.content_encoding_set = False
  32. self.content_type_is_excluded = False
  33. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  34. self.send = send
  35. await self.app(scope, receive, self.send_with_compression)
  36. async def send_with_compression(self, message: Message) -> None:
  37. message_type = message["type"]
  38. if message_type == "http.response.start":
  39. # Don't send the initial message until we've determined how to
  40. # modify the outgoing headers correctly.
  41. self.initial_message = message
  42. headers = Headers(raw=self.initial_message["headers"])
  43. self.content_encoding_set = "content-encoding" in headers
  44. self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
  45. elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
  46. if not self.started:
  47. self.started = True
  48. await self.send(self.initial_message)
  49. await self.send(message)
  50. elif message_type == "http.response.body" and not self.started:
  51. self.started = True
  52. body = message.get("body", b"")
  53. more_body = message.get("more_body", False)
  54. if len(body) < self.minimum_size and not more_body:
  55. # Don't apply compression to small outgoing responses.
  56. await self.send(self.initial_message)
  57. await self.send(message)
  58. elif not more_body:
  59. # Standard response.
  60. body = self.apply_compression(body, more_body=False)
  61. headers = MutableHeaders(raw=self.initial_message["headers"])
  62. headers.add_vary_header("Accept-Encoding")
  63. if body != message["body"]:
  64. headers["Content-Encoding"] = self.content_encoding
  65. headers["Content-Length"] = str(len(body))
  66. message["body"] = body
  67. await self.send(self.initial_message)
  68. await self.send(message)
  69. else:
  70. # Initial body in streaming response.
  71. body = self.apply_compression(body, more_body=True)
  72. headers = MutableHeaders(raw=self.initial_message["headers"])
  73. headers.add_vary_header("Accept-Encoding")
  74. if body != message["body"]:
  75. headers["Content-Encoding"] = self.content_encoding
  76. del headers["Content-Length"]
  77. message["body"] = body
  78. await self.send(self.initial_message)
  79. await self.send(message)
  80. elif message_type == "http.response.body":
  81. # Remaining body in streaming response.
  82. body = message.get("body", b"")
  83. more_body = message.get("more_body", False)
  84. message["body"] = self.apply_compression(body, more_body=more_body)
  85. await self.send(message)
  86. elif message_type == "http.response.pathsend": # pragma: no branch
  87. # Don't apply GZip to pathsend responses
  88. await self.send(self.initial_message)
  89. await self.send(message)
  90. def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
  91. """Apply compression on the response body.
  92. If more_body is False, any compression file should be closed. If it
  93. isn't, it won't be closed automatically until all background tasks
  94. complete.
  95. """
  96. return body
  97. class GZipResponder(IdentityResponder):
  98. content_encoding = "gzip"
  99. def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
  100. super().__init__(app, minimum_size)
  101. self.gzip_buffer = io.BytesIO()
  102. self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
  103. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  104. with self.gzip_buffer, self.gzip_file:
  105. await super().__call__(scope, receive, send)
  106. def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
  107. self.gzip_file.write(body)
  108. if not more_body:
  109. self.gzip_file.close()
  110. body = self.gzip_buffer.getvalue()
  111. self.gzip_buffer.seek(0)
  112. self.gzip_buffer.truncate()
  113. return body
  114. async def unattached_send(message: Message) -> NoReturn:
  115. raise RuntimeError("send awaitable not set") # pragma: no cover