You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

157 lines
5.4 KiB

  1. import functools
  2. import re
  3. from dataclasses import is_dataclass
  4. from enum import Enum
  5. from typing import Any, Dict, Optional, Set, Type, Union, cast
  6. import fastapi
  7. from fastapi.datastructures import DefaultPlaceholder, DefaultType
  8. from fastapi.openapi.constants import REF_PREFIX
  9. from pydantic import BaseConfig, BaseModel, create_model
  10. from pydantic.class_validators import Validator
  11. from pydantic.fields import FieldInfo, ModelField, UndefinedType
  12. from pydantic.schema import model_process_schema
  13. from pydantic.utils import lenient_issubclass
  14. def get_model_definitions(
  15. *,
  16. flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
  17. model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
  18. ) -> Dict[str, Any]:
  19. definitions: Dict[str, Dict[str, Any]] = {}
  20. for model in flat_models:
  21. m_schema, m_definitions, m_nested_models = model_process_schema(
  22. model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  23. )
  24. definitions.update(m_definitions)
  25. model_name = model_name_map[model]
  26. definitions[model_name] = m_schema
  27. return definitions
  28. def get_path_param_names(path: str) -> Set[str]:
  29. return set(re.findall("{(.*?)}", path))
  30. def create_response_field(
  31. name: str,
  32. type_: Type[Any],
  33. class_validators: Optional[Dict[str, Validator]] = None,
  34. default: Optional[Any] = None,
  35. required: Union[bool, UndefinedType] = False,
  36. model_config: Type[BaseConfig] = BaseConfig,
  37. field_info: Optional[FieldInfo] = None,
  38. alias: Optional[str] = None,
  39. ) -> ModelField:
  40. """
  41. Create a new response field. Raises if type_ is invalid.
  42. """
  43. class_validators = class_validators or {}
  44. field_info = field_info or FieldInfo(None)
  45. response_field = functools.partial(
  46. ModelField,
  47. name=name,
  48. type_=type_,
  49. class_validators=class_validators,
  50. default=default,
  51. required=required,
  52. model_config=model_config,
  53. alias=alias,
  54. )
  55. try:
  56. return response_field(field_info=field_info)
  57. except RuntimeError:
  58. raise fastapi.exceptions.FastAPIError(
  59. f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
  60. )
  61. def create_cloned_field(
  62. field: ModelField,
  63. *,
  64. cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None,
  65. ) -> ModelField:
  66. # _cloned_types has already cloned types, to support recursive models
  67. if cloned_types is None:
  68. cloned_types = dict()
  69. original_type = field.type_
  70. if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
  71. original_type = original_type.__pydantic_model__
  72. use_type = original_type
  73. if lenient_issubclass(original_type, BaseModel):
  74. original_type = cast(Type[BaseModel], original_type)
  75. use_type = cloned_types.get(original_type)
  76. if use_type is None:
  77. use_type = create_model(original_type.__name__, __base__=original_type)
  78. cloned_types[original_type] = use_type
  79. for f in original_type.__fields__.values():
  80. use_type.__fields__[f.name] = create_cloned_field(
  81. f, cloned_types=cloned_types
  82. )
  83. new_field = create_response_field(name=field.name, type_=use_type)
  84. new_field.has_alias = field.has_alias
  85. new_field.alias = field.alias
  86. new_field.class_validators = field.class_validators
  87. new_field.default = field.default
  88. new_field.required = field.required
  89. new_field.model_config = field.model_config
  90. new_field.field_info = field.field_info
  91. new_field.allow_none = field.allow_none
  92. new_field.validate_always = field.validate_always
  93. if field.sub_fields:
  94. new_field.sub_fields = [
  95. create_cloned_field(sub_field, cloned_types=cloned_types)
  96. for sub_field in field.sub_fields
  97. ]
  98. if field.key_field:
  99. new_field.key_field = create_cloned_field(
  100. field.key_field, cloned_types=cloned_types
  101. )
  102. new_field.validators = field.validators
  103. new_field.pre_validators = field.pre_validators
  104. new_field.post_validators = field.post_validators
  105. new_field.parse_json = field.parse_json
  106. new_field.shape = field.shape
  107. new_field.populate_validators()
  108. return new_field
  109. def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
  110. operation_id = name + path
  111. operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
  112. operation_id = operation_id + "_" + method.lower()
  113. return operation_id
  114. def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
  115. for key in update_dict:
  116. if (
  117. key in main_dict
  118. and isinstance(main_dict[key], dict)
  119. and isinstance(update_dict[key], dict)
  120. ):
  121. deep_dict_update(main_dict[key], update_dict[key])
  122. else:
  123. main_dict[key] = update_dict[key]
  124. def get_value_or_default(
  125. first_item: Union[DefaultPlaceholder, DefaultType],
  126. *extra_items: Union[DefaultPlaceholder, DefaultType],
  127. ) -> Union[DefaultPlaceholder, DefaultType]:
  128. """
  129. Pass items or `DefaultPlaceholder`s by descending priority.
  130. The first one to _not_ be a `DefaultPlaceholder` will be returned.
  131. Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
  132. """
  133. items = (first_item,) + extra_items
  134. for item in items:
  135. if not isinstance(item, DefaultPlaceholder):
  136. return item
  137. return first_item