|
- import typing
- from enum import Enum
- from urllib.parse import unquote_plus
-
- from starlette.datastructures import FormData, Headers, UploadFile
-
- try:
- import multipart
- from multipart.multipart import parse_options_header
- except ImportError: # pragma: nocover
- parse_options_header = None
- multipart = None
-
-
- class FormMessage(Enum):
- FIELD_START = 1
- FIELD_NAME = 2
- FIELD_DATA = 3
- FIELD_END = 4
- END = 5
-
-
- class MultiPartMessage(Enum):
- PART_BEGIN = 1
- PART_DATA = 2
- PART_END = 3
- HEADER_FIELD = 4
- HEADER_VALUE = 5
- HEADER_END = 6
- HEADERS_FINISHED = 7
- END = 8
-
-
- def _user_safe_decode(src: bytes, codec: str) -> str:
- try:
- return src.decode(codec)
- except (UnicodeDecodeError, LookupError):
- return src.decode("latin-1")
-
-
- class FormParser:
- def __init__(
- self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
- ) -> None:
- assert (
- multipart is not None
- ), "The `python-multipart` library must be installed to use form parsing."
- self.headers = headers
- self.stream = stream
- self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = []
-
- def on_field_start(self) -> None:
- message = (FormMessage.FIELD_START, b"")
- self.messages.append(message)
-
- def on_field_name(self, data: bytes, start: int, end: int) -> None:
- message = (FormMessage.FIELD_NAME, data[start:end])
- self.messages.append(message)
-
- def on_field_data(self, data: bytes, start: int, end: int) -> None:
- message = (FormMessage.FIELD_DATA, data[start:end])
- self.messages.append(message)
-
- def on_field_end(self) -> None:
- message = (FormMessage.FIELD_END, b"")
- self.messages.append(message)
-
- def on_end(self) -> None:
- message = (FormMessage.END, b"")
- self.messages.append(message)
-
- async def parse(self) -> FormData:
- # Callbacks dictionary.
- callbacks = {
- "on_field_start": self.on_field_start,
- "on_field_name": self.on_field_name,
- "on_field_data": self.on_field_data,
- "on_field_end": self.on_field_end,
- "on_end": self.on_end,
- }
-
- # Create the parser.
- parser = multipart.QuerystringParser(callbacks)
- field_name = b""
- field_value = b""
-
- items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
-
- # Feed the parser with data from the request.
- async for chunk in self.stream:
- if chunk:
- parser.write(chunk)
- else:
- parser.finalize()
- messages = list(self.messages)
- self.messages.clear()
- for message_type, message_bytes in messages:
- if message_type == FormMessage.FIELD_START:
- field_name = b""
- field_value = b""
- elif message_type == FormMessage.FIELD_NAME:
- field_name += message_bytes
- elif message_type == FormMessage.FIELD_DATA:
- field_value += message_bytes
- elif message_type == FormMessage.FIELD_END:
- name = unquote_plus(field_name.decode("latin-1"))
- value = unquote_plus(field_value.decode("latin-1"))
- items.append((name, value))
-
- return FormData(items)
-
-
- class MultiPartParser:
- def __init__(
- self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
- ) -> None:
- assert (
- multipart is not None
- ), "The `python-multipart` library must be installed to use form parsing."
- self.headers = headers
- self.stream = stream
- self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = []
-
- def on_part_begin(self) -> None:
- message = (MultiPartMessage.PART_BEGIN, b"")
- self.messages.append(message)
-
- def on_part_data(self, data: bytes, start: int, end: int) -> None:
- message = (MultiPartMessage.PART_DATA, data[start:end])
- self.messages.append(message)
-
- def on_part_end(self) -> None:
- message = (MultiPartMessage.PART_END, b"")
- self.messages.append(message)
-
- def on_header_field(self, data: bytes, start: int, end: int) -> None:
- message = (MultiPartMessage.HEADER_FIELD, data[start:end])
- self.messages.append(message)
-
- def on_header_value(self, data: bytes, start: int, end: int) -> None:
- message = (MultiPartMessage.HEADER_VALUE, data[start:end])
- self.messages.append(message)
-
- def on_header_end(self) -> None:
- message = (MultiPartMessage.HEADER_END, b"")
- self.messages.append(message)
-
- def on_headers_finished(self) -> None:
- message = (MultiPartMessage.HEADERS_FINISHED, b"")
- self.messages.append(message)
-
- def on_end(self) -> None:
- message = (MultiPartMessage.END, b"")
- self.messages.append(message)
-
- async def parse(self) -> FormData:
- # Parse the Content-Type header to get the multipart boundary.
- content_type, params = parse_options_header(self.headers["Content-Type"])
- charset = params.get(b"charset", "utf-8")
- if type(charset) == bytes:
- charset = charset.decode("latin-1")
- boundary = params.get(b"boundary")
-
- # Callbacks dictionary.
- callbacks = {
- "on_part_begin": self.on_part_begin,
- "on_part_data": self.on_part_data,
- "on_part_end": self.on_part_end,
- "on_header_field": self.on_header_field,
- "on_header_value": self.on_header_value,
- "on_header_end": self.on_header_end,
- "on_headers_finished": self.on_headers_finished,
- "on_end": self.on_end,
- }
-
- # Create the parser.
- parser = multipart.MultipartParser(boundary, callbacks)
- header_field = b""
- header_value = b""
- content_disposition = None
- content_type = b""
- field_name = ""
- data = b""
- file: typing.Optional[UploadFile] = None
-
- items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
-
- # Feed the parser with data from the request.
- async for chunk in self.stream:
- parser.write(chunk)
- messages = list(self.messages)
- self.messages.clear()
- for message_type, message_bytes in messages:
- if message_type == MultiPartMessage.PART_BEGIN:
- content_disposition = None
- content_type = b""
- data = b""
- elif message_type == MultiPartMessage.HEADER_FIELD:
- header_field += message_bytes
- elif message_type == MultiPartMessage.HEADER_VALUE:
- header_value += message_bytes
- elif message_type == MultiPartMessage.HEADER_END:
- field = header_field.lower()
- if field == b"content-disposition":
- content_disposition = header_value
- elif field == b"content-type":
- content_type = header_value
- header_field = b""
- header_value = b""
- elif message_type == MultiPartMessage.HEADERS_FINISHED:
- disposition, options = parse_options_header(content_disposition)
- field_name = _user_safe_decode(options[b"name"], charset)
- if b"filename" in options:
- filename = _user_safe_decode(options[b"filename"], charset)
- file = UploadFile(
- filename=filename,
- content_type=content_type.decode("latin-1"),
- )
- else:
- file = None
- elif message_type == MultiPartMessage.PART_DATA:
- if file is None:
- data += message_bytes
- else:
- await file.write(message_bytes)
- elif message_type == MultiPartMessage.PART_END:
- if file is None:
- items.append((field_name, _user_safe_decode(data, charset)))
- else:
- await file.seek(0)
- items.append((field_name, file))
-
- parser.finalize()
- return FormData(items)
|