Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.
 
 
 
 

501 rinda
18 KiB

  1. """
  2. The main purpose is to enhance stdlib dataclasses by adding validation
  3. A pydantic dataclass can be generated from scratch or from a stdlib one.
  4. Behind the scene, a pydantic dataclass is just like a regular one on which we attach
  5. a `BaseModel` and magic methods to trigger the validation of the data.
  6. `__init__` and `__post_init__` are hence overridden and have extra logic to be
  7. able to validate input data.
  8. When a pydantic dataclass is generated from scratch, it's just a plain dataclass
  9. with validation triggered at initialization
  10. The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
  11. ```py
  12. @dataclasses.dataclass
  13. class M:
  14. x: int
  15. ValidatedM = pydantic.dataclasses.dataclass(M)
  16. ```
  17. We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
  18. ```py
  19. assert isinstance(ValidatedM(x=1), M)
  20. assert ValidatedM(x=1) == M(x=1)
  21. ```
  22. This means we **don't want to create a new dataclass that inherits from it**
  23. The trick is to create a wrapper around `M` that will act as a proxy to trigger
  24. validation without altering default `M` behaviour.
  25. """
  26. import copy
  27. import dataclasses
  28. import sys
  29. from contextlib import contextmanager
  30. from functools import wraps
  31. try:
  32. from functools import cached_property
  33. except ImportError:
  34. # cached_property available only for python3.8+
  35. pass
  36. from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
  37. from typing_extensions import dataclass_transform
  38. from pydantic.v1.class_validators import gather_all_validators
  39. from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config
  40. from pydantic.v1.error_wrappers import ValidationError
  41. from pydantic.v1.errors import DataclassTypeError
  42. from pydantic.v1.fields import Field, FieldInfo, Required, Undefined
  43. from pydantic.v1.main import create_model, validate_model
  44. from pydantic.v1.utils import ClassAttribute
  45. if TYPE_CHECKING:
  46. from pydantic.v1.main import BaseModel
  47. from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable
  48. DataclassT = TypeVar('DataclassT', bound='Dataclass')
  49. DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
  50. class Dataclass:
  51. # stdlib attributes
  52. __dataclass_fields__: ClassVar[Dict[str, Any]]
  53. __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
  54. __post_init__: ClassVar[Callable[..., None]]
  55. # Added by pydantic
  56. __pydantic_run_validation__: ClassVar[bool]
  57. __post_init_post_parse__: ClassVar[Callable[..., None]]
  58. __pydantic_initialised__: ClassVar[bool]
  59. __pydantic_model__: ClassVar[Type[BaseModel]]
  60. __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
  61. __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
  62. def __init__(self, *args: object, **kwargs: object) -> None:
  63. pass
  64. @classmethod
  65. def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
  66. pass
  67. @classmethod
  68. def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
  69. pass
  70. __all__ = [
  71. 'dataclass',
  72. 'set_validation',
  73. 'create_pydantic_model_from_dataclass',
  74. 'is_builtin_dataclass',
  75. 'make_dataclass_validator',
  76. ]
  77. _T = TypeVar('_T')
  78. if sys.version_info >= (3, 10):
  79. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  80. @overload
  81. def dataclass(
  82. *,
  83. init: bool = True,
  84. repr: bool = True,
  85. eq: bool = True,
  86. order: bool = False,
  87. unsafe_hash: bool = False,
  88. frozen: bool = False,
  89. config: Union[ConfigDict, Type[object], None] = None,
  90. validate_on_init: Optional[bool] = None,
  91. use_proxy: Optional[bool] = None,
  92. kw_only: bool = ...,
  93. ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
  94. ...
  95. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  96. @overload
  97. def dataclass(
  98. _cls: Type[_T],
  99. *,
  100. init: bool = True,
  101. repr: bool = True,
  102. eq: bool = True,
  103. order: bool = False,
  104. unsafe_hash: bool = False,
  105. frozen: bool = False,
  106. config: Union[ConfigDict, Type[object], None] = None,
  107. validate_on_init: Optional[bool] = None,
  108. use_proxy: Optional[bool] = None,
  109. kw_only: bool = ...,
  110. ) -> 'DataclassClassOrWrapper':
  111. ...
  112. else:
  113. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  114. @overload
  115. def dataclass(
  116. *,
  117. init: bool = True,
  118. repr: bool = True,
  119. eq: bool = True,
  120. order: bool = False,
  121. unsafe_hash: bool = False,
  122. frozen: bool = False,
  123. config: Union[ConfigDict, Type[object], None] = None,
  124. validate_on_init: Optional[bool] = None,
  125. use_proxy: Optional[bool] = None,
  126. ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
  127. ...
  128. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  129. @overload
  130. def dataclass(
  131. _cls: Type[_T],
  132. *,
  133. init: bool = True,
  134. repr: bool = True,
  135. eq: bool = True,
  136. order: bool = False,
  137. unsafe_hash: bool = False,
  138. frozen: bool = False,
  139. config: Union[ConfigDict, Type[object], None] = None,
  140. validate_on_init: Optional[bool] = None,
  141. use_proxy: Optional[bool] = None,
  142. ) -> 'DataclassClassOrWrapper':
  143. ...
  144. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  145. def dataclass(
  146. _cls: Optional[Type[_T]] = None,
  147. *,
  148. init: bool = True,
  149. repr: bool = True,
  150. eq: bool = True,
  151. order: bool = False,
  152. unsafe_hash: bool = False,
  153. frozen: bool = False,
  154. config: Union[ConfigDict, Type[object], None] = None,
  155. validate_on_init: Optional[bool] = None,
  156. use_proxy: Optional[bool] = None,
  157. kw_only: bool = False,
  158. ) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
  159. """
  160. Like the python standard lib dataclasses but with type validation.
  161. The result is either a pydantic dataclass that will validate input data
  162. or a wrapper that will trigger validation around a stdlib dataclass
  163. to avoid modifying it directly
  164. """
  165. the_config = get_config(config)
  166. def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
  167. should_use_proxy = (
  168. use_proxy
  169. if use_proxy is not None
  170. else (
  171. is_builtin_dataclass(cls)
  172. and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
  173. )
  174. )
  175. if should_use_proxy:
  176. dc_cls_doc = ''
  177. dc_cls = DataclassProxy(cls)
  178. default_validate_on_init = False
  179. else:
  180. dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
  181. if sys.version_info >= (3, 10):
  182. dc_cls = dataclasses.dataclass(
  183. cls,
  184. init=init,
  185. repr=repr,
  186. eq=eq,
  187. order=order,
  188. unsafe_hash=unsafe_hash,
  189. frozen=frozen,
  190. kw_only=kw_only,
  191. )
  192. else:
  193. dc_cls = dataclasses.dataclass( # type: ignore
  194. cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
  195. )
  196. default_validate_on_init = True
  197. should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
  198. _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
  199. dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
  200. return dc_cls
  201. if _cls is None:
  202. return wrap
  203. return wrap(_cls)
  204. @contextmanager
  205. def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
  206. original_run_validation = cls.__pydantic_run_validation__
  207. try:
  208. cls.__pydantic_run_validation__ = value
  209. yield cls
  210. finally:
  211. cls.__pydantic_run_validation__ = original_run_validation
  212. class DataclassProxy:
  213. __slots__ = '__dataclass__'
  214. def __init__(self, dc_cls: Type['Dataclass']) -> None:
  215. object.__setattr__(self, '__dataclass__', dc_cls)
  216. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  217. with set_validation(self.__dataclass__, True):
  218. return self.__dataclass__(*args, **kwargs)
  219. def __getattr__(self, name: str) -> Any:
  220. return getattr(self.__dataclass__, name)
  221. def __setattr__(self, __name: str, __value: Any) -> None:
  222. return setattr(self.__dataclass__, __name, __value)
  223. def __instancecheck__(self, instance: Any) -> bool:
  224. return isinstance(instance, self.__dataclass__)
  225. def __copy__(self) -> 'DataclassProxy':
  226. return DataclassProxy(copy.copy(self.__dataclass__))
  227. def __deepcopy__(self, memo: Any) -> 'DataclassProxy':
  228. return DataclassProxy(copy.deepcopy(self.__dataclass__, memo))
  229. def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
  230. dc_cls: Type['Dataclass'],
  231. config: Type[BaseConfig],
  232. validate_on_init: bool,
  233. dc_cls_doc: str,
  234. ) -> None:
  235. """
  236. We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
  237. it won't even exist (code is generated on the fly by `dataclasses`)
  238. By default, we run validation after `__init__` or `__post_init__` if defined
  239. """
  240. init = dc_cls.__init__
  241. @wraps(init)
  242. def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
  243. if config.extra == Extra.ignore:
  244. init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
  245. elif config.extra == Extra.allow:
  246. for k, v in kwargs.items():
  247. self.__dict__.setdefault(k, v)
  248. init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
  249. else:
  250. init(self, *args, **kwargs)
  251. if hasattr(dc_cls, '__post_init__'):
  252. try:
  253. post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined]
  254. except AttributeError:
  255. post_init = dc_cls.__post_init__
  256. @wraps(post_init)
  257. def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
  258. if config.post_init_call == 'before_validation':
  259. post_init(self, *args, **kwargs)
  260. if self.__class__.__pydantic_run_validation__:
  261. self.__pydantic_validate_values__()
  262. if hasattr(self, '__post_init_post_parse__'):
  263. self.__post_init_post_parse__(*args, **kwargs)
  264. if config.post_init_call == 'after_validation':
  265. post_init(self, *args, **kwargs)
  266. setattr(dc_cls, '__init__', handle_extra_init)
  267. setattr(dc_cls, '__post_init__', new_post_init)
  268. else:
  269. @wraps(init)
  270. def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
  271. handle_extra_init(self, *args, **kwargs)
  272. if self.__class__.__pydantic_run_validation__:
  273. self.__pydantic_validate_values__()
  274. if hasattr(self, '__post_init_post_parse__'):
  275. # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
  276. # public method `dataclasses.fields`
  277. # get all initvars and their default values
  278. initvars_and_values: Dict[str, Any] = {}
  279. for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
  280. if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
  281. try:
  282. # set arg value by default
  283. initvars_and_values[f.name] = args[i]
  284. except IndexError:
  285. initvars_and_values[f.name] = kwargs.get(f.name, f.default)
  286. self.__post_init_post_parse__(**initvars_and_values)
  287. setattr(dc_cls, '__init__', new_init)
  288. setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
  289. setattr(dc_cls, '__pydantic_initialised__', False)
  290. setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
  291. setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
  292. setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
  293. setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
  294. if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
  295. setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
  296. def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
  297. yield cls.__validate__
  298. def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
  299. with set_validation(cls, True):
  300. if isinstance(v, cls):
  301. v.__pydantic_validate_values__()
  302. return v
  303. elif isinstance(v, (list, tuple)):
  304. return cls(*v)
  305. elif isinstance(v, dict):
  306. return cls(**v)
  307. else:
  308. raise DataclassTypeError(class_name=cls.__name__)
  309. def create_pydantic_model_from_dataclass(
  310. dc_cls: Type['Dataclass'],
  311. config: Type[Any] = BaseConfig,
  312. dc_cls_doc: Optional[str] = None,
  313. ) -> Type['BaseModel']:
  314. field_definitions: Dict[str, Any] = {}
  315. for field in dataclasses.fields(dc_cls):
  316. default: Any = Undefined
  317. default_factory: Optional['NoArgAnyCallable'] = None
  318. field_info: FieldInfo
  319. if field.default is not dataclasses.MISSING:
  320. default = field.default
  321. elif field.default_factory is not dataclasses.MISSING:
  322. default_factory = field.default_factory
  323. else:
  324. default = Required
  325. if isinstance(default, FieldInfo):
  326. field_info = default
  327. dc_cls.__pydantic_has_field_info_default__ = True
  328. else:
  329. field_info = Field(default=default, default_factory=default_factory, **field.metadata)
  330. field_definitions[field.name] = (field.type, field_info)
  331. validators = gather_all_validators(dc_cls)
  332. model: Type['BaseModel'] = create_model(
  333. dc_cls.__name__,
  334. __config__=config,
  335. __module__=dc_cls.__module__,
  336. __validators__=validators,
  337. __cls_kwargs__={'__resolve_forward_refs__': False},
  338. **field_definitions,
  339. )
  340. model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
  341. return model
  342. if sys.version_info >= (3, 8):
  343. def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
  344. return isinstance(getattr(type(obj), k, None), cached_property)
  345. else:
  346. def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
  347. return False
  348. def _dataclass_validate_values(self: 'Dataclass') -> None:
  349. # validation errors can occur if this function is called twice on an already initialised dataclass.
  350. # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
  351. if getattr(self, '__pydantic_initialised__'):
  352. return
  353. if getattr(self, '__pydantic_has_field_info_default__', False):
  354. # We need to remove `FieldInfo` values since they are not valid as input
  355. # It's ok to do that because they are obviously the default values!
  356. input_data = {
  357. k: v
  358. for k, v in self.__dict__.items()
  359. if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k))
  360. }
  361. else:
  362. input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)}
  363. d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
  364. if validation_error:
  365. raise validation_error
  366. self.__dict__.update(d)
  367. object.__setattr__(self, '__pydantic_initialised__', True)
  368. def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
  369. if self.__pydantic_initialised__:
  370. d = dict(self.__dict__)
  371. d.pop(name, None)
  372. known_field = self.__pydantic_model__.__fields__.get(name, None)
  373. if known_field:
  374. value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
  375. if error_:
  376. raise ValidationError([error_], self.__class__)
  377. object.__setattr__(self, name, value)
  378. def is_builtin_dataclass(_cls: Type[Any]) -> bool:
  379. """
  380. Whether a class is a stdlib dataclass
  381. (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
  382. we check that
  383. - `_cls` is a dataclass
  384. - `_cls` is not a processed pydantic dataclass (with a basemodel attached)
  385. - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
  386. e.g.
  387. ```
  388. @dataclasses.dataclass
  389. class A:
  390. x: int
  391. @pydantic.dataclasses.dataclass
  392. class B(A):
  393. y: int
  394. ```
  395. In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
  396. which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
  397. """
  398. return (
  399. dataclasses.is_dataclass(_cls)
  400. and not hasattr(_cls, '__pydantic_model__')
  401. and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
  402. )
  403. def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
  404. """
  405. Create a pydantic.dataclass from a builtin dataclass to add type validation
  406. and yield the validators
  407. It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
  408. """
  409. yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))