Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

268 řádky
9.3 KiB

  1. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
  2. from .class_validators import gather_all_validators
  3. from .error_wrappers import ValidationError
  4. from .errors import DataclassTypeError
  5. from .fields import Field, FieldInfo, Required, Undefined
  6. from .main import create_model, validate_model
  7. from .typing import resolve_annotations
  8. from .utils import ClassAttribute
  9. if TYPE_CHECKING:
  10. from .config import BaseConfig
  11. from .main import BaseModel
  12. from .typing import CallableGenerator, NoArgAnyCallable
  13. DataclassT = TypeVar('DataclassT', bound='Dataclass')
  14. class Dataclass:
  15. __pydantic_model__: Type[BaseModel]
  16. __initialised__: bool
  17. __post_init_original__: Optional[Callable[..., None]]
  18. __processed__: Optional[ClassAttribute]
  19. __has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value
  20. def __init__(self, *args: Any, **kwargs: Any) -> None:
  21. pass
  22. @classmethod
  23. def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
  24. pass
  25. @classmethod
  26. def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
  27. pass
  28. def __call__(self: 'DataclassT', *args: Any, **kwargs: Any) -> 'DataclassT':
  29. pass
  30. def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
  31. if isinstance(v, cls):
  32. return v
  33. elif isinstance(v, (list, tuple)):
  34. return cls(*v)
  35. elif isinstance(v, dict):
  36. return cls(**v)
  37. # In nested dataclasses, v can be of type `dataclasses.dataclass`.
  38. # But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`,
  39. # which inherits directly from the class of `v`.
  40. elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v):
  41. import dataclasses
  42. return cls(**dataclasses.asdict(v))
  43. else:
  44. raise DataclassTypeError(class_name=cls.__name__)
  45. def _get_validators(cls: Type['Dataclass']) -> 'CallableGenerator':
  46. yield cls.__validate__
  47. def setattr_validate_assignment(self: 'Dataclass', name: str, value: Any) -> None:
  48. if self.__initialised__:
  49. d = dict(self.__dict__)
  50. d.pop(name, None)
  51. known_field = self.__pydantic_model__.__fields__.get(name, None)
  52. if known_field:
  53. value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
  54. if error_:
  55. raise ValidationError([error_], self.__class__)
  56. object.__setattr__(self, name, value)
  57. def is_builtin_dataclass(_cls: Type[Any]) -> bool:
  58. """
  59. `dataclasses.is_dataclass` is True if one of the class parents is a `dataclass`.
  60. This is why we also add a class attribute `__processed__` to only consider 'direct' built-in dataclasses
  61. """
  62. import dataclasses
  63. return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)
  64. def _generate_pydantic_post_init(
  65. post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
  66. ) -> Callable[..., None]:
  67. def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
  68. if post_init_original is not None:
  69. post_init_original(self, *initvars)
  70. if getattr(self, '__has_field_info_default__', False):
  71. # We need to remove `FieldInfo` values since they are not valid as input
  72. # It's ok to do that because they are obviously the default values!
  73. input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
  74. else:
  75. input_data = self.__dict__
  76. d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
  77. if validation_error:
  78. raise validation_error
  79. object.__setattr__(self, '__dict__', {**getattr(self, '__dict__', {}), **d})
  80. object.__setattr__(self, '__initialised__', True)
  81. if post_init_post_parse is not None:
  82. post_init_post_parse(self, *initvars)
  83. return _pydantic_post_init
  84. def _process_class(
  85. _cls: Type[Any],
  86. init: bool,
  87. repr: bool,
  88. eq: bool,
  89. order: bool,
  90. unsafe_hash: bool,
  91. frozen: bool,
  92. config: Optional[Type[Any]],
  93. ) -> Type['Dataclass']:
  94. import dataclasses
  95. post_init_original = getattr(_cls, '__post_init__', None)
  96. if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
  97. post_init_original = None
  98. if not post_init_original:
  99. post_init_original = getattr(_cls, '__post_init_original__', None)
  100. post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
  101. _pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)
  102. # If the class is already a dataclass, __post_init__ will not be called automatically
  103. # so no validation will be added.
  104. # We hence create dynamically a new dataclass:
  105. # ```
  106. # @dataclasses.dataclass
  107. # class NewClass(_cls):
  108. # __post_init__ = _pydantic_post_init
  109. # ```
  110. # with the exact same fields as the base dataclass
  111. # and register it on module level to address pickle problem:
  112. # https://github.com/samuelcolvin/pydantic/issues/2111
  113. if is_builtin_dataclass(_cls):
  114. uniq_class_name = f'_Pydantic_{_cls.__name__}_{id(_cls)}'
  115. _cls = type(
  116. # for pretty output new class will have the name as original
  117. _cls.__name__,
  118. (_cls,),
  119. {
  120. '__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__),
  121. '__post_init__': _pydantic_post_init,
  122. # attrs for pickle to find this class
  123. '__module__': __name__,
  124. '__qualname__': uniq_class_name,
  125. },
  126. )
  127. globals()[uniq_class_name] = _cls
  128. else:
  129. _cls.__post_init__ = _pydantic_post_init
  130. cls: Type['Dataclass'] = dataclasses.dataclass( # type: ignore
  131. _cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
  132. )
  133. cls.__processed__ = ClassAttribute('__processed__', True)
  134. field_definitions: Dict[str, Any] = {}
  135. for field in dataclasses.fields(cls):
  136. default: Any = Undefined
  137. default_factory: Optional['NoArgAnyCallable'] = None
  138. field_info: FieldInfo
  139. if field.default is not dataclasses.MISSING:
  140. default = field.default
  141. elif field.default_factory is not dataclasses.MISSING:
  142. default_factory = field.default_factory
  143. else:
  144. default = Required
  145. if isinstance(default, FieldInfo):
  146. field_info = default
  147. cls.__has_field_info_default__ = True
  148. else:
  149. field_info = Field(default=default, default_factory=default_factory, **field.metadata)
  150. field_definitions[field.name] = (field.type, field_info)
  151. validators = gather_all_validators(cls)
  152. cls.__pydantic_model__ = create_model(
  153. cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
  154. )
  155. cls.__initialised__ = False
  156. cls.__validate__ = classmethod(_validate_dataclass) # type: ignore[assignment]
  157. cls.__get_validators__ = classmethod(_get_validators) # type: ignore[assignment]
  158. if post_init_original:
  159. cls.__post_init_original__ = post_init_original
  160. if cls.__pydantic_model__.__config__.validate_assignment and not frozen:
  161. cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]
  162. return cls
  163. @overload
  164. def dataclass(
  165. *,
  166. init: bool = True,
  167. repr: bool = True,
  168. eq: bool = True,
  169. order: bool = False,
  170. unsafe_hash: bool = False,
  171. frozen: bool = False,
  172. config: Type[Any] = None,
  173. ) -> Callable[[Type[Any]], Type['Dataclass']]:
  174. ...
  175. @overload
  176. def dataclass(
  177. _cls: Type[Any],
  178. *,
  179. init: bool = True,
  180. repr: bool = True,
  181. eq: bool = True,
  182. order: bool = False,
  183. unsafe_hash: bool = False,
  184. frozen: bool = False,
  185. config: Type[Any] = None,
  186. ) -> Type['Dataclass']:
  187. ...
  188. def dataclass(
  189. _cls: Optional[Type[Any]] = None,
  190. *,
  191. init: bool = True,
  192. repr: bool = True,
  193. eq: bool = True,
  194. order: bool = False,
  195. unsafe_hash: bool = False,
  196. frozen: bool = False,
  197. config: Type[Any] = None,
  198. ) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]:
  199. """
  200. Like the python standard lib dataclasses but with type validation.
  201. Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
  202. as Config.validate_assignment.
  203. """
  204. def wrap(cls: Type[Any]) -> Type['Dataclass']:
  205. return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)
  206. if _cls is None:
  207. return wrap
  208. return wrap(_cls)
  209. def make_dataclass_validator(_cls: Type[Any], config: Type['BaseConfig']) -> 'CallableGenerator':
  210. """
  211. Create a pydantic.dataclass from a builtin dataclass to add type validation
  212. and yield the validators
  213. It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
  214. """
  215. dataclass_params = _cls.__dataclass_params__
  216. stdlib_dataclass_parameters = {param: getattr(dataclass_params, param) for param in dataclass_params.__slots__}
  217. cls = dataclass(_cls, config=config, **stdlib_dataclass_parameters)
  218. yield from _get_validators(cls)