25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

360 lines
16 KiB

  1. import sys
  2. import typing
  3. from typing import (
  4. TYPE_CHECKING,
  5. Any,
  6. ClassVar,
  7. Dict,
  8. Generic,
  9. Iterator,
  10. List,
  11. Mapping,
  12. Optional,
  13. Tuple,
  14. Type,
  15. TypeVar,
  16. Union,
  17. cast,
  18. )
  19. from typing_extensions import Annotated
  20. from .class_validators import gather_all_validators
  21. from .fields import DeferredType
  22. from .main import BaseModel, create_model
  23. from .types import JsonWrapper
  24. from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
  25. from .utils import all_identical, lenient_issubclass
  26. _generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
  27. GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
  28. TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
  29. Parametrization = Mapping[TypeVarType, Type[Any]]
  30. # _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
  31. # as captured during construction of the class (not instances).
  32. # E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
  33. # `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
  34. # (This information is only otherwise available after creation from the class name string).
  35. _assigned_parameters: Dict[Type[Any], Parametrization] = {}
  36. class GenericModel(BaseModel):
  37. __slots__ = ()
  38. __concrete__: ClassVar[bool] = False
  39. if TYPE_CHECKING:
  40. # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
  41. # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
  42. # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
  43. __parameters__: ClassVar[Tuple[TypeVarType, ...]]
  44. # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
  45. def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
  46. """Instantiates a new class from a generic class `cls` and type variables `params`.
  47. :param params: Tuple of types the class . Given a generic class
  48. `Model` with 2 type variables and a concrete model `Model[str, int]`,
  49. the value `(str, int)` would be passed to `params`.
  50. :return: New model class inheriting from `cls` with instantiated
  51. types described by `params`. If no parameters are given, `cls` is
  52. returned as is.
  53. """
  54. cached = _generic_types_cache.get((cls, params))
  55. if cached is not None:
  56. return cached
  57. if cls.__concrete__ and Generic not in cls.__bases__:
  58. raise TypeError('Cannot parameterize a concrete instantiation of a generic model')
  59. if not isinstance(params, tuple):
  60. params = (params,)
  61. if cls is GenericModel and any(isinstance(param, TypeVar) for param in params):
  62. raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel')
  63. if not hasattr(cls, '__parameters__'):
  64. raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')
  65. check_parameters_count(cls, params)
  66. # Build map from generic typevars to passed params
  67. typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
  68. if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
  69. return cls # if arguments are equal to parameters it's the same object
  70. # Create new model with original model as parent inserting fields with DeferredType.
  71. model_name = cls.__concrete_name__(params)
  72. validators = gather_all_validators(cls)
  73. type_hints = get_all_type_hints(cls).items()
  74. instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
  75. fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
  76. model_module, called_globally = get_caller_frame_info()
  77. created_model = cast(
  78. Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
  79. create_model(
  80. model_name,
  81. __module__=model_module or cls.__module__,
  82. __base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
  83. __config__=None,
  84. __validators__=validators,
  85. **fields,
  86. ),
  87. )
  88. _assigned_parameters[created_model] = typevars_map
  89. if called_globally: # create global reference and therefore allow pickling
  90. object_by_reference = None
  91. reference_name = model_name
  92. reference_module_globals = sys.modules[created_model.__module__].__dict__
  93. while object_by_reference is not created_model:
  94. object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
  95. reference_name += '_'
  96. created_model.Config = cls.Config
  97. # Find any typevars that are still present in the model.
  98. # If none are left, the model is fully "concrete", otherwise the new
  99. # class is a generic class as well taking the found typevars as
  100. # parameters.
  101. new_params = tuple(
  102. {param: None for param in iter_contained_typevars(typevars_map.values())}
  103. ) # use dict as ordered set
  104. created_model.__concrete__ = not new_params
  105. if new_params:
  106. created_model.__parameters__ = new_params
  107. # Save created model in cache so we don't end up creating duplicate
  108. # models that should be identical.
  109. _generic_types_cache[(cls, params)] = created_model
  110. if len(params) == 1:
  111. _generic_types_cache[(cls, params[0])] = created_model
  112. # Recursively walk class type hints and replace generic typevars
  113. # with concrete types that were passed.
  114. _prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
  115. return created_model
  116. @classmethod
  117. def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
  118. """Compute class name for child classes.
  119. :param params: Tuple of types the class . Given a generic class
  120. `Model` with 2 type variables and a concrete model `Model[str, int]`,
  121. the value `(str, int)` would be passed to `params`.
  122. :return: String representing a the new class where `params` are
  123. passed to `cls` as type variables.
  124. This method can be overridden to achieve a custom naming scheme for GenericModels.
  125. """
  126. param_names = [display_as_type(param) for param in params]
  127. params_component = ', '.join(param_names)
  128. return f'{cls.__name__}[{params_component}]'
  129. @classmethod
  130. def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
  131. """
  132. Returns unbound bases of cls parameterised to given type variables
  133. :param typevars_map: Dictionary of type applications for binding subclasses.
  134. Given a generic class `Model` with 2 type variables [S, T]
  135. and a concrete model `Model[str, int]`,
  136. the value `{S: str, T: int}` would be passed to `typevars_map`.
  137. :return: an iterator of generic sub classes, parameterised by `typevars_map`
  138. and other assigned parameters of `cls`
  139. e.g.:
  140. ```
  141. class A(GenericModel, Generic[T]):
  142. ...
  143. class B(A[V], Generic[V]):
  144. ...
  145. assert A[int] in B.__parameterized_bases__({V: int})
  146. ```
  147. """
  148. def build_base_model(
  149. base_model: Type[GenericModel], mapped_types: Parametrization
  150. ) -> Iterator[Type[GenericModel]]:
  151. base_parameters = tuple([mapped_types[param] for param in base_model.__parameters__])
  152. parameterized_base = base_model.__class_getitem__(base_parameters)
  153. if parameterized_base is base_model or parameterized_base is cls:
  154. # Avoid duplication in MRO
  155. return
  156. yield parameterized_base
  157. for base_model in cls.__bases__:
  158. if not issubclass(base_model, GenericModel):
  159. # not a class that can be meaningfully parameterized
  160. continue
  161. elif not getattr(base_model, '__parameters__', None):
  162. # base_model is "GenericModel" (and has no __parameters__)
  163. # or
  164. # base_model is already concrete, and will be included transitively via cls.
  165. continue
  166. elif cls in _assigned_parameters:
  167. if base_model in _assigned_parameters:
  168. # cls is partially parameterised but not from base_model
  169. # e.g. cls = B[S], base_model = A[S]
  170. # B[S][int] should subclass A[int], (and will be transitively via B[int])
  171. # but it's not viable to consistently subclass types with arbitrary construction
  172. # So don't attempt to include A[S][int]
  173. continue
  174. else: # base_model not in _assigned_parameters:
  175. # cls is partially parameterized, base_model is original generic
  176. # e.g. cls = B[str, T], base_model = B[S, T]
  177. # Need to determine the mapping for the base_model parameters
  178. mapped_types: Parametrization = {
  179. key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
  180. }
  181. yield from build_base_model(base_model, mapped_types)
  182. else:
  183. # cls is base generic, so base_class has a distinct base
  184. # can construct the Parameterised base model using typevars_map directly
  185. yield from build_base_model(base_model, typevars_map)
  186. def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
  187. """Return type with all occurrences of `type_map` keys recursively replaced with their values.
  188. :param type_: Any type, class or generic alias
  189. :param type_map: Mapping from `TypeVar` instance to concrete types.
  190. :return: New type representing the basic structure of `type_` with all
  191. `typevar_map` keys recursively replaced.
  192. >>> replace_types(Tuple[str, Union[List[str], float]], {str: int})
  193. Tuple[int, Union[List[int], float]]
  194. """
  195. if not type_map:
  196. return type_
  197. type_args = get_args(type_)
  198. origin_type = get_origin(type_)
  199. if origin_type is Annotated:
  200. annotated_type, *annotations = type_args
  201. return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
  202. # Having type args is a good indicator that this is a typing module
  203. # class instantiation or a generic alias of some sort.
  204. if type_args:
  205. resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
  206. if all_identical(type_args, resolved_type_args):
  207. # If all arguments are the same, there is no need to modify the
  208. # type or create a new object at all
  209. return type_
  210. if (
  211. origin_type is not None
  212. and isinstance(type_, typing_base)
  213. and not isinstance(origin_type, typing_base)
  214. and getattr(type_, '_name', None) is not None
  215. ):
  216. # In python < 3.9 generic aliases don't exist so any of these like `list`,
  217. # `type` or `collections.abc.Callable` need to be translated.
  218. # See: https://www.python.org/dev/peps/pep-0585
  219. origin_type = getattr(typing, type_._name)
  220. assert origin_type is not None
  221. return origin_type[resolved_type_args]
  222. # We handle pydantic generic models separately as they don't have the same
  223. # semantics as "typing" classes or generic aliases
  224. if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__:
  225. type_args = type_.__parameters__
  226. resolved_type_args = tuple(replace_types(t, type_map) for t in type_args)
  227. if all_identical(type_args, resolved_type_args):
  228. return type_
  229. return type_[resolved_type_args]
  230. # Handle special case for typehints that can have lists as arguments.
  231. # `typing.Callable[[int, str], int]` is an example for this.
  232. if isinstance(type_, (List, list)):
  233. resolved_list = list(replace_types(element, type_map) for element in type_)
  234. if all_identical(type_, resolved_list):
  235. return type_
  236. return resolved_list
  237. # For JsonWrapperValue, need to handle its inner type to allow correct parsing
  238. # of generic Json arguments like Json[T]
  239. if not origin_type and lenient_issubclass(type_, JsonWrapper):
  240. type_.inner_type = replace_types(type_.inner_type, type_map)
  241. return type_
  242. # If all else fails, we try to resolve the type directly and otherwise just
  243. # return the input with no modifications.
  244. return type_map.get(type_, type_)
  245. def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
  246. actual = len(parameters)
  247. expected = len(cls.__parameters__)
  248. if actual != expected:
  249. description = 'many' if actual > expected else 'few'
  250. raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')
  251. DictValues: Type[Any] = {}.values().__class__
  252. def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
  253. """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found."""
  254. if isinstance(v, TypeVar):
  255. yield v
  256. elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel):
  257. yield from v.__parameters__
  258. elif isinstance(v, (DictValues, list)):
  259. for var in v:
  260. yield from iter_contained_typevars(var)
  261. else:
  262. args = get_args(v)
  263. for arg in args:
  264. yield from iter_contained_typevars(arg)
  265. def get_caller_frame_info() -> Tuple[Optional[str], bool]:
  266. """
  267. Used inside a function to check whether it was called globally
  268. Will only work against non-compiled code, therefore used only in pydantic.generics
  269. :returns Tuple[module_name, called_globally]
  270. """
  271. try:
  272. previous_caller_frame = sys._getframe(2)
  273. except ValueError as e:
  274. raise RuntimeError('This function must be used inside another function') from e
  275. except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
  276. return None, False
  277. frame_globals = previous_caller_frame.f_globals
  278. return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
  279. def _prepare_model_fields(
  280. created_model: Type[GenericModel],
  281. fields: Mapping[str, Any],
  282. instance_type_hints: Mapping[str, type],
  283. typevars_map: Mapping[Any, type],
  284. ) -> None:
  285. """
  286. Replace DeferredType fields with concrete type hints and prepare them.
  287. """
  288. for key, field in created_model.__fields__.items():
  289. if key not in fields:
  290. assert field.type_.__class__ is not DeferredType
  291. # https://github.com/nedbat/coveragepy/issues/198
  292. continue # pragma: no cover
  293. assert field.type_.__class__ is DeferredType, field.type_.__class__
  294. field_type_hint = instance_type_hints[key]
  295. concrete_type = replace_types(field_type_hint, typevars_map)
  296. field.type_ = concrete_type
  297. field.outer_type_ = concrete_type
  298. field.prepare()
  299. created_model.__annotations__[key] = concrete_type