您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

1874 行
75 KiB

  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import shutil
  5. import sys
  6. import tempfile
  7. from email.message import Message
  8. from enum import IntEnum
  9. from io import BufferedRandom, BytesIO
  10. from numbers import Number
  11. from typing import TYPE_CHECKING, cast
  12. from .decoders import Base64Decoder, QuotedPrintableDecoder
  13. from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
  14. if TYPE_CHECKING: # pragma: no cover
  15. from typing import Any, Callable, Literal, Protocol, TypedDict
  16. from typing_extensions import TypeAlias
  17. class SupportsRead(Protocol):
  18. def read(self, __n: int) -> bytes: ...
  19. class QuerystringCallbacks(TypedDict, total=False):
  20. on_field_start: Callable[[], None]
  21. on_field_name: Callable[[bytes, int, int], None]
  22. on_field_data: Callable[[bytes, int, int], None]
  23. on_field_end: Callable[[], None]
  24. on_end: Callable[[], None]
  25. class OctetStreamCallbacks(TypedDict, total=False):
  26. on_start: Callable[[], None]
  27. on_data: Callable[[bytes, int, int], None]
  28. on_end: Callable[[], None]
  29. class MultipartCallbacks(TypedDict, total=False):
  30. on_part_begin: Callable[[], None]
  31. on_part_data: Callable[[bytes, int, int], None]
  32. on_part_end: Callable[[], None]
  33. on_header_begin: Callable[[], None]
  34. on_header_field: Callable[[bytes, int, int], None]
  35. on_header_value: Callable[[bytes, int, int], None]
  36. on_header_end: Callable[[], None]
  37. on_headers_finished: Callable[[], None]
  38. on_end: Callable[[], None]
  39. class FormParserConfig(TypedDict):
  40. UPLOAD_DIR: str | None
  41. UPLOAD_KEEP_FILENAME: bool
  42. UPLOAD_KEEP_EXTENSIONS: bool
  43. UPLOAD_ERROR_ON_BAD_CTE: bool
  44. MAX_MEMORY_FILE_SIZE: int
  45. MAX_BODY_SIZE: float
  46. class FileConfig(TypedDict, total=False):
  47. UPLOAD_DIR: str | bytes | None
  48. UPLOAD_DELETE_TMP: bool
  49. UPLOAD_KEEP_FILENAME: bool
  50. UPLOAD_KEEP_EXTENSIONS: bool
  51. MAX_MEMORY_FILE_SIZE: int
  52. class _FormProtocol(Protocol):
  53. def write(self, data: bytes) -> int: ...
  54. def finalize(self) -> None: ...
  55. def close(self) -> None: ...
  56. class FieldProtocol(_FormProtocol, Protocol):
  57. def __init__(self, name: bytes | None) -> None: ...
  58. def set_none(self) -> None: ...
  59. class FileProtocol(_FormProtocol, Protocol):
  60. def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ...
  61. OnFieldCallback = Callable[[FieldProtocol], None]
  62. OnFileCallback = Callable[[FileProtocol], None]
  63. CallbackName: TypeAlias = Literal[
  64. "start",
  65. "data",
  66. "end",
  67. "field_start",
  68. "field_name",
  69. "field_data",
  70. "field_end",
  71. "part_begin",
  72. "part_data",
  73. "part_end",
  74. "header_begin",
  75. "header_field",
  76. "header_value",
  77. "header_end",
  78. "headers_finished",
  79. ]
  80. # Unique missing object.
  81. _missing = object()
  82. class QuerystringState(IntEnum):
  83. """Querystring parser states.
  84. These are used to keep track of the state of the parser, and are used to determine
  85. what to do when new data is encountered.
  86. """
  87. BEFORE_FIELD = 0
  88. FIELD_NAME = 1
  89. FIELD_DATA = 2
  90. class MultipartState(IntEnum):
  91. """Multipart parser states.
  92. These are used to keep track of the state of the parser, and are used to determine
  93. what to do when new data is encountered.
  94. """
  95. START = 0
  96. START_BOUNDARY = 1
  97. HEADER_FIELD_START = 2
  98. HEADER_FIELD = 3
  99. HEADER_VALUE_START = 4
  100. HEADER_VALUE = 5
  101. HEADER_VALUE_ALMOST_DONE = 6
  102. HEADERS_ALMOST_DONE = 7
  103. PART_DATA_START = 8
  104. PART_DATA = 9
  105. PART_DATA_END = 10
  106. END_BOUNDARY = 11
  107. END = 12
  108. # Flags for the multipart parser.
  109. FLAG_PART_BOUNDARY = 1
  110. FLAG_LAST_BOUNDARY = 2
  111. # Get constants. Since iterating over a str on Python 2 gives you a 1-length
  112. # string, but iterating over a bytes object on Python 3 gives you an integer,
  113. # we need to save these constants.
  114. CR = b"\r"[0]
  115. LF = b"\n"[0]
  116. COLON = b":"[0]
  117. SPACE = b" "[0]
  118. HYPHEN = b"-"[0]
  119. AMPERSAND = b"&"[0]
  120. SEMICOLON = b";"[0]
  121. LOWER_A = b"a"[0]
  122. LOWER_Z = b"z"[0]
  123. NULL = b"\x00"[0]
  124. # fmt: off
  125. # Mask for ASCII characters that can be http tokens.
  126. # Per RFC7230 - 3.2.6, this is all alpha-numeric characters
  127. # and these: !#$%&'*+-.^_`|~
  128. TOKEN_CHARS_SET = frozenset(
  129. b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  130. b"abcdefghijklmnopqrstuvwxyz"
  131. b"0123456789"
  132. b"!#$%&'*+-.^_`|~")
  133. # fmt: on
  134. def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
  135. """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
  136. # Uses email.message.Message to parse the header as described in PEP 594.
  137. # Ref: https://peps.python.org/pep-0594/#cgi
  138. if not value:
  139. return (b"", {})
  140. # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
  141. if isinstance(value, bytes): # pragma: no cover
  142. value = value.decode("latin-1")
  143. # For types
  144. assert isinstance(value, str), "Value should be a string by now"
  145. # If we have no options, return the string as-is.
  146. if ";" not in value:
  147. return (value.lower().strip().encode("latin-1"), {})
  148. # Split at the first semicolon, to get our value and then options.
  149. # ctype, rest = value.split(b';', 1)
  150. message = Message()
  151. message["content-type"] = value
  152. params = message.get_params()
  153. # If there were no parameters, this would have already returned above
  154. assert params, "At least the content type value should be present"
  155. ctype = params.pop(0)[0].encode("latin-1")
  156. options: dict[bytes, bytes] = {}
  157. for param in params:
  158. key, value = param
  159. # If the value returned from get_params() is a 3-tuple, the last
  160. # element corresponds to the value.
  161. # See: https://docs.python.org/3/library/email.compat32-message.html
  162. if isinstance(value, tuple):
  163. value = value[-1]
  164. # If the value is a filename, we need to fix a bug on IE6 that sends
  165. # the full file path instead of the filename.
  166. if key == "filename":
  167. if value[1:3] == ":\\" or value[:2] == "\\\\":
  168. value = value.split("\\")[-1]
  169. options[key.encode("latin-1")] = value.encode("latin-1")
  170. return ctype, options
  171. class Field:
  172. """A Field object represents a (parsed) form field. It represents a single
  173. field with a corresponding name and value.
  174. The name that a :class:`Field` will be instantiated with is the same name
  175. that would be found in the following HTML::
  176. <input name="name_goes_here" type="text"/>
  177. This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
  178. will be called when data is written to the Field, and when the Field is
  179. finalized, respectively.
  180. Args:
  181. name: The name of the form field.
  182. """
  183. def __init__(self, name: bytes | None) -> None:
  184. self._name = name
  185. self._value: list[bytes] = []
  186. # We cache the joined version of _value for speed.
  187. self._cache = _missing
  188. @classmethod
  189. def from_value(cls, name: bytes, value: bytes | None) -> Field:
  190. """Create an instance of a :class:`Field`, and set the corresponding
  191. value - either None or an actual value. This method will also
  192. finalize the Field itself.
  193. Args:
  194. name: the name of the form field.
  195. value: the value of the form field - either a bytestring or None.
  196. Returns:
  197. A new instance of a [`Field`][python_multipart.Field].
  198. """
  199. f = cls(name)
  200. if value is None:
  201. f.set_none()
  202. else:
  203. f.write(value)
  204. f.finalize()
  205. return f
  206. def write(self, data: bytes) -> int:
  207. """Write some data into the form field.
  208. Args:
  209. data: The data to write to the field.
  210. Returns:
  211. The number of bytes written.
  212. """
  213. return self.on_data(data)
  214. def on_data(self, data: bytes) -> int:
  215. """This method is a callback that will be called whenever data is
  216. written to the Field.
  217. Args:
  218. data: The data to write to the field.
  219. Returns:
  220. The number of bytes written.
  221. """
  222. self._value.append(data)
  223. self._cache = _missing
  224. return len(data)
  225. def on_end(self) -> None:
  226. """This method is called whenever the Field is finalized."""
  227. if self._cache is _missing:
  228. self._cache = b"".join(self._value)
  229. def finalize(self) -> None:
  230. """Finalize the form field."""
  231. self.on_end()
  232. def close(self) -> None:
  233. """Close the Field object. This will free any underlying cache."""
  234. # Free our value array.
  235. if self._cache is _missing:
  236. self._cache = b"".join(self._value)
  237. del self._value
  238. def set_none(self) -> None:
  239. """Some fields in a querystring can possibly have a value of None - for
  240. example, the string "foo&bar=&baz=asdf" will have a field with the
  241. name "foo" and value None, one with name "bar" and value "", and one
  242. with name "baz" and value "asdf". Since the write() interface doesn't
  243. support writing None, this function will set the field value to None.
  244. """
  245. self._cache = None
  246. @property
  247. def field_name(self) -> bytes | None:
  248. """This property returns the name of the field."""
  249. return self._name
  250. @property
  251. def value(self) -> bytes | None:
  252. """This property returns the value of the form field."""
  253. if self._cache is _missing:
  254. self._cache = b"".join(self._value)
  255. assert isinstance(self._cache, bytes) or self._cache is None
  256. return self._cache
  257. def __eq__(self, other: object) -> bool:
  258. if isinstance(other, Field):
  259. return self.field_name == other.field_name and self.value == other.value
  260. else:
  261. return NotImplemented
  262. def __repr__(self) -> str:
  263. if self.value is not None and len(self.value) > 97:
  264. # We get the repr, and then insert three dots before the final
  265. # quote.
  266. v = repr(self.value[:97])[:-1] + "...'"
  267. else:
  268. v = repr(self.value)
  269. return "{}(field_name={!r}, value={})".format(self.__class__.__name__, self.field_name, v)
  270. class File:
  271. """This class represents an uploaded file. It handles writing file data to
  272. either an in-memory file or a temporary file on-disk, if the optional
  273. threshold is passed.
  274. There are some options that can be passed to the File to change behavior
  275. of the class. Valid options are as follows:
  276. | Name | Type | Default | Description |
  277. |-----------------------|-------|---------|-------------|
  278. | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
  279. | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
  280. | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
  281. | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
  282. | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
  283. Args:
  284. file_name: The name of the file that this [`File`][python_multipart.File] represents.
  285. field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
  286. the file was uploaded with Content-Type application/octet-stream.
  287. config: The configuration for this File. See above for valid configuration keys and their corresponding values.
  288. """ # noqa: E501
  289. def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
  290. # Save configuration, set other variables default.
  291. self.logger = logging.getLogger(__name__)
  292. self._config = config
  293. self._in_memory = True
  294. self._bytes_written = 0
  295. self._fileobj: BytesIO | BufferedRandom = BytesIO()
  296. # Save the provided field/file name.
  297. self._field_name = field_name
  298. self._file_name = file_name
  299. # Our actual file name is None by default, since, depending on our
  300. # config, we may not actually use the provided name.
  301. self._actual_file_name: bytes | None = None
  302. # Split the extension from the filename.
  303. if file_name is not None:
  304. base, ext = os.path.splitext(file_name)
  305. self._file_base = base
  306. self._ext = ext
  307. @property
  308. def field_name(self) -> bytes | None:
  309. """The form field associated with this file. May be None if there isn't
  310. one, for example when we have an application/octet-stream upload.
  311. """
  312. return self._field_name
  313. @property
  314. def file_name(self) -> bytes | None:
  315. """The file name given in the upload request."""
  316. return self._file_name
  317. @property
  318. def actual_file_name(self) -> bytes | None:
  319. """The file name that this file is saved as. Will be None if it's not
  320. currently saved on disk.
  321. """
  322. return self._actual_file_name
  323. @property
  324. def file_object(self) -> BytesIO | BufferedRandom:
  325. """The file object that we're currently writing to. Note that this
  326. will either be an instance of a :class:`io.BytesIO`, or a regular file
  327. object.
  328. """
  329. return self._fileobj
  330. @property
  331. def size(self) -> int:
  332. """The total size of this file, counted as the number of bytes that
  333. currently have been written to the file.
  334. """
  335. return self._bytes_written
  336. @property
  337. def in_memory(self) -> bool:
  338. """A boolean representing whether or not this file object is currently
  339. stored in-memory or on-disk.
  340. """
  341. return self._in_memory
  342. def flush_to_disk(self) -> None:
  343. """If the file is already on-disk, do nothing. Otherwise, copy from
  344. the in-memory buffer to a disk file, and then reassign our internal
  345. file object to this new disk file.
  346. Note that if you attempt to flush a file that is already on-disk, a
  347. warning will be logged to this module's logger.
  348. """
  349. if not self._in_memory:
  350. self.logger.warning("Trying to flush to disk when we're not in memory")
  351. return
  352. # Go back to the start of our file.
  353. self._fileobj.seek(0)
  354. # Open a new file.
  355. new_file = self._get_disk_file()
  356. # Copy the file objects.
  357. shutil.copyfileobj(self._fileobj, new_file)
  358. # Seek to the new position in our new file.
  359. new_file.seek(self._bytes_written)
  360. # Reassign the fileobject.
  361. old_fileobj = self._fileobj
  362. self._fileobj = new_file
  363. # We're no longer in memory.
  364. self._in_memory = False
  365. # Close the old file object.
  366. old_fileobj.close()
  367. def _get_disk_file(self) -> BufferedRandom:
  368. """This function is responsible for getting a file object on-disk for us."""
  369. self.logger.info("Opening a file on disk")
  370. file_dir = self._config.get("UPLOAD_DIR")
  371. keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
  372. keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
  373. delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
  374. tmp_file: None | BufferedRandom = None
  375. # If we have a directory and are to keep the filename...
  376. if file_dir is not None and keep_filename:
  377. self.logger.info("Saving with filename in: %r", file_dir)
  378. # Build our filename.
  379. # TODO: what happens if we don't have a filename?
  380. fname = self._file_base + self._ext if keep_extensions else self._file_base
  381. path = os.path.join(file_dir, fname) # type: ignore[arg-type]
  382. try:
  383. self.logger.info("Opening file: %r", path)
  384. tmp_file = open(path, "w+b")
  385. except OSError:
  386. tmp_file = None
  387. self.logger.exception("Error opening temporary file")
  388. raise FileError("Error opening temporary file: %r" % path)
  389. else:
  390. # Build options array.
  391. # Note that on Python 3, tempfile doesn't support byte names. We
  392. # encode our paths using the default filesystem encoding.
  393. suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
  394. if file_dir is None:
  395. dir = None
  396. elif isinstance(file_dir, bytes):
  397. dir = file_dir.decode(sys.getfilesystemencoding())
  398. else:
  399. dir = file_dir # pragma: no cover
  400. # Create a temporary (named) file with the appropriate settings.
  401. self.logger.info(
  402. "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
  403. )
  404. try:
  405. tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir))
  406. except OSError:
  407. self.logger.exception("Error creating named temporary file")
  408. raise FileError("Error creating named temporary file")
  409. assert tmp_file is not None
  410. # Encode filename as bytes.
  411. if isinstance(tmp_file.name, str):
  412. fname = tmp_file.name.encode(sys.getfilesystemencoding())
  413. else:
  414. fname = cast(bytes, tmp_file.name) # pragma: no cover
  415. self._actual_file_name = fname
  416. return tmp_file
  417. def write(self, data: bytes) -> int:
  418. """Write some data to the File.
  419. :param data: a bytestring
  420. """
  421. return self.on_data(data)
  422. def on_data(self, data: bytes) -> int:
  423. """This method is a callback that will be called whenever data is
  424. written to the File.
  425. Args:
  426. data: The data to write to the file.
  427. Returns:
  428. The number of bytes written.
  429. """
  430. bwritten = self._fileobj.write(data)
  431. # If the bytes written isn't the same as the length, just return.
  432. if bwritten != len(data):
  433. self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
  434. return bwritten
  435. # Keep track of how many bytes we've written.
  436. self._bytes_written += bwritten
  437. # If we're in-memory and are over our limit, we create a file.
  438. max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
  439. if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
  440. self.logger.info("Flushing to disk")
  441. self.flush_to_disk()
  442. # Return the number of bytes written.
  443. return bwritten
  444. def on_end(self) -> None:
  445. """This method is called whenever the Field is finalized."""
  446. # Flush the underlying file object
  447. self._fileobj.flush()
  448. def finalize(self) -> None:
  449. """Finalize the form file. This will not close the underlying file,
  450. but simply signal that we are finished writing to the File.
  451. """
  452. self.on_end()
  453. def close(self) -> None:
  454. """Close the File object. This will actually close the underlying
  455. file object (whether it's a :class:`io.BytesIO` or an actual file
  456. object).
  457. """
  458. self._fileobj.close()
  459. def __repr__(self) -> str:
  460. return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name)
  461. class BaseParser:
  462. """This class is the base class for all parsers. It contains the logic for
  463. calling and adding callbacks.
  464. A callback can be one of two different forms. "Notification callbacks" are
  465. callbacks that are called when something happens - for example, when a new
  466. part of a multipart message is encountered by the parser. "Data callbacks"
  467. are called when we get some sort of data - for example, part of the body of
  468. a multipart chunk. Notification callbacks are called with no parameters,
  469. whereas data callbacks are called with three, as follows::
  470. data_callback(data, start, end)
  471. The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
  472. Python 3). "start" and "end" are integer indexes into the "data" string
  473. that represent the data of interest. Thus, in a data callback, the slice
  474. `data[start:end]` represents the data that the callback is "interested in".
  475. The callback is not passed a copy of the data, since copying severely hurts
  476. performance.
  477. """
  478. def __init__(self) -> None:
  479. self.logger = logging.getLogger(__name__)
  480. self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {}
  481. def callback(
  482. self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None
  483. ) -> None:
  484. """This function calls a provided callback with some data. If the
  485. callback is not set, will do nothing.
  486. Args:
  487. name: The name of the callback to call (as a string).
  488. data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
  489. callback, and no parameters are given.
  490. end: An integer that is passed to the data callback.
  491. start: An integer that is passed to the data callback.
  492. """
  493. on_name = "on_" + name
  494. func = self.callbacks.get(on_name)
  495. if func is None:
  496. return
  497. func = cast("Callable[..., Any]", func)
  498. # Depending on whether we're given a buffer...
  499. if data is not None:
  500. # Don't do anything if we have start == end.
  501. if start is not None and start == end:
  502. return
  503. self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
  504. func(data, start, end)
  505. else:
  506. self.logger.debug("Calling %s with no data", on_name)
  507. func()
  508. def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
  509. """Update the function for a callback. Removes from the callbacks dict
  510. if new_func is None.
  511. :param name: The name of the callback to call (as a string).
  512. :param new_func: The new function for the callback. If None, then the
  513. callback will be removed (with no error if it does not
  514. exist).
  515. """
  516. if new_func is None:
  517. self.callbacks.pop("on_" + name, None) # type: ignore[misc]
  518. else:
  519. self.callbacks["on_" + name] = new_func # type: ignore[literal-required]
  520. def close(self) -> None:
  521. pass # pragma: no cover
  522. def finalize(self) -> None:
  523. pass # pragma: no cover
  524. def __repr__(self) -> str:
  525. return "%s()" % self.__class__.__name__
  526. class OctetStreamParser(BaseParser):
  527. """This parser parses an octet-stream request body and calls callbacks when
  528. incoming data is received. Callbacks are as follows:
  529. | Callback Name | Parameters | Description |
  530. |----------------|-----------------|-----------------------------------------------------|
  531. | on_start | None | Called when the first data is parsed. |
  532. | on_data | data, start, end| Called for each data chunk that is parsed. |
  533. | on_end | None | Called when the parser is finished parsing all data.|
  534. Args:
  535. callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
  536. max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
  537. """
  538. def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
  539. super().__init__()
  540. self.callbacks = callbacks
  541. self._started = False
  542. if not isinstance(max_size, Number) or max_size < 1:
  543. raise ValueError("max_size must be a positive number, not %r" % max_size)
  544. self.max_size: int | float = max_size
  545. self._current_size = 0
  546. def write(self, data: bytes) -> int:
  547. """Write some data to the parser, which will perform size verification,
  548. and then pass the data to the underlying callback.
  549. Args:
  550. data: The data to write to the parser.
  551. Returns:
  552. The number of bytes written.
  553. """
  554. if not self._started:
  555. self.callback("start")
  556. self._started = True
  557. # Truncate data length.
  558. data_len = len(data)
  559. if (self._current_size + data_len) > self.max_size:
  560. # We truncate the length of data that we are to process.
  561. new_size = int(self.max_size - self._current_size)
  562. self.logger.warning(
  563. "Current size is %d (max %d), so truncating data length from %d to %d",
  564. self._current_size,
  565. self.max_size,
  566. data_len,
  567. new_size,
  568. )
  569. data_len = new_size
  570. # Increment size, then callback, in case there's an exception.
  571. self._current_size += data_len
  572. self.callback("data", data, 0, data_len)
  573. return data_len
  574. def finalize(self) -> None:
  575. """Finalize this parser, which signals to that we are finished parsing,
  576. and sends the on_end callback.
  577. """
  578. self.callback("end")
  579. def __repr__(self) -> str:
  580. return "%s()" % self.__class__.__name__
  581. class QuerystringParser(BaseParser):
  582. """This is a streaming querystring parser. It will consume data, and call
  583. the callbacks given when it has data.
  584. | Callback Name | Parameters | Description |
  585. |----------------|-----------------|-----------------------------------------------------|
  586. | on_field_start | None | Called when a new field is encountered. |
  587. | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
  588. | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
  589. | on_field_end | None | Called when the end of a field is encountered. |
  590. | on_end | None | Called when the parser is finished parsing all data.|
  591. Args:
  592. callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
  593. strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
  594. behavior of the parser changes as the following: if a field has a value with an equal sign
  595. (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
  596. it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
  597. then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised.
  598. max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
  599. """ # noqa: E501
  600. state: QuerystringState
  601. def __init__(
  602. self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
  603. ) -> None:
  604. super().__init__()
  605. self.state = QuerystringState.BEFORE_FIELD
  606. self._found_sep = False
  607. self.callbacks = callbacks
  608. # Max-size stuff
  609. if not isinstance(max_size, Number) or max_size < 1:
  610. raise ValueError("max_size must be a positive number, not %r" % max_size)
  611. self.max_size: int | float = max_size
  612. self._current_size = 0
  613. # Should parsing be strict?
  614. self.strict_parsing = strict_parsing
  615. def write(self, data: bytes) -> int:
  616. """Write some data to the parser, which will perform size verification,
  617. parse into either a field name or value, and then pass the
  618. corresponding data to the underlying callback. If an error is
  619. encountered while parsing, a QuerystringParseError will be raised. The
  620. "offset" attribute of the raised exception will be set to the offset in
  621. the input data chunk (NOT the overall stream) that caused the error.
  622. Args:
  623. data: The data to write to the parser.
  624. Returns:
  625. The number of bytes written.
  626. """
  627. # Handle sizing.
  628. data_len = len(data)
  629. if (self._current_size + data_len) > self.max_size:
  630. # We truncate the length of data that we are to process.
  631. new_size = int(self.max_size - self._current_size)
  632. self.logger.warning(
  633. "Current size is %d (max %d), so truncating data length from %d to %d",
  634. self._current_size,
  635. self.max_size,
  636. data_len,
  637. new_size,
  638. )
  639. data_len = new_size
  640. l = 0
  641. try:
  642. l = self._internal_write(data, data_len)
  643. finally:
  644. self._current_size += l
  645. return l
  646. def _internal_write(self, data: bytes, length: int) -> int:
  647. state = self.state
  648. strict_parsing = self.strict_parsing
  649. found_sep = self._found_sep
  650. i = 0
  651. while i < length:
  652. ch = data[i]
  653. # Depending on our state...
  654. if state == QuerystringState.BEFORE_FIELD:
  655. # If the 'found_sep' flag is set, we've already encountered
  656. # and skipped a single separator. If so, we check our strict
  657. # parsing flag and decide what to do. Otherwise, we haven't
  658. # yet reached a separator, and thus, if we do, we need to skip
  659. # it as it will be the boundary between fields that's supposed
  660. # to be there.
  661. if ch == AMPERSAND or ch == SEMICOLON:
  662. if found_sep:
  663. # If we're parsing strictly, we disallow blank chunks.
  664. if strict_parsing:
  665. e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i)
  666. e.offset = i
  667. raise e
  668. else:
  669. self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i)
  670. else:
  671. # This case is when we're skipping the (first)
  672. # separator between fields, so we just set our flag
  673. # and continue on.
  674. found_sep = True
  675. else:
  676. # Emit a field-start event, and go to that state. Also,
  677. # reset the "found_sep" flag, for the next time we get to
  678. # this state.
  679. self.callback("field_start")
  680. i -= 1
  681. state = QuerystringState.FIELD_NAME
  682. found_sep = False
  683. elif state == QuerystringState.FIELD_NAME:
  684. # Try and find a separator - we ensure that, if we do, we only
  685. # look for the equal sign before it.
  686. sep_pos = data.find(b"&", i)
  687. if sep_pos == -1:
  688. sep_pos = data.find(b";", i)
  689. # See if we can find an equals sign in the remaining data. If
  690. # so, we can immediately emit the field name and jump to the
  691. # data state.
  692. if sep_pos != -1:
  693. equals_pos = data.find(b"=", i, sep_pos)
  694. else:
  695. equals_pos = data.find(b"=", i)
  696. if equals_pos != -1:
  697. # Emit this name.
  698. self.callback("field_name", data, i, equals_pos)
  699. # Jump i to this position. Note that it will then have 1
  700. # added to it below, which means the next iteration of this
  701. # loop will inspect the character after the equals sign.
  702. i = equals_pos
  703. state = QuerystringState.FIELD_DATA
  704. else:
  705. # No equals sign found.
  706. if not strict_parsing:
  707. # See also comments in the QuerystringState.FIELD_DATA case below.
  708. # If we found the separator, we emit the name and just
  709. # end - there's no data callback at all (not even with
  710. # a blank value).
  711. if sep_pos != -1:
  712. self.callback("field_name", data, i, sep_pos)
  713. self.callback("field_end")
  714. i = sep_pos - 1
  715. state = QuerystringState.BEFORE_FIELD
  716. else:
  717. # Otherwise, no separator in this block, so the
  718. # rest of this chunk must be a name.
  719. self.callback("field_name", data, i, length)
  720. i = length
  721. else:
  722. # We're parsing strictly. If we find a separator,
  723. # this is an error - we require an equals sign.
  724. if sep_pos != -1:
  725. e = QuerystringParseError(
  726. "When strict_parsing is True, we require an "
  727. "equals sign in all field chunks. Did not "
  728. "find one in the chunk that starts at %d" % (i,)
  729. )
  730. e.offset = i
  731. raise e
  732. # No separator in the rest of this chunk, so it's just
  733. # a field name.
  734. self.callback("field_name", data, i, length)
  735. i = length
  736. elif state == QuerystringState.FIELD_DATA:
  737. # Try finding either an ampersand or a semicolon after this
  738. # position.
  739. sep_pos = data.find(b"&", i)
  740. if sep_pos == -1:
  741. sep_pos = data.find(b";", i)
  742. # If we found it, callback this bit as data and then go back
  743. # to expecting to find a field.
  744. if sep_pos != -1:
  745. self.callback("field_data", data, i, sep_pos)
  746. self.callback("field_end")
  747. # Note that we go to the separator, which brings us to the
  748. # "before field" state. This allows us to properly emit
  749. # "field_start" events only when we actually have data for
  750. # a field of some sort.
  751. i = sep_pos - 1
  752. state = QuerystringState.BEFORE_FIELD
  753. # Otherwise, emit the rest as data and finish.
  754. else:
  755. self.callback("field_data", data, i, length)
  756. i = length
  757. else: # pragma: no cover (error case)
  758. msg = "Reached an unknown state %d at %d" % (state, i)
  759. self.logger.warning(msg)
  760. e = QuerystringParseError(msg)
  761. e.offset = i
  762. raise e
  763. i += 1
  764. self.state = state
  765. self._found_sep = found_sep
  766. return len(data)
  767. def finalize(self) -> None:
  768. """Finalize this parser, which signals to that we are finished parsing,
  769. if we're still in the middle of a field, an on_field_end callback, and
  770. then the on_end callback.
  771. """
  772. # If we're currently in the middle of a field, we finish it.
  773. if self.state == QuerystringState.FIELD_DATA:
  774. self.callback("field_end")
  775. self.callback("end")
  776. def __repr__(self) -> str:
  777. return "{}(strict_parsing={!r}, max_size={!r})".format(
  778. self.__class__.__name__, self.strict_parsing, self.max_size
  779. )
  780. class MultipartParser(BaseParser):
  781. """This class is a streaming multipart/form-data parser.
  782. | Callback Name | Parameters | Description |
  783. |--------------------|-----------------|-------------|
  784. | on_part_begin | None | Called when a new part of the multipart message is encountered. |
  785. | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
  786. | on_part_end | None | Called when the end of a part is reached. |
  787. | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
  788. | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
  789. | on_header_value | data, start, end| Called when we get data for a header. |
  790. | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
  791. | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
  792. | on_end | None | Called when the parser is finished parsing all data. |
  793. Args:
  794. boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
  795. callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
  796. max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
  797. """ # noqa: E501
  798. def __init__(
  799. self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
  800. ) -> None:
  801. # Initialize parser state.
  802. super().__init__()
  803. self.state = MultipartState.START
  804. self.index = self.flags = 0
  805. self.callbacks = callbacks
  806. if not isinstance(max_size, Number) or max_size < 1:
  807. raise ValueError("max_size must be a positive number, not %r" % max_size)
  808. self.max_size = max_size
  809. self._current_size = 0
  810. # Setup marks. These are used to track the state of data received.
  811. self.marks: dict[str, int] = {}
  812. # Save our boundary.
  813. if isinstance(boundary, str): # pragma: no cover
  814. boundary = boundary.encode("latin-1")
  815. self.boundary = b"\r\n--" + boundary
  816. def write(self, data: bytes) -> int:
  817. """Write some data to the parser, which will perform size verification,
  818. and then parse the data into the appropriate location (e.g. header,
  819. data, etc.), and pass this on to the underlying callback. If an error
  820. is encountered, a MultipartParseError will be raised. The "offset"
  821. attribute on the raised exception will be set to the offset of the byte
  822. in the input chunk that caused the error.
  823. Args:
  824. data: The data to write to the parser.
  825. Returns:
  826. The number of bytes written.
  827. """
  828. # Handle sizing.
  829. data_len = len(data)
  830. if (self._current_size + data_len) > self.max_size:
  831. # We truncate the length of data that we are to process.
  832. new_size = int(self.max_size - self._current_size)
  833. self.logger.warning(
  834. "Current size is %d (max %d), so truncating data length from %d to %d",
  835. self._current_size,
  836. self.max_size,
  837. data_len,
  838. new_size,
  839. )
  840. data_len = new_size
  841. l = 0
  842. try:
  843. l = self._internal_write(data, data_len)
  844. finally:
  845. self._current_size += l
  846. return l
  847. def _internal_write(self, data: bytes, length: int) -> int:
  848. # Get values from locals.
  849. boundary = self.boundary
  850. # Get our state, flags and index. These are persisted between calls to
  851. # this function.
  852. state = self.state
  853. index = self.index
  854. flags = self.flags
  855. # Our index defaults to 0.
  856. i = 0
  857. # Set a mark.
  858. def set_mark(name: str) -> None:
  859. self.marks[name] = i
  860. # Remove a mark.
  861. def delete_mark(name: str, reset: bool = False) -> None:
  862. self.marks.pop(name, None)
  863. # Helper function that makes calling a callback with data easier. The
  864. # 'remaining' parameter will callback from the marked value until the
  865. # end of the buffer, and reset the mark, instead of deleting it. This
  866. # is used at the end of the function to call our callbacks with any
  867. # remaining data in this chunk.
  868. def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None:
  869. marked_index = self.marks.get(name)
  870. if marked_index is None:
  871. return
  872. # Otherwise, we call it from the mark to the current byte we're
  873. # processing.
  874. if end_i <= marked_index:
  875. # There is no additional data to send.
  876. pass
  877. elif marked_index >= 0:
  878. # We are emitting data from the local buffer.
  879. self.callback(name, data, marked_index, end_i)
  880. else:
  881. # Some of the data comes from a partial boundary match.
  882. # and requires look-behind.
  883. # We need to use self.flags (and not flags) because we care about
  884. # the state when we entered the loop.
  885. lookbehind_len = -marked_index
  886. if lookbehind_len <= len(boundary):
  887. self.callback(name, boundary, 0, lookbehind_len)
  888. elif self.flags & FLAG_PART_BOUNDARY:
  889. lookback = boundary + b"\r\n"
  890. self.callback(name, lookback, 0, lookbehind_len)
  891. elif self.flags & FLAG_LAST_BOUNDARY:
  892. lookback = boundary + b"--\r\n"
  893. self.callback(name, lookback, 0, lookbehind_len)
  894. else: # pragma: no cover (error case)
  895. self.logger.warning("Look-back buffer error")
  896. if end_i > 0:
  897. self.callback(name, data, 0, end_i)
  898. # If we're getting remaining data, we have got all the data we
  899. # can be certain is not a boundary, leaving only a partial boundary match.
  900. if remaining:
  901. self.marks[name] = end_i - length
  902. else:
  903. self.marks.pop(name, None)
  904. # For each byte...
  905. while i < length:
  906. c = data[i]
  907. if state == MultipartState.START:
  908. # Skip leading newlines
  909. if c == CR or c == LF:
  910. i += 1
  911. continue
  912. # index is used as in index into our boundary. Set to 0.
  913. index = 0
  914. # Move to the next state, but decrement i so that we re-process
  915. # this character.
  916. state = MultipartState.START_BOUNDARY
  917. i -= 1
  918. elif state == MultipartState.START_BOUNDARY:
  919. # Check to ensure that the last 2 characters in our boundary
  920. # are CRLF.
  921. if index == len(boundary) - 2:
  922. if c == HYPHEN:
  923. # Potential empty message.
  924. state = MultipartState.END_BOUNDARY
  925. elif c != CR:
  926. # Error!
  927. msg = "Did not find CR at end of boundary (%d)" % (i,)
  928. self.logger.warning(msg)
  929. e = MultipartParseError(msg)
  930. e.offset = i
  931. raise e
  932. index += 1
  933. elif index == len(boundary) - 2 + 1:
  934. if c != LF:
  935. msg = "Did not find LF at end of boundary (%d)" % (i,)
  936. self.logger.warning(msg)
  937. e = MultipartParseError(msg)
  938. e.offset = i
  939. raise e
  940. # The index is now used for indexing into our boundary.
  941. index = 0
  942. # Callback for the start of a part.
  943. self.callback("part_begin")
  944. # Move to the next character and state.
  945. state = MultipartState.HEADER_FIELD_START
  946. else:
  947. # Check to ensure our boundary matches
  948. if c != boundary[index + 2]:
  949. msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2)
  950. self.logger.warning(msg)
  951. e = MultipartParseError(msg)
  952. e.offset = i
  953. raise e
  954. # Increment index into boundary and continue.
  955. index += 1
  956. elif state == MultipartState.HEADER_FIELD_START:
  957. # Mark the start of a header field here, reset the index, and
  958. # continue parsing our header field.
  959. index = 0
  960. # Set a mark of our header field.
  961. set_mark("header_field")
  962. # Notify that we're starting a header if the next character is
  963. # not a CR; a CR at the beginning of the header will cause us
  964. # to stop parsing headers in the MultipartState.HEADER_FIELD state,
  965. # below.
  966. if c != CR:
  967. self.callback("header_begin")
  968. # Move to parsing header fields.
  969. state = MultipartState.HEADER_FIELD
  970. i -= 1
  971. elif state == MultipartState.HEADER_FIELD:
  972. # If we've reached a CR at the beginning of a header, it means
  973. # that we've reached the second of 2 newlines, and so there are
  974. # no more headers to parse.
  975. if c == CR and index == 0:
  976. delete_mark("header_field")
  977. state = MultipartState.HEADERS_ALMOST_DONE
  978. i += 1
  979. continue
  980. # Increment our index in the header.
  981. index += 1
  982. # If we've reached a colon, we're done with this header.
  983. if c == COLON:
  984. # A 0-length header is an error.
  985. if index == 1:
  986. msg = "Found 0-length header at %d" % (i,)
  987. self.logger.warning(msg)
  988. e = MultipartParseError(msg)
  989. e.offset = i
  990. raise e
  991. # Call our callback with the header field.
  992. data_callback("header_field", i)
  993. # Move to parsing the header value.
  994. state = MultipartState.HEADER_VALUE_START
  995. elif c not in TOKEN_CHARS_SET:
  996. msg = "Found invalid character %r in header at %d" % (c, i)
  997. self.logger.warning(msg)
  998. e = MultipartParseError(msg)
  999. e.offset = i
  1000. raise e
  1001. elif state == MultipartState.HEADER_VALUE_START:
  1002. # Skip leading spaces.
  1003. if c == SPACE:
  1004. i += 1
  1005. continue
  1006. # Mark the start of the header value.
  1007. set_mark("header_value")
  1008. # Move to the header-value state, reprocessing this character.
  1009. state = MultipartState.HEADER_VALUE
  1010. i -= 1
  1011. elif state == MultipartState.HEADER_VALUE:
  1012. # If we've got a CR, we're nearly done our headers. Otherwise,
  1013. # we do nothing and just move past this character.
  1014. if c == CR:
  1015. data_callback("header_value", i)
  1016. self.callback("header_end")
  1017. state = MultipartState.HEADER_VALUE_ALMOST_DONE
  1018. elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
  1019. # The last character should be a LF. If not, it's an error.
  1020. if c != LF:
  1021. msg = "Did not find LF character at end of header " "(found %r)" % (c,)
  1022. self.logger.warning(msg)
  1023. e = MultipartParseError(msg)
  1024. e.offset = i
  1025. raise e
  1026. # Move back to the start of another header. Note that if that
  1027. # state detects ANOTHER newline, it'll trigger the end of our
  1028. # headers.
  1029. state = MultipartState.HEADER_FIELD_START
  1030. elif state == MultipartState.HEADERS_ALMOST_DONE:
  1031. # We're almost done our headers. This is reached when we parse
  1032. # a CR at the beginning of a header, so our next character
  1033. # should be a LF, or it's an error.
  1034. if c != LF:
  1035. msg = f"Did not find LF at end of headers (found {c!r})"
  1036. self.logger.warning(msg)
  1037. e = MultipartParseError(msg)
  1038. e.offset = i
  1039. raise e
  1040. self.callback("headers_finished")
  1041. state = MultipartState.PART_DATA_START
  1042. elif state == MultipartState.PART_DATA_START:
  1043. # Mark the start of our part data.
  1044. set_mark("part_data")
  1045. # Start processing part data, including this character.
  1046. state = MultipartState.PART_DATA
  1047. i -= 1
  1048. elif state == MultipartState.PART_DATA:
  1049. # We're processing our part data right now. During this, we
  1050. # need to efficiently search for our boundary, since any data
  1051. # on any number of lines can be a part of the current data.
  1052. # Save the current value of our index. We use this in case we
  1053. # find part of a boundary, but it doesn't match fully.
  1054. prev_index = index
  1055. # Set up variables.
  1056. boundary_length = len(boundary)
  1057. data_length = length
  1058. # If our index is 0, we're starting a new part, so start our
  1059. # search.
  1060. if index == 0:
  1061. # The most common case is likely to be that the whole
  1062. # boundary is present in the buffer.
  1063. # Calling `find` is much faster than iterating here.
  1064. i0 = data.find(boundary, i, data_length)
  1065. if i0 >= 0:
  1066. # We matched the whole boundary string.
  1067. index = boundary_length - 1
  1068. i = i0 + boundary_length - 1
  1069. else:
  1070. # No match found for whole string.
  1071. # There may be a partial boundary at the end of the
  1072. # data, which the find will not match.
  1073. # Since the length should to be searched is limited to
  1074. # the boundary length, just perform a naive search.
  1075. i = max(i, data_length - boundary_length)
  1076. # Search forward until we either hit the end of our buffer,
  1077. # or reach a potential start of the boundary.
  1078. while i < data_length - 1 and data[i] != boundary[0]:
  1079. i += 1
  1080. c = data[i]
  1081. # Now, we have a couple of cases here. If our index is before
  1082. # the end of the boundary...
  1083. if index < boundary_length:
  1084. # If the character matches...
  1085. if boundary[index] == c:
  1086. # The current character matches, so continue!
  1087. index += 1
  1088. else:
  1089. index = 0
  1090. # Our index is equal to the length of our boundary!
  1091. elif index == boundary_length:
  1092. # First we increment it.
  1093. index += 1
  1094. # Now, if we've reached a newline, we need to set this as
  1095. # the potential end of our boundary.
  1096. if c == CR:
  1097. flags |= FLAG_PART_BOUNDARY
  1098. # Otherwise, if this is a hyphen, we might be at the last
  1099. # of all boundaries.
  1100. elif c == HYPHEN:
  1101. flags |= FLAG_LAST_BOUNDARY
  1102. # Otherwise, we reset our index, since this isn't either a
  1103. # newline or a hyphen.
  1104. else:
  1105. index = 0
  1106. # Our index is right after the part boundary, which should be
  1107. # a LF.
  1108. elif index == boundary_length + 1:
  1109. # If we're at a part boundary (i.e. we've seen a CR
  1110. # character already)...
  1111. if flags & FLAG_PART_BOUNDARY:
  1112. # We need a LF character next.
  1113. if c == LF:
  1114. # Unset the part boundary flag.
  1115. flags &= ~FLAG_PART_BOUNDARY
  1116. # We have identified a boundary, callback for any data before it.
  1117. data_callback("part_data", i - index)
  1118. # Callback indicating that we've reached the end of
  1119. # a part, and are starting a new one.
  1120. self.callback("part_end")
  1121. self.callback("part_begin")
  1122. # Move to parsing new headers.
  1123. index = 0
  1124. state = MultipartState.HEADER_FIELD_START
  1125. i += 1
  1126. continue
  1127. # We didn't find an LF character, so no match. Reset
  1128. # our index and clear our flag.
  1129. index = 0
  1130. flags &= ~FLAG_PART_BOUNDARY
  1131. # Otherwise, if we're at the last boundary (i.e. we've
  1132. # seen a hyphen already)...
  1133. elif flags & FLAG_LAST_BOUNDARY:
  1134. # We need a second hyphen here.
  1135. if c == HYPHEN:
  1136. # We have identified a boundary, callback for any data before it.
  1137. data_callback("part_data", i - index)
  1138. # Callback to end the current part, and then the
  1139. # message.
  1140. self.callback("part_end")
  1141. self.callback("end")
  1142. state = MultipartState.END
  1143. else:
  1144. # No match, so reset index.
  1145. index = 0
  1146. # Otherwise, our index is 0. If the previous index is not, it
  1147. # means we reset something, and we need to take the data we
  1148. # thought was part of our boundary and send it along as actual
  1149. # data.
  1150. if index == 0 and prev_index > 0:
  1151. # Overwrite our previous index.
  1152. prev_index = 0
  1153. # Re-consider the current character, since this could be
  1154. # the start of the boundary itself.
  1155. i -= 1
  1156. elif state == MultipartState.END_BOUNDARY:
  1157. if index == len(boundary) - 2 + 1:
  1158. if c != HYPHEN:
  1159. msg = "Did not find - at end of boundary (%d)" % (i,)
  1160. self.logger.warning(msg)
  1161. e = MultipartParseError(msg)
  1162. e.offset = i
  1163. raise e
  1164. index += 1
  1165. self.callback("end")
  1166. state = MultipartState.END
  1167. elif state == MultipartState.END:
  1168. # Don't do anything if chunk ends with CRLF.
  1169. if c == CR and i + 1 < length and data[i + 1] == LF:
  1170. i += 2
  1171. continue
  1172. # Skip data after the last boundary.
  1173. self.logger.warning("Skipping data after last boundary")
  1174. i = length
  1175. break
  1176. else: # pragma: no cover (error case)
  1177. # We got into a strange state somehow! Just stop processing.
  1178. msg = "Reached an unknown state %d at %d" % (state, i)
  1179. self.logger.warning(msg)
  1180. e = MultipartParseError(msg)
  1181. e.offset = i
  1182. raise e
  1183. # Move to the next byte.
  1184. i += 1
  1185. # We call our callbacks with any remaining data. Note that we pass
  1186. # the 'remaining' flag, which sets the mark back to 0 instead of
  1187. # deleting it, if it's found. This is because, if the mark is found
  1188. # at this point, we assume that there's data for one of these things
  1189. # that has been parsed, but not yet emitted. And, as such, it implies
  1190. # that we haven't yet reached the end of this 'thing'. So, by setting
  1191. # the mark to 0, we cause any data callbacks that take place in future
  1192. # calls to this function to start from the beginning of that buffer.
  1193. data_callback("header_field", length, True)
  1194. data_callback("header_value", length, True)
  1195. data_callback("part_data", length - index, True)
  1196. # Save values to locals.
  1197. self.state = state
  1198. self.index = index
  1199. self.flags = flags
  1200. # Return our data length to indicate no errors, and that we processed
  1201. # all of it.
  1202. return length
  1203. def finalize(self) -> None:
  1204. """Finalize this parser, which signals to that we are finished parsing.
  1205. Note: It does not currently, but in the future, it will verify that we
  1206. are in the final state of the parser (i.e. the end of the multipart
  1207. message is well-formed), and, if not, throw an error.
  1208. """
  1209. # TODO: verify that we're in the state MultipartState.END, otherwise throw an
  1210. # error or otherwise state that we're not finished parsing.
  1211. pass
  1212. def __repr__(self) -> str:
  1213. return f"{self.__class__.__name__}(boundary={self.boundary!r})"
  1214. class FormParser:
  1215. """This class is the all-in-one form parser. Given all the information
  1216. necessary to parse a form, it will instantiate the correct parser, create
  1217. the proper :class:`Field` and :class:`File` classes to store the data that
  1218. is parsed, and call the two given callbacks with each field and file as
  1219. they become available.
  1220. Args:
  1221. content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
  1222. on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
  1223. on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
  1224. on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
  1225. boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
  1226. in the Content-Type header, as a bytestring.
  1227. file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
  1228. information about the uploaded file. In such cases, you can provide the file name of the uploaded file
  1229. manually.
  1230. FileClass: The class to use for uploaded files. Defaults to :class:`File`, but you can provide your own class
  1231. if you wish to customize behaviour. The class will be instantiated as FileClass(file_name, field_name), and
  1232. it must provide the following functions::
  1233. - file_instance.write(data)
  1234. - file_instance.finalize()
  1235. - file_instance.close()
  1236. FieldClass: The class to use for uploaded fields. Defaults to :class:`Field`, but you can provide your own
  1237. class if you wish to customize behaviour. The class will be instantiated as FieldClass(field_name), and it
  1238. must provide the following functions::
  1239. - field_instance.write(data)
  1240. - field_instance.finalize()
  1241. - field_instance.close()
  1242. - field_instance.set_none()
  1243. config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
  1244. and then any keys present in this dictionary will overwrite the default values.
  1245. """
  1246. #: This is the default configuration for our form parser.
  1247. #: Note: all file sizes should be in bytes.
  1248. DEFAULT_CONFIG: FormParserConfig = {
  1249. "MAX_BODY_SIZE": float("inf"),
  1250. "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
  1251. "UPLOAD_DIR": None,
  1252. "UPLOAD_KEEP_FILENAME": False,
  1253. "UPLOAD_KEEP_EXTENSIONS": False,
  1254. # Error on invalid Content-Transfer-Encoding?
  1255. "UPLOAD_ERROR_ON_BAD_CTE": False,
  1256. }
  1257. def __init__(
  1258. self,
  1259. content_type: str,
  1260. on_field: OnFieldCallback | None,
  1261. on_file: OnFileCallback | None,
  1262. on_end: Callable[[], None] | None = None,
  1263. boundary: bytes | str | None = None,
  1264. file_name: bytes | None = None,
  1265. FileClass: type[FileProtocol] = File,
  1266. FieldClass: type[FieldProtocol] = Field,
  1267. config: dict[Any, Any] = {},
  1268. ) -> None:
  1269. self.logger = logging.getLogger(__name__)
  1270. # Save variables.
  1271. self.content_type = content_type
  1272. self.boundary = boundary
  1273. self.bytes_received = 0
  1274. self.parser = None
  1275. # Save callbacks.
  1276. self.on_field = on_field
  1277. self.on_file = on_file
  1278. self.on_end = on_end
  1279. # Save classes.
  1280. self.FileClass = File
  1281. self.FieldClass = Field
  1282. # Set configuration options.
  1283. self.config: FormParserConfig = self.DEFAULT_CONFIG.copy()
  1284. self.config.update(config) # type: ignore[typeddict-item]
  1285. parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None
  1286. # Depending on the Content-Type, we instantiate the correct parser.
  1287. if content_type == "application/octet-stream":
  1288. file: FileProtocol = None # type: ignore
  1289. def on_start() -> None:
  1290. nonlocal file
  1291. file = FileClass(file_name, None, config=cast("FileConfig", self.config))
  1292. def on_data(data: bytes, start: int, end: int) -> None:
  1293. nonlocal file
  1294. file.write(data[start:end])
  1295. def _on_end() -> None:
  1296. nonlocal file
  1297. # Finalize the file itself.
  1298. file.finalize()
  1299. # Call our callback.
  1300. if on_file:
  1301. on_file(file)
  1302. # Call the on-end callback.
  1303. if self.on_end is not None:
  1304. self.on_end()
  1305. # Instantiate an octet-stream parser
  1306. parser = OctetStreamParser(
  1307. callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
  1308. max_size=self.config["MAX_BODY_SIZE"],
  1309. )
  1310. elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
  1311. name_buffer: list[bytes] = []
  1312. f: FieldProtocol | None = None
  1313. def on_field_start() -> None:
  1314. pass
  1315. def on_field_name(data: bytes, start: int, end: int) -> None:
  1316. name_buffer.append(data[start:end])
  1317. def on_field_data(data: bytes, start: int, end: int) -> None:
  1318. nonlocal f
  1319. if f is None:
  1320. f = FieldClass(b"".join(name_buffer))
  1321. del name_buffer[:]
  1322. f.write(data[start:end])
  1323. def on_field_end() -> None:
  1324. nonlocal f
  1325. # Finalize and call callback.
  1326. if f is None:
  1327. # If we get here, it's because there was no field data.
  1328. # We create a field, set it to None, and then continue.
  1329. f = FieldClass(b"".join(name_buffer))
  1330. del name_buffer[:]
  1331. f.set_none()
  1332. f.finalize()
  1333. if on_field:
  1334. on_field(f)
  1335. f = None
  1336. def _on_end() -> None:
  1337. if self.on_end is not None:
  1338. self.on_end()
  1339. # Instantiate parser.
  1340. parser = QuerystringParser(
  1341. callbacks={
  1342. "on_field_start": on_field_start,
  1343. "on_field_name": on_field_name,
  1344. "on_field_data": on_field_data,
  1345. "on_field_end": on_field_end,
  1346. "on_end": _on_end,
  1347. },
  1348. max_size=self.config["MAX_BODY_SIZE"],
  1349. )
  1350. elif content_type == "multipart/form-data":
  1351. if boundary is None:
  1352. self.logger.error("No boundary given")
  1353. raise FormParserError("No boundary given")
  1354. header_name: list[bytes] = []
  1355. header_value: list[bytes] = []
  1356. headers: dict[bytes, bytes] = {}
  1357. f_multi: FileProtocol | FieldProtocol | None = None
  1358. writer = None
  1359. is_file = False
  1360. def on_part_begin() -> None:
  1361. # Reset headers in case this isn't the first part.
  1362. nonlocal headers
  1363. headers = {}
  1364. def on_part_data(data: bytes, start: int, end: int) -> None:
  1365. nonlocal writer
  1366. assert writer is not None
  1367. writer.write(data[start:end])
  1368. # TODO: check for error here.
  1369. def on_part_end() -> None:
  1370. nonlocal f_multi, is_file
  1371. assert f_multi is not None
  1372. f_multi.finalize()
  1373. if is_file:
  1374. if on_file:
  1375. on_file(f_multi)
  1376. else:
  1377. if on_field:
  1378. on_field(cast("FieldProtocol", f_multi))
  1379. def on_header_field(data: bytes, start: int, end: int) -> None:
  1380. header_name.append(data[start:end])
  1381. def on_header_value(data: bytes, start: int, end: int) -> None:
  1382. header_value.append(data[start:end])
  1383. def on_header_end() -> None:
  1384. headers[b"".join(header_name)] = b"".join(header_value)
  1385. del header_name[:]
  1386. del header_value[:]
  1387. def on_headers_finished() -> None:
  1388. nonlocal is_file, f_multi, writer
  1389. # Reset the 'is file' flag.
  1390. is_file = False
  1391. # Parse the content-disposition header.
  1392. # TODO: handle mixed case
  1393. content_disp = headers.get(b"Content-Disposition")
  1394. disp, options = parse_options_header(content_disp)
  1395. # Get the field and filename.
  1396. field_name = options.get(b"name")
  1397. file_name = options.get(b"filename")
  1398. # TODO: check for errors
  1399. # Create the proper class.
  1400. if file_name is None:
  1401. f_multi = FieldClass(field_name)
  1402. else:
  1403. f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
  1404. is_file = True
  1405. # Parse the given Content-Transfer-Encoding to determine what
  1406. # we need to do with the incoming data.
  1407. # TODO: check that we properly handle 8bit / 7bit encoding.
  1408. transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")
  1409. if transfer_encoding in (b"binary", b"8bit", b"7bit"):
  1410. writer = f_multi
  1411. elif transfer_encoding == b"base64":
  1412. writer = Base64Decoder(f_multi)
  1413. elif transfer_encoding == b"quoted-printable":
  1414. writer = QuotedPrintableDecoder(f_multi)
  1415. else:
  1416. self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
  1417. if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
  1418. raise FormParserError('Unknown Content-Transfer-Encoding "{!r}"'.format(transfer_encoding))
  1419. else:
  1420. # If we aren't erroring, then we just treat this as an
  1421. # unencoded Content-Transfer-Encoding.
  1422. writer = f_multi
  1423. def _on_end() -> None:
  1424. nonlocal writer
  1425. if writer is not None:
  1426. writer.finalize()
  1427. if self.on_end is not None:
  1428. self.on_end()
  1429. # Instantiate a multipart parser.
  1430. parser = MultipartParser(
  1431. boundary,
  1432. callbacks={
  1433. "on_part_begin": on_part_begin,
  1434. "on_part_data": on_part_data,
  1435. "on_part_end": on_part_end,
  1436. "on_header_field": on_header_field,
  1437. "on_header_value": on_header_value,
  1438. "on_header_end": on_header_end,
  1439. "on_headers_finished": on_headers_finished,
  1440. "on_end": _on_end,
  1441. },
  1442. max_size=self.config["MAX_BODY_SIZE"],
  1443. )
  1444. else:
  1445. self.logger.warning("Unknown Content-Type: %r", content_type)
  1446. raise FormParserError("Unknown Content-Type: {}".format(content_type))
  1447. self.parser = parser
  1448. def write(self, data: bytes) -> int:
  1449. """Write some data. The parser will forward this to the appropriate
  1450. underlying parser.
  1451. Args:
  1452. data: The data to write.
  1453. Returns:
  1454. The number of bytes processed.
  1455. """
  1456. self.bytes_received += len(data)
  1457. # TODO: check the parser's return value for errors?
  1458. assert self.parser is not None
  1459. return self.parser.write(data)
  1460. def finalize(self) -> None:
  1461. """Finalize the parser."""
  1462. if self.parser is not None and hasattr(self.parser, "finalize"):
  1463. self.parser.finalize()
  1464. def close(self) -> None:
  1465. """Close the parser."""
  1466. if self.parser is not None and hasattr(self.parser, "close"):
  1467. self.parser.close()
  1468. def __repr__(self) -> str:
  1469. return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser)
  1470. def create_form_parser(
  1471. headers: dict[str, bytes],
  1472. on_field: OnFieldCallback | None,
  1473. on_file: OnFileCallback | None,
  1474. trust_x_headers: bool = False,
  1475. config: dict[Any, Any] = {},
  1476. ) -> FormParser:
  1477. """This function is a helper function to aid in creating a FormParser
  1478. instances. Given a dictionary-like headers object, it will determine
  1479. the correct information needed, instantiate a FormParser with the
  1480. appropriate values and given callbacks, and then return the corresponding
  1481. parser.
  1482. Args:
  1483. headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
  1484. on_field: Callback to call with each parsed field.
  1485. on_file: Callback to call with each parsed file.
  1486. trust_x_headers: Whether or not to trust information received from certain X-Headers - for example, the file
  1487. name from X-File-Name.
  1488. config: Configuration variables to pass to the FormParser.
  1489. """
  1490. content_type: str | bytes | None = headers.get("Content-Type")
  1491. if content_type is None:
  1492. logging.getLogger(__name__).warning("No Content-Type header given")
  1493. raise ValueError("No Content-Type header given!")
  1494. # Boundaries are optional (the FormParser will raise if one is needed
  1495. # but not given).
  1496. content_type, params = parse_options_header(content_type)
  1497. boundary = params.get(b"boundary")
  1498. # We need content_type to be a string, not a bytes object.
  1499. content_type = content_type.decode("latin-1")
  1500. # File names are optional.
  1501. file_name = headers.get("X-File-Name")
  1502. # Instantiate a form parser.
  1503. form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config)
  1504. # Return our parser.
  1505. return form_parser
  1506. def parse_form(
  1507. headers: dict[str, bytes],
  1508. input_stream: SupportsRead,
  1509. on_field: OnFieldCallback | None,
  1510. on_file: OnFileCallback | None,
  1511. chunk_size: int = 1048576,
  1512. ) -> None:
  1513. """This function is useful if you just want to parse a request body,
  1514. without too much work. Pass it a dictionary-like object of the request's
  1515. headers, and a file-like object for the input stream, along with two
  1516. callbacks that will get called whenever a field or file is parsed.
  1517. Args:
  1518. headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
  1519. input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
  1520. on_field: Callback to call with each parsed field.
  1521. on_file: Callback to call with each parsed file.
  1522. chunk_size: The maximum size to read from the input stream and write to the parser at one time.
  1523. Defaults to 1 MiB.
  1524. """
  1525. # Create our form parser.
  1526. parser = create_form_parser(headers, on_field, on_file)
  1527. # Read chunks of 1MiB and write to the parser, but never read more than
  1528. # the given Content-Length, if any.
  1529. content_length: int | float | bytes | None = headers.get("Content-Length")
  1530. if content_length is not None:
  1531. content_length = int(content_length)
  1532. else:
  1533. content_length = float("inf")
  1534. bytes_read = 0
  1535. while True:
  1536. # Read only up to the Content-Length given.
  1537. max_readable = int(min(content_length - bytes_read, chunk_size))
  1538. buff = input_stream.read(max_readable)
  1539. # Write to the parser and update our length.
  1540. parser.write(buff)
  1541. bytes_read += len(buff)
  1542. # If we get a buffer that's smaller than the size requested, or if we
  1543. # have read up to our content length, we're done.
  1544. if len(buff) != max_readable or bytes_read == content_length:
  1545. break
  1546. # Tell our parser that we're done writing data.
  1547. parser.finalize()