You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

221 rivejä
7.8 KiB

  1. import re
  2. import warnings
  3. from dataclasses import is_dataclass
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. Dict,
  8. MutableMapping,
  9. Optional,
  10. Set,
  11. Type,
  12. Union,
  13. cast,
  14. )
  15. from weakref import WeakKeyDictionary
  16. import fastapi
  17. from fastapi._compat import (
  18. PYDANTIC_V2,
  19. BaseConfig,
  20. ModelField,
  21. PydanticSchemaGenerationError,
  22. Undefined,
  23. UndefinedType,
  24. Validator,
  25. lenient_issubclass,
  26. )
  27. from fastapi.datastructures import DefaultPlaceholder, DefaultType
  28. from pydantic import BaseModel, create_model
  29. from pydantic.fields import FieldInfo
  30. from typing_extensions import Literal
  31. if TYPE_CHECKING: # pragma: nocover
  32. from .routing import APIRoute
  33. # Cache for `create_cloned_field`
  34. _CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
  35. WeakKeyDictionary()
  36. )
  37. def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
  38. if status_code is None:
  39. return True
  40. # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1
  41. if status_code in {
  42. "default",
  43. "1XX",
  44. "2XX",
  45. "3XX",
  46. "4XX",
  47. "5XX",
  48. }:
  49. return True
  50. current_status_code = int(status_code)
  51. return not (current_status_code < 200 or current_status_code in {204, 205, 304})
  52. def get_path_param_names(path: str) -> Set[str]:
  53. return set(re.findall("{(.*?)}", path))
  54. def create_model_field(
  55. name: str,
  56. type_: Any,
  57. class_validators: Optional[Dict[str, Validator]] = None,
  58. default: Optional[Any] = Undefined,
  59. required: Union[bool, UndefinedType] = Undefined,
  60. model_config: Type[BaseConfig] = BaseConfig,
  61. field_info: Optional[FieldInfo] = None,
  62. alias: Optional[str] = None,
  63. mode: Literal["validation", "serialization"] = "validation",
  64. ) -> ModelField:
  65. class_validators = class_validators or {}
  66. if PYDANTIC_V2:
  67. field_info = field_info or FieldInfo(
  68. annotation=type_, default=default, alias=alias
  69. )
  70. else:
  71. field_info = field_info or FieldInfo()
  72. kwargs = {"name": name, "field_info": field_info}
  73. if PYDANTIC_V2:
  74. kwargs.update({"mode": mode})
  75. else:
  76. kwargs.update(
  77. {
  78. "type_": type_,
  79. "class_validators": class_validators,
  80. "default": default,
  81. "required": required,
  82. "model_config": model_config,
  83. "alias": alias,
  84. }
  85. )
  86. try:
  87. return ModelField(**kwargs) # type: ignore[arg-type]
  88. except (RuntimeError, PydanticSchemaGenerationError):
  89. raise fastapi.exceptions.FastAPIError(
  90. "Invalid args for response field! Hint: "
  91. f"check that {type_} is a valid Pydantic field type. "
  92. "If you are using a return type annotation that is not a valid Pydantic "
  93. "field (e.g. Union[Response, dict, None]) you can disable generating the "
  94. "response model from the type annotation with the path operation decorator "
  95. "parameter response_model=None. Read more: "
  96. "https://fastapi.tiangolo.com/tutorial/response-model/"
  97. ) from None
  98. def create_cloned_field(
  99. field: ModelField,
  100. *,
  101. cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
  102. ) -> ModelField:
  103. if PYDANTIC_V2:
  104. return field
  105. # cloned_types caches already cloned types to support recursive models and improve
  106. # performance by avoiding unnecessary cloning
  107. if cloned_types is None:
  108. cloned_types = _CLONED_TYPES_CACHE
  109. original_type = field.type_
  110. if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
  111. original_type = original_type.__pydantic_model__
  112. use_type = original_type
  113. if lenient_issubclass(original_type, BaseModel):
  114. original_type = cast(Type[BaseModel], original_type)
  115. use_type = cloned_types.get(original_type)
  116. if use_type is None:
  117. use_type = create_model(original_type.__name__, __base__=original_type)
  118. cloned_types[original_type] = use_type
  119. for f in original_type.__fields__.values():
  120. use_type.__fields__[f.name] = create_cloned_field(
  121. f, cloned_types=cloned_types
  122. )
  123. new_field = create_model_field(name=field.name, type_=use_type)
  124. new_field.has_alias = field.has_alias # type: ignore[attr-defined]
  125. new_field.alias = field.alias # type: ignore[misc]
  126. new_field.class_validators = field.class_validators # type: ignore[attr-defined]
  127. new_field.default = field.default # type: ignore[misc]
  128. new_field.required = field.required # type: ignore[misc]
  129. new_field.model_config = field.model_config # type: ignore[attr-defined]
  130. new_field.field_info = field.field_info
  131. new_field.allow_none = field.allow_none # type: ignore[attr-defined]
  132. new_field.validate_always = field.validate_always # type: ignore[attr-defined]
  133. if field.sub_fields: # type: ignore[attr-defined]
  134. new_field.sub_fields = [ # type: ignore[attr-defined]
  135. create_cloned_field(sub_field, cloned_types=cloned_types)
  136. for sub_field in field.sub_fields # type: ignore[attr-defined]
  137. ]
  138. if field.key_field: # type: ignore[attr-defined]
  139. new_field.key_field = create_cloned_field( # type: ignore[attr-defined]
  140. field.key_field, # type: ignore[attr-defined]
  141. cloned_types=cloned_types,
  142. )
  143. new_field.validators = field.validators # type: ignore[attr-defined]
  144. new_field.pre_validators = field.pre_validators # type: ignore[attr-defined]
  145. new_field.post_validators = field.post_validators # type: ignore[attr-defined]
  146. new_field.parse_json = field.parse_json # type: ignore[attr-defined]
  147. new_field.shape = field.shape # type: ignore[attr-defined]
  148. new_field.populate_validators() # type: ignore[attr-defined]
  149. return new_field
  150. def generate_operation_id_for_path(
  151. *, name: str, path: str, method: str
  152. ) -> str: # pragma: nocover
  153. warnings.warn(
  154. "fastapi.utils.generate_operation_id_for_path() was deprecated, "
  155. "it is not used internally, and will be removed soon",
  156. DeprecationWarning,
  157. stacklevel=2,
  158. )
  159. operation_id = f"{name}{path}"
  160. operation_id = re.sub(r"\W", "_", operation_id)
  161. operation_id = f"{operation_id}_{method.lower()}"
  162. return operation_id
  163. def generate_unique_id(route: "APIRoute") -> str:
  164. operation_id = f"{route.name}{route.path_format}"
  165. operation_id = re.sub(r"\W", "_", operation_id)
  166. assert route.methods
  167. operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
  168. return operation_id
  169. def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
  170. for key, value in update_dict.items():
  171. if (
  172. key in main_dict
  173. and isinstance(main_dict[key], dict)
  174. and isinstance(value, dict)
  175. ):
  176. deep_dict_update(main_dict[key], value)
  177. elif (
  178. key in main_dict
  179. and isinstance(main_dict[key], list)
  180. and isinstance(update_dict[key], list)
  181. ):
  182. main_dict[key] = main_dict[key] + update_dict[key]
  183. else:
  184. main_dict[key] = value
  185. def get_value_or_default(
  186. first_item: Union[DefaultPlaceholder, DefaultType],
  187. *extra_items: Union[DefaultPlaceholder, DefaultType],
  188. ) -> Union[DefaultPlaceholder, DefaultType]:
  189. """
  190. Pass items or `DefaultPlaceholder`s by descending priority.
  191. The first one to _not_ be a `DefaultPlaceholder` will be returned.
  192. Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
  193. """
  194. items = (first_item,) + extra_items
  195. for item in items:
  196. if not isinstance(item, DefaultPlaceholder):
  197. return item
  198. return first_item