|
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
-
- from .class_validators import gather_all_validators
- from .error_wrappers import ValidationError
- from .errors import DataclassTypeError
- from .fields import Field, FieldInfo, Required, Undefined
- from .main import create_model, validate_model
- from .typing import resolve_annotations
- from .utils import ClassAttribute
-
- if TYPE_CHECKING:
- from .config import BaseConfig
- from .main import BaseModel
- from .typing import CallableGenerator, NoArgAnyCallable
-
- DataclassT = TypeVar('DataclassT', bound='Dataclass')
-
- class Dataclass:
- __pydantic_model__: Type[BaseModel]
- __initialised__: bool
- __post_init_original__: Optional[Callable[..., None]]
- __processed__: Optional[ClassAttribute]
- __has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- pass
-
- @classmethod
- def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
- pass
-
- @classmethod
- def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
- pass
-
- def __call__(self: 'DataclassT', *args: Any, **kwargs: Any) -> 'DataclassT':
- pass
-
-
- def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
- if isinstance(v, cls):
- return v
- elif isinstance(v, (list, tuple)):
- return cls(*v)
- elif isinstance(v, dict):
- return cls(**v)
- # In nested dataclasses, v can be of type `dataclasses.dataclass`.
- # But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`,
- # which inherits directly from the class of `v`.
- elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v):
- import dataclasses
-
- return cls(**dataclasses.asdict(v))
- else:
- raise DataclassTypeError(class_name=cls.__name__)
-
-
- def _get_validators(cls: Type['Dataclass']) -> 'CallableGenerator':
- yield cls.__validate__
-
-
- def setattr_validate_assignment(self: 'Dataclass', name: str, value: Any) -> None:
- if self.__initialised__:
- d = dict(self.__dict__)
- d.pop(name, None)
- known_field = self.__pydantic_model__.__fields__.get(name, None)
- if known_field:
- value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
- if error_:
- raise ValidationError([error_], self.__class__)
-
- object.__setattr__(self, name, value)
-
-
- def is_builtin_dataclass(_cls: Type[Any]) -> bool:
- """
- `dataclasses.is_dataclass` is True if one of the class parents is a `dataclass`.
- This is why we also add a class attribute `__processed__` to only consider 'direct' built-in dataclasses
- """
- import dataclasses
-
- return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)
-
-
- def _generate_pydantic_post_init(
- post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
- ) -> Callable[..., None]:
- def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
- if post_init_original is not None:
- post_init_original(self, *initvars)
-
- if getattr(self, '__has_field_info_default__', False):
- # We need to remove `FieldInfo` values since they are not valid as input
- # It's ok to do that because they are obviously the default values!
- input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
- else:
- input_data = self.__dict__
- d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
- if validation_error:
- raise validation_error
- object.__setattr__(self, '__dict__', {**getattr(self, '__dict__', {}), **d})
- object.__setattr__(self, '__initialised__', True)
- if post_init_post_parse is not None:
- post_init_post_parse(self, *initvars)
-
- return _pydantic_post_init
-
-
- def _process_class(
- _cls: Type[Any],
- init: bool,
- repr: bool,
- eq: bool,
- order: bool,
- unsafe_hash: bool,
- frozen: bool,
- config: Optional[Type[Any]],
- ) -> Type['Dataclass']:
- import dataclasses
-
- post_init_original = getattr(_cls, '__post_init__', None)
- if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
- post_init_original = None
- if not post_init_original:
- post_init_original = getattr(_cls, '__post_init_original__', None)
-
- post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
-
- _pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)
-
- # If the class is already a dataclass, __post_init__ will not be called automatically
- # so no validation will be added.
- # We hence create dynamically a new dataclass:
- # ```
- # @dataclasses.dataclass
- # class NewClass(_cls):
- # __post_init__ = _pydantic_post_init
- # ```
- # with the exact same fields as the base dataclass
- # and register it on module level to address pickle problem:
- # https://github.com/samuelcolvin/pydantic/issues/2111
- if is_builtin_dataclass(_cls):
- uniq_class_name = f'_Pydantic_{_cls.__name__}_{id(_cls)}'
- _cls = type(
- # for pretty output new class will have the name as original
- _cls.__name__,
- (_cls,),
- {
- '__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__),
- '__post_init__': _pydantic_post_init,
- # attrs for pickle to find this class
- '__module__': __name__,
- '__qualname__': uniq_class_name,
- },
- )
- globals()[uniq_class_name] = _cls
- else:
- _cls.__post_init__ = _pydantic_post_init
- cls: Type['Dataclass'] = dataclasses.dataclass( # type: ignore
- _cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
- )
- cls.__processed__ = ClassAttribute('__processed__', True)
-
- field_definitions: Dict[str, Any] = {}
- for field in dataclasses.fields(cls):
- default: Any = Undefined
- default_factory: Optional['NoArgAnyCallable'] = None
- field_info: FieldInfo
-
- if field.default is not dataclasses.MISSING:
- default = field.default
- elif field.default_factory is not dataclasses.MISSING:
- default_factory = field.default_factory
- else:
- default = Required
-
- if isinstance(default, FieldInfo):
- field_info = default
- cls.__has_field_info_default__ = True
- else:
- field_info = Field(default=default, default_factory=default_factory, **field.metadata)
-
- field_definitions[field.name] = (field.type, field_info)
-
- validators = gather_all_validators(cls)
- cls.__pydantic_model__ = create_model(
- cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
- )
-
- cls.__initialised__ = False
- cls.__validate__ = classmethod(_validate_dataclass) # type: ignore[assignment]
- cls.__get_validators__ = classmethod(_get_validators) # type: ignore[assignment]
- if post_init_original:
- cls.__post_init_original__ = post_init_original
-
- if cls.__pydantic_model__.__config__.validate_assignment and not frozen:
- cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]
-
- return cls
-
-
- @overload
- def dataclass(
- *,
- init: bool = True,
- repr: bool = True,
- eq: bool = True,
- order: bool = False,
- unsafe_hash: bool = False,
- frozen: bool = False,
- config: Type[Any] = None,
- ) -> Callable[[Type[Any]], Type['Dataclass']]:
- ...
-
-
- @overload
- def dataclass(
- _cls: Type[Any],
- *,
- init: bool = True,
- repr: bool = True,
- eq: bool = True,
- order: bool = False,
- unsafe_hash: bool = False,
- frozen: bool = False,
- config: Type[Any] = None,
- ) -> Type['Dataclass']:
- ...
-
-
- def dataclass(
- _cls: Optional[Type[Any]] = None,
- *,
- init: bool = True,
- repr: bool = True,
- eq: bool = True,
- order: bool = False,
- unsafe_hash: bool = False,
- frozen: bool = False,
- config: Type[Any] = None,
- ) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]:
- """
- Like the python standard lib dataclasses but with type validation.
-
- Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
- as Config.validate_assignment.
- """
-
- def wrap(cls: Type[Any]) -> Type['Dataclass']:
- return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)
-
- if _cls is None:
- return wrap
-
- return wrap(_cls)
-
-
- def make_dataclass_validator(_cls: Type[Any], config: Type['BaseConfig']) -> 'CallableGenerator':
- """
- Create a pydantic.dataclass from a builtin dataclass to add type validation
- and yield the validators
- It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
- """
- dataclass_params = _cls.__dataclass_params__
- stdlib_dataclass_parameters = {param: getattr(dataclass_params, param) for param in dataclass_params.__slots__}
- cls = dataclass(_cls, config=config, **stdlib_dataclass_parameters)
- yield from _get_validators(cls)
|