您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

235 行
8.3 KiB

  1. import typing
  2. from enum import Enum
  3. from urllib.parse import unquote_plus
  4. from starlette.datastructures import FormData, Headers, UploadFile
  5. try:
  6. import multipart
  7. from multipart.multipart import parse_options_header
  8. except ImportError: # pragma: nocover
  9. parse_options_header = None
  10. multipart = None
  11. class FormMessage(Enum):
  12. FIELD_START = 1
  13. FIELD_NAME = 2
  14. FIELD_DATA = 3
  15. FIELD_END = 4
  16. END = 5
  17. class MultiPartMessage(Enum):
  18. PART_BEGIN = 1
  19. PART_DATA = 2
  20. PART_END = 3
  21. HEADER_FIELD = 4
  22. HEADER_VALUE = 5
  23. HEADER_END = 6
  24. HEADERS_FINISHED = 7
  25. END = 8
  26. def _user_safe_decode(src: bytes, codec: str) -> str:
  27. try:
  28. return src.decode(codec)
  29. except (UnicodeDecodeError, LookupError):
  30. return src.decode("latin-1")
  31. class FormParser:
  32. def __init__(
  33. self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
  34. ) -> None:
  35. assert (
  36. multipart is not None
  37. ), "The `python-multipart` library must be installed to use form parsing."
  38. self.headers = headers
  39. self.stream = stream
  40. self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = []
  41. def on_field_start(self) -> None:
  42. message = (FormMessage.FIELD_START, b"")
  43. self.messages.append(message)
  44. def on_field_name(self, data: bytes, start: int, end: int) -> None:
  45. message = (FormMessage.FIELD_NAME, data[start:end])
  46. self.messages.append(message)
  47. def on_field_data(self, data: bytes, start: int, end: int) -> None:
  48. message = (FormMessage.FIELD_DATA, data[start:end])
  49. self.messages.append(message)
  50. def on_field_end(self) -> None:
  51. message = (FormMessage.FIELD_END, b"")
  52. self.messages.append(message)
  53. def on_end(self) -> None:
  54. message = (FormMessage.END, b"")
  55. self.messages.append(message)
  56. async def parse(self) -> FormData:
  57. # Callbacks dictionary.
  58. callbacks = {
  59. "on_field_start": self.on_field_start,
  60. "on_field_name": self.on_field_name,
  61. "on_field_data": self.on_field_data,
  62. "on_field_end": self.on_field_end,
  63. "on_end": self.on_end,
  64. }
  65. # Create the parser.
  66. parser = multipart.QuerystringParser(callbacks)
  67. field_name = b""
  68. field_value = b""
  69. items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
  70. # Feed the parser with data from the request.
  71. async for chunk in self.stream:
  72. if chunk:
  73. parser.write(chunk)
  74. else:
  75. parser.finalize()
  76. messages = list(self.messages)
  77. self.messages.clear()
  78. for message_type, message_bytes in messages:
  79. if message_type == FormMessage.FIELD_START:
  80. field_name = b""
  81. field_value = b""
  82. elif message_type == FormMessage.FIELD_NAME:
  83. field_name += message_bytes
  84. elif message_type == FormMessage.FIELD_DATA:
  85. field_value += message_bytes
  86. elif message_type == FormMessage.FIELD_END:
  87. name = unquote_plus(field_name.decode("latin-1"))
  88. value = unquote_plus(field_value.decode("latin-1"))
  89. items.append((name, value))
  90. return FormData(items)
  91. class MultiPartParser:
  92. def __init__(
  93. self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
  94. ) -> None:
  95. assert (
  96. multipart is not None
  97. ), "The `python-multipart` library must be installed to use form parsing."
  98. self.headers = headers
  99. self.stream = stream
  100. self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = []
  101. def on_part_begin(self) -> None:
  102. message = (MultiPartMessage.PART_BEGIN, b"")
  103. self.messages.append(message)
  104. def on_part_data(self, data: bytes, start: int, end: int) -> None:
  105. message = (MultiPartMessage.PART_DATA, data[start:end])
  106. self.messages.append(message)
  107. def on_part_end(self) -> None:
  108. message = (MultiPartMessage.PART_END, b"")
  109. self.messages.append(message)
  110. def on_header_field(self, data: bytes, start: int, end: int) -> None:
  111. message = (MultiPartMessage.HEADER_FIELD, data[start:end])
  112. self.messages.append(message)
  113. def on_header_value(self, data: bytes, start: int, end: int) -> None:
  114. message = (MultiPartMessage.HEADER_VALUE, data[start:end])
  115. self.messages.append(message)
  116. def on_header_end(self) -> None:
  117. message = (MultiPartMessage.HEADER_END, b"")
  118. self.messages.append(message)
  119. def on_headers_finished(self) -> None:
  120. message = (MultiPartMessage.HEADERS_FINISHED, b"")
  121. self.messages.append(message)
  122. def on_end(self) -> None:
  123. message = (MultiPartMessage.END, b"")
  124. self.messages.append(message)
  125. async def parse(self) -> FormData:
  126. # Parse the Content-Type header to get the multipart boundary.
  127. content_type, params = parse_options_header(self.headers["Content-Type"])
  128. charset = params.get(b"charset", "utf-8")
  129. if type(charset) == bytes:
  130. charset = charset.decode("latin-1")
  131. boundary = params.get(b"boundary")
  132. # Callbacks dictionary.
  133. callbacks = {
  134. "on_part_begin": self.on_part_begin,
  135. "on_part_data": self.on_part_data,
  136. "on_part_end": self.on_part_end,
  137. "on_header_field": self.on_header_field,
  138. "on_header_value": self.on_header_value,
  139. "on_header_end": self.on_header_end,
  140. "on_headers_finished": self.on_headers_finished,
  141. "on_end": self.on_end,
  142. }
  143. # Create the parser.
  144. parser = multipart.MultipartParser(boundary, callbacks)
  145. header_field = b""
  146. header_value = b""
  147. content_disposition = None
  148. content_type = b""
  149. field_name = ""
  150. data = b""
  151. file: typing.Optional[UploadFile] = None
  152. items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
  153. # Feed the parser with data from the request.
  154. async for chunk in self.stream:
  155. parser.write(chunk)
  156. messages = list(self.messages)
  157. self.messages.clear()
  158. for message_type, message_bytes in messages:
  159. if message_type == MultiPartMessage.PART_BEGIN:
  160. content_disposition = None
  161. content_type = b""
  162. data = b""
  163. elif message_type == MultiPartMessage.HEADER_FIELD:
  164. header_field += message_bytes
  165. elif message_type == MultiPartMessage.HEADER_VALUE:
  166. header_value += message_bytes
  167. elif message_type == MultiPartMessage.HEADER_END:
  168. field = header_field.lower()
  169. if field == b"content-disposition":
  170. content_disposition = header_value
  171. elif field == b"content-type":
  172. content_type = header_value
  173. header_field = b""
  174. header_value = b""
  175. elif message_type == MultiPartMessage.HEADERS_FINISHED:
  176. disposition, options = parse_options_header(content_disposition)
  177. field_name = _user_safe_decode(options[b"name"], charset)
  178. if b"filename" in options:
  179. filename = _user_safe_decode(options[b"filename"], charset)
  180. file = UploadFile(
  181. filename=filename,
  182. content_type=content_type.decode("latin-1"),
  183. )
  184. else:
  185. file = None
  186. elif message_type == MultiPartMessage.PART_DATA:
  187. if file is None:
  188. data += message_bytes
  189. else:
  190. await file.write(message_bytes)
  191. elif message_type == MultiPartMessage.PART_END:
  192. if file is None:
  193. items.append((field_name, _user_safe_decode(data, charset)))
  194. else:
  195. await file.seek(0)
  196. items.append((field_name, file))
  197. parser.finalize()
  198. return FormData(items)