You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

158 lines
6.1 KiB

  1. """RootModel class and type definitions."""
  2. from __future__ import annotations as _annotations
  3. import typing
  4. from copy import copy, deepcopy
  5. from pydantic_core import PydanticUndefined
  6. from . import PydanticUserError
  7. from ._internal import _model_construction, _repr
  8. from .main import BaseModel, _object_setattr
  9. if typing.TYPE_CHECKING:
  10. from typing import Any, Literal
  11. from typing_extensions import Self, dataclass_transform
  12. from .fields import Field as PydanticModelField
  13. from .fields import PrivateAttr as PydanticModelPrivateAttr
  14. # dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
  15. # takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
  16. # on a new metaclass.
  17. @dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr))
  18. class _RootModelMetaclass(_model_construction.ModelMetaclass): ...
  19. else:
  20. _RootModelMetaclass = _model_construction.ModelMetaclass
  21. __all__ = ('RootModel',)
  22. RootModelRootType = typing.TypeVar('RootModelRootType')
  23. class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
  24. """!!! abstract "Usage Documentation"
  25. [`RootModel` and Custom Root Types](../concepts/models.md#rootmodel-and-custom-root-types)
  26. A Pydantic `BaseModel` for the root object of the model.
  27. Attributes:
  28. root: The root object of the model.
  29. __pydantic_root_model__: Whether the model is a RootModel.
  30. __pydantic_private__: Private fields in the model.
  31. __pydantic_extra__: Extra fields in the model.
  32. """
  33. __pydantic_root_model__ = True
  34. __pydantic_private__ = None
  35. __pydantic_extra__ = None
  36. root: RootModelRootType
  37. def __init_subclass__(cls, **kwargs):
  38. extra = cls.model_config.get('extra')
  39. if extra is not None:
  40. raise PydanticUserError(
  41. "`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
  42. )
  43. super().__init_subclass__(**kwargs)
  44. def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
  45. __tracebackhide__ = True
  46. if data:
  47. if root is not PydanticUndefined:
  48. raise ValueError(
  49. '"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
  50. )
  51. root = data # type: ignore
  52. self.__pydantic_validator__.validate_python(root, self_instance=self)
  53. __init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
  54. @classmethod
  55. def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore
  56. """Create a new model using the provided root object and update fields set.
  57. Args:
  58. root: The root object of the model.
  59. _fields_set: The set of fields to be updated.
  60. Returns:
  61. The new model.
  62. Raises:
  63. NotImplemented: If the model is not a subclass of `RootModel`.
  64. """
  65. return super().model_construct(root=root, _fields_set=_fields_set)
  66. def __getstate__(self) -> dict[Any, Any]:
  67. return {
  68. '__dict__': self.__dict__,
  69. '__pydantic_fields_set__': self.__pydantic_fields_set__,
  70. }
  71. def __setstate__(self, state: dict[Any, Any]) -> None:
  72. _object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
  73. _object_setattr(self, '__dict__', state['__dict__'])
  74. def __copy__(self) -> Self:
  75. """Returns a shallow copy of the model."""
  76. cls = type(self)
  77. m = cls.__new__(cls)
  78. _object_setattr(m, '__dict__', copy(self.__dict__))
  79. _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
  80. return m
  81. def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
  82. """Returns a deep copy of the model."""
  83. cls = type(self)
  84. m = cls.__new__(cls)
  85. _object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
  86. # This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
  87. # and attempting a deepcopy would be marginally slower.
  88. _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
  89. return m
  90. if typing.TYPE_CHECKING:
  91. def model_dump( # type: ignore
  92. self,
  93. *,
  94. mode: Literal['json', 'python'] | str = 'python',
  95. include: Any = None,
  96. exclude: Any = None,
  97. context: dict[str, Any] | None = None,
  98. by_alias: bool | None = None,
  99. exclude_unset: bool = False,
  100. exclude_defaults: bool = False,
  101. exclude_none: bool = False,
  102. round_trip: bool = False,
  103. warnings: bool | Literal['none', 'warn', 'error'] = True,
  104. serialize_as_any: bool = False,
  105. ) -> Any:
  106. """This method is included just to get a more accurate return type for type checkers.
  107. It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
  108. See the documentation of `BaseModel.model_dump` for more details about the arguments.
  109. Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is
  110. not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return
  111. type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could
  112. even be something different, in the case of a custom serializer.
  113. Thus, `Any` is used here to catch all of these cases.
  114. """
  115. ...
  116. def __eq__(self, other: Any) -> bool:
  117. if not isinstance(other, RootModel):
  118. return NotImplemented
  119. return self.__pydantic_fields__['root'].annotation == other.__pydantic_fields__[
  120. 'root'
  121. ].annotation and super().__eq__(other)
  122. def __repr_args__(self) -> _repr.ReprArgs:
  123. yield 'root', self.root