25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.
 
 
 
 

277 satır
11 KiB

  1. from __future__ import annotations
  2. from collections.abc import AsyncGenerator
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from tempfile import SpooledTemporaryFile
  6. from typing import TYPE_CHECKING
  7. from urllib.parse import unquote_plus
  8. from starlette.datastructures import FormData, Headers, UploadFile
  9. if TYPE_CHECKING:
  10. import python_multipart as multipart
  11. from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
  12. else:
  13. try:
  14. try:
  15. import python_multipart as multipart
  16. from python_multipart.multipart import parse_options_header
  17. except ModuleNotFoundError: # pragma: no cover
  18. import multipart
  19. from multipart.multipart import parse_options_header
  20. except ModuleNotFoundError: # pragma: no cover
  21. multipart = None
  22. parse_options_header = None
  23. class FormMessage(Enum):
  24. FIELD_START = 1
  25. FIELD_NAME = 2
  26. FIELD_DATA = 3
  27. FIELD_END = 4
  28. END = 5
  29. @dataclass
  30. class MultipartPart:
  31. content_disposition: bytes | None = None
  32. field_name: str = ""
  33. data: bytearray = field(default_factory=bytearray)
  34. file: UploadFile | None = None
  35. item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
  36. def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
  37. try:
  38. return src.decode(codec)
  39. except (UnicodeDecodeError, LookupError):
  40. return src.decode("latin-1")
  41. class MultiPartException(Exception):
  42. def __init__(self, message: str) -> None:
  43. self.message = message
  44. class FormParser:
  45. def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None:
  46. assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
  47. self.headers = headers
  48. self.stream = stream
  49. self.messages: list[tuple[FormMessage, bytes]] = []
  50. def on_field_start(self) -> None:
  51. message = (FormMessage.FIELD_START, b"")
  52. self.messages.append(message)
  53. def on_field_name(self, data: bytes, start: int, end: int) -> None:
  54. message = (FormMessage.FIELD_NAME, data[start:end])
  55. self.messages.append(message)
  56. def on_field_data(self, data: bytes, start: int, end: int) -> None:
  57. message = (FormMessage.FIELD_DATA, data[start:end])
  58. self.messages.append(message)
  59. def on_field_end(self) -> None:
  60. message = (FormMessage.FIELD_END, b"")
  61. self.messages.append(message)
  62. def on_end(self) -> None:
  63. message = (FormMessage.END, b"")
  64. self.messages.append(message)
  65. async def parse(self) -> FormData:
  66. # Callbacks dictionary.
  67. callbacks: QuerystringCallbacks = {
  68. "on_field_start": self.on_field_start,
  69. "on_field_name": self.on_field_name,
  70. "on_field_data": self.on_field_data,
  71. "on_field_end": self.on_field_end,
  72. "on_end": self.on_end,
  73. }
  74. # Create the parser.
  75. parser = multipart.QuerystringParser(callbacks)
  76. field_name = b""
  77. field_value = b""
  78. items: list[tuple[str, str | UploadFile]] = []
  79. # Feed the parser with data from the request.
  80. async for chunk in self.stream:
  81. if chunk:
  82. parser.write(chunk)
  83. else:
  84. parser.finalize()
  85. messages = list(self.messages)
  86. self.messages.clear()
  87. for message_type, message_bytes in messages:
  88. if message_type == FormMessage.FIELD_START:
  89. field_name = b""
  90. field_value = b""
  91. elif message_type == FormMessage.FIELD_NAME:
  92. field_name += message_bytes
  93. elif message_type == FormMessage.FIELD_DATA:
  94. field_value += message_bytes
  95. elif message_type == FormMessage.FIELD_END:
  96. name = unquote_plus(field_name.decode("latin-1"))
  97. value = unquote_plus(field_value.decode("latin-1"))
  98. items.append((name, value))
  99. return FormData(items)
  100. class MultiPartParser:
  101. spool_max_size = 1024 * 1024 # 1MB
  102. """The maximum size of the spooled temporary file used to store file data."""
  103. max_part_size = 1024 * 1024 # 1MB
  104. """The maximum size of a part in the multipart request."""
  105. def __init__(
  106. self,
  107. headers: Headers,
  108. stream: AsyncGenerator[bytes, None],
  109. *,
  110. max_files: int | float = 1000,
  111. max_fields: int | float = 1000,
  112. max_part_size: int = 1024 * 1024, # 1MB
  113. ) -> None:
  114. assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
  115. self.headers = headers
  116. self.stream = stream
  117. self.max_files = max_files
  118. self.max_fields = max_fields
  119. self.items: list[tuple[str, str | UploadFile]] = []
  120. self._current_files = 0
  121. self._current_fields = 0
  122. self._current_partial_header_name: bytes = b""
  123. self._current_partial_header_value: bytes = b""
  124. self._current_part = MultipartPart()
  125. self._charset = ""
  126. self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
  127. self._file_parts_to_finish: list[MultipartPart] = []
  128. self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
  129. self.max_part_size = max_part_size
  130. def on_part_begin(self) -> None:
  131. self._current_part = MultipartPart()
  132. def on_part_data(self, data: bytes, start: int, end: int) -> None:
  133. message_bytes = data[start:end]
  134. if self._current_part.file is None:
  135. if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
  136. raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
  137. self._current_part.data.extend(message_bytes)
  138. else:
  139. self._file_parts_to_write.append((self._current_part, message_bytes))
  140. def on_part_end(self) -> None:
  141. if self._current_part.file is None:
  142. self.items.append(
  143. (
  144. self._current_part.field_name,
  145. _user_safe_decode(self._current_part.data, self._charset),
  146. )
  147. )
  148. else:
  149. self._file_parts_to_finish.append(self._current_part)
  150. # The file can be added to the items right now even though it's not
  151. # finished yet, because it will be finished in the `parse()` method, before
  152. # self.items is used in the return value.
  153. self.items.append((self._current_part.field_name, self._current_part.file))
  154. def on_header_field(self, data: bytes, start: int, end: int) -> None:
  155. self._current_partial_header_name += data[start:end]
  156. def on_header_value(self, data: bytes, start: int, end: int) -> None:
  157. self._current_partial_header_value += data[start:end]
  158. def on_header_end(self) -> None:
  159. field = self._current_partial_header_name.lower()
  160. if field == b"content-disposition":
  161. self._current_part.content_disposition = self._current_partial_header_value
  162. self._current_part.item_headers.append((field, self._current_partial_header_value))
  163. self._current_partial_header_name = b""
  164. self._current_partial_header_value = b""
  165. def on_headers_finished(self) -> None:
  166. disposition, options = parse_options_header(self._current_part.content_disposition)
  167. try:
  168. self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
  169. except KeyError:
  170. raise MultiPartException('The Content-Disposition header field "name" must be provided.')
  171. if b"filename" in options:
  172. self._current_files += 1
  173. if self._current_files > self.max_files:
  174. raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
  175. filename = _user_safe_decode(options[b"filename"], self._charset)
  176. tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
  177. self._files_to_close_on_error.append(tempfile)
  178. self._current_part.file = UploadFile(
  179. file=tempfile, # type: ignore[arg-type]
  180. size=0,
  181. filename=filename,
  182. headers=Headers(raw=self._current_part.item_headers),
  183. )
  184. else:
  185. self._current_fields += 1
  186. if self._current_fields > self.max_fields:
  187. raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
  188. self._current_part.file = None
  189. def on_end(self) -> None:
  190. pass
  191. async def parse(self) -> FormData:
  192. # Parse the Content-Type header to get the multipart boundary.
  193. _, params = parse_options_header(self.headers["Content-Type"])
  194. charset = params.get(b"charset", "utf-8")
  195. if isinstance(charset, bytes):
  196. charset = charset.decode("latin-1")
  197. self._charset = charset
  198. try:
  199. boundary = params[b"boundary"]
  200. except KeyError:
  201. raise MultiPartException("Missing boundary in multipart.")
  202. # Callbacks dictionary.
  203. callbacks: MultipartCallbacks = {
  204. "on_part_begin": self.on_part_begin,
  205. "on_part_data": self.on_part_data,
  206. "on_part_end": self.on_part_end,
  207. "on_header_field": self.on_header_field,
  208. "on_header_value": self.on_header_value,
  209. "on_header_end": self.on_header_end,
  210. "on_headers_finished": self.on_headers_finished,
  211. "on_end": self.on_end,
  212. }
  213. # Create the parser.
  214. parser = multipart.MultipartParser(boundary, callbacks)
  215. try:
  216. # Feed the parser with data from the request.
  217. async for chunk in self.stream:
  218. parser.write(chunk)
  219. # Write file data, it needs to use await with the UploadFile methods
  220. # that call the corresponding file methods *in a threadpool*,
  221. # otherwise, if they were called directly in the callback methods above
  222. # (regular, non-async functions), that would block the event loop in
  223. # the main thread.
  224. for part, data in self._file_parts_to_write:
  225. assert part.file # for type checkers
  226. await part.file.write(data)
  227. for part in self._file_parts_to_finish:
  228. assert part.file # for type checkers
  229. await part.file.seek(0)
  230. self._file_parts_to_write.clear()
  231. self._file_parts_to_finish.clear()
  232. except MultiPartException as exc:
  233. # Close all the files if there was an error.
  234. for file in self._files_to_close_on_error:
  235. file.close()
  236. raise exc
  237. parser.finalize()
  238. return FormData(self.items)