You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
4.8 KiB
139 lines
4.8 KiB
"""RootModel class and type definitions.""" |
|
|
|
from __future__ import annotations as _annotations |
|
|
|
import typing |
|
from copy import copy, deepcopy |
|
|
|
from pydantic_core import PydanticUndefined |
|
|
|
from . import PydanticUserError |
|
from ._internal import _repr |
|
from .main import BaseModel, _object_setattr |
|
|
|
if typing.TYPE_CHECKING: |
|
from typing import Any |
|
|
|
from typing_extensions import Literal |
|
|
|
Model = typing.TypeVar('Model', bound='BaseModel') |
|
|
|
|
|
__all__ = ('RootModel',) |
|
|
|
|
|
RootModelRootType = typing.TypeVar('RootModelRootType') |
|
|
|
|
|
class RootModel(BaseModel, typing.Generic[RootModelRootType]): |
|
"""Usage docs: https://docs.pydantic.dev/2.2/usage/models/#rootmodel-and-custom-root-types |
|
|
|
A Pydantic `BaseModel` for the root object of the model. |
|
|
|
Attributes: |
|
root: The root object of the model. |
|
__pydantic_root_model__: Whether the model is a RootModel. |
|
__pydantic_private__: Private fields in the model. |
|
__pydantic_extra__: Extra fields in the model. |
|
|
|
""" |
|
|
|
__pydantic_root_model__ = True |
|
__pydantic_private__ = None |
|
__pydantic_extra__ = None |
|
|
|
root: RootModelRootType |
|
|
|
def __init_subclass__(cls, **kwargs): |
|
extra = cls.model_config.get('extra') |
|
if extra is not None: |
|
raise PydanticUserError( |
|
"`RootModel` does not support setting `model_config['extra']`", code='root-model-extra' |
|
) |
|
super().__init_subclass__(**kwargs) |
|
|
|
def __init__(__pydantic_self__, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore |
|
__tracebackhide__ = True |
|
if data: |
|
if root is not PydanticUndefined: |
|
raise ValueError( |
|
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments' |
|
) |
|
root = data # type: ignore |
|
__pydantic_self__.__pydantic_validator__.validate_python(root, self_instance=__pydantic_self__) |
|
|
|
__init__.__pydantic_base_init__ = True |
|
|
|
@classmethod |
|
def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model: |
|
"""Create a new model using the provided root object and update fields set. |
|
|
|
Args: |
|
root: The root object of the model. |
|
_fields_set: The set of fields to be updated. |
|
|
|
Returns: |
|
The new model. |
|
|
|
Raises: |
|
NotImplemented: If the model is not a subclass of `RootModel`. |
|
""" |
|
return super().model_construct(root=root, _fields_set=_fields_set) |
|
|
|
def __getstate__(self) -> dict[Any, Any]: |
|
return { |
|
'__dict__': self.__dict__, |
|
'__pydantic_fields_set__': self.__pydantic_fields_set__, |
|
} |
|
|
|
def __setstate__(self, state: dict[Any, Any]) -> None: |
|
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__']) |
|
_object_setattr(self, '__dict__', state['__dict__']) |
|
|
|
def __copy__(self: Model) -> Model: |
|
"""Returns a shallow copy of the model.""" |
|
cls = type(self) |
|
m = cls.__new__(cls) |
|
_object_setattr(m, '__dict__', copy(self.__dict__)) |
|
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) |
|
return m |
|
|
|
def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model: |
|
"""Returns a deep copy of the model.""" |
|
cls = type(self) |
|
m = cls.__new__(cls) |
|
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo)) |
|
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str], |
|
# and attempting a deepcopy would be marginally slower. |
|
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) |
|
return m |
|
|
|
if typing.TYPE_CHECKING: |
|
|
|
def model_dump( |
|
self, |
|
*, |
|
mode: Literal['json', 'python'] | str = 'python', |
|
include: Any = None, |
|
exclude: Any = None, |
|
by_alias: bool = False, |
|
exclude_unset: bool = False, |
|
exclude_defaults: bool = False, |
|
exclude_none: bool = False, |
|
round_trip: bool = False, |
|
warnings: bool = True, |
|
) -> RootModelRootType: |
|
"""This method is included just to get a more accurate return type for type checkers. |
|
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary. |
|
|
|
See the documentation of `BaseModel.model_dump` for more details about the arguments. |
|
""" |
|
... |
|
|
|
def __eq__(self, other: Any) -> bool: |
|
if not isinstance(other, RootModel): |
|
return NotImplemented |
|
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other) |
|
|
|
def __repr_args__(self) -> _repr.ReprArgs: |
|
yield 'root', self.root
|
|
|