Não pode escolher mais do que 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.
 
 
 
 

950 linhas
38 KiB

  1. import sys
  2. from configparser import ConfigParser
  3. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union
  4. from mypy.errorcodes import ErrorCode
  5. from mypy.nodes import (
  6. ARG_NAMED,
  7. ARG_NAMED_OPT,
  8. ARG_OPT,
  9. ARG_POS,
  10. ARG_STAR2,
  11. MDEF,
  12. Argument,
  13. AssignmentStmt,
  14. Block,
  15. CallExpr,
  16. ClassDef,
  17. Context,
  18. Decorator,
  19. EllipsisExpr,
  20. FuncBase,
  21. FuncDef,
  22. JsonDict,
  23. MemberExpr,
  24. NameExpr,
  25. PassStmt,
  26. PlaceholderNode,
  27. RefExpr,
  28. StrExpr,
  29. SymbolNode,
  30. SymbolTableNode,
  31. TempNode,
  32. TypeInfo,
  33. TypeVarExpr,
  34. Var,
  35. )
  36. from mypy.options import Options
  37. from mypy.plugin import (
  38. CheckerPluginInterface,
  39. ClassDefContext,
  40. FunctionContext,
  41. MethodContext,
  42. Plugin,
  43. ReportConfigContext,
  44. SemanticAnalyzerPluginInterface,
  45. )
  46. from mypy.plugins import dataclasses
  47. from mypy.semanal import set_callable_name # type: ignore
  48. from mypy.server.trigger import make_wildcard_trigger
  49. from mypy.types import (
  50. AnyType,
  51. CallableType,
  52. Instance,
  53. NoneType,
  54. Overloaded,
  55. ProperType,
  56. Type,
  57. TypeOfAny,
  58. TypeType,
  59. TypeVarId,
  60. TypeVarType,
  61. UnionType,
  62. get_proper_type,
  63. )
  64. from mypy.typevars import fill_typevars
  65. from mypy.util import get_unique_redefinition_name
  66. from mypy.version import __version__ as mypy_version
  67. from pydantic.v1.utils import is_valid_field
  68. try:
  69. from mypy.types import TypeVarDef # type: ignore[attr-defined]
  70. except ImportError: # pragma: no cover
  71. # Backward-compatible with TypeVarDef from Mypy 0.910.
  72. from mypy.types import TypeVarType as TypeVarDef
  73. CONFIGFILE_KEY = 'pydantic-mypy'
  74. METADATA_KEY = 'pydantic-mypy-metadata'
  75. _NAMESPACE = __name__[:-5] # 'pydantic' in 1.10.X, 'pydantic.v1' in v2.X
  76. BASEMODEL_FULLNAME = f'{_NAMESPACE}.main.BaseModel'
  77. BASESETTINGS_FULLNAME = f'{_NAMESPACE}.env_settings.BaseSettings'
  78. MODEL_METACLASS_FULLNAME = f'{_NAMESPACE}.main.ModelMetaclass'
  79. FIELD_FULLNAME = f'{_NAMESPACE}.fields.Field'
  80. DATACLASS_FULLNAME = f'{_NAMESPACE}.dataclasses.dataclass'
  81. def parse_mypy_version(version: str) -> Tuple[int, ...]:
  82. return tuple(map(int, version.partition('+')[0].split('.')))
  83. MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
  84. BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
  85. # Increment version if plugin changes and mypy caches should be invalidated
  86. __version__ = 2
  87. def plugin(version: str) -> 'TypingType[Plugin]':
  88. """
  89. `version` is the mypy version string
  90. We might want to use this to print a warning if the mypy version being used is
  91. newer, or especially older, than we expect (or need).
  92. """
  93. return PydanticPlugin
  94. class PydanticPlugin(Plugin):
  95. def __init__(self, options: Options) -> None:
  96. self.plugin_config = PydanticPluginConfig(options)
  97. self._plugin_data = self.plugin_config.to_data()
  98. super().__init__(options)
  99. def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
  100. sym = self.lookup_fully_qualified(fullname)
  101. if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
  102. # No branching may occur if the mypy cache has not been cleared
  103. if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro):
  104. return self._pydantic_model_class_maker_callback
  105. return None
  106. def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
  107. if fullname == MODEL_METACLASS_FULLNAME:
  108. return self._pydantic_model_metaclass_marker_callback
  109. return None
  110. def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
  111. sym = self.lookup_fully_qualified(fullname)
  112. if sym and sym.fullname == FIELD_FULLNAME:
  113. return self._pydantic_field_callback
  114. return None
  115. def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]:
  116. if fullname.endswith('.from_orm'):
  117. return from_orm_callback
  118. return None
  119. def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
  120. """Mark pydantic.dataclasses as dataclass.
  121. Mypy version 1.1.1 added support for `@dataclass_transform` decorator.
  122. """
  123. if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1):
  124. return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
  125. return None
  126. def report_config_data(self, ctx: ReportConfigContext) -> Dict[str, Any]:
  127. """Return all plugin config data.
  128. Used by mypy to determine if cache needs to be discarded.
  129. """
  130. return self._plugin_data
  131. def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
  132. transformer = PydanticModelTransformer(ctx, self.plugin_config)
  133. transformer.transform()
  134. def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
  135. """Reset dataclass_transform_spec attribute of ModelMetaclass.
  136. Let the plugin handle it. This behavior can be disabled
  137. if 'debug_dataclass_transform' is set to True', for testing purposes.
  138. """
  139. if self.plugin_config.debug_dataclass_transform:
  140. return
  141. info_metaclass = ctx.cls.info.declared_metaclass
  142. assert info_metaclass, "callback not passed from 'get_metaclass_hook'"
  143. if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
  144. info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined]
  145. def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
  146. """
  147. Extract the type of the `default` argument from the Field function, and use it as the return type.
  148. In particular:
  149. * Check whether the default and default_factory argument is specified.
  150. * Output an error if both are specified.
  151. * Retrieve the type of the argument which is specified, and use it as return type for the function.
  152. """
  153. default_any_type = ctx.default_return_type
  154. assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
  155. assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
  156. default_args = ctx.args[0]
  157. default_factory_args = ctx.args[1]
  158. if default_args and default_factory_args:
  159. error_default_and_default_factory_specified(ctx.api, ctx.context)
  160. return default_any_type
  161. if default_args:
  162. default_type = ctx.arg_types[0][0]
  163. default_arg = default_args[0]
  164. # Fallback to default Any type if the field is required
  165. if not isinstance(default_arg, EllipsisExpr):
  166. return default_type
  167. elif default_factory_args:
  168. default_factory_type = ctx.arg_types[1][0]
  169. # Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
  170. # Pydantic calls the default factory without any argument, so we retrieve the first item
  171. if isinstance(default_factory_type, Overloaded):
  172. if MYPY_VERSION_TUPLE > (0, 910):
  173. default_factory_type = default_factory_type.items[0]
  174. else:
  175. # Mypy0.910 exposes the items of overloaded types in a function
  176. default_factory_type = default_factory_type.items()[0] # type: ignore[operator]
  177. if isinstance(default_factory_type, CallableType):
  178. ret_type = default_factory_type.ret_type
  179. # mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
  180. # add this check in case it varies by version
  181. args = getattr(ret_type, 'args', None)
  182. if args:
  183. if all(isinstance(arg, TypeVarType) for arg in args):
  184. # Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
  185. ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
  186. return ret_type
  187. return default_any_type
  188. class PydanticPluginConfig:
  189. __slots__ = (
  190. 'init_forbid_extra',
  191. 'init_typed',
  192. 'warn_required_dynamic_aliases',
  193. 'warn_untyped_fields',
  194. 'debug_dataclass_transform',
  195. )
  196. init_forbid_extra: bool
  197. init_typed: bool
  198. warn_required_dynamic_aliases: bool
  199. warn_untyped_fields: bool
  200. debug_dataclass_transform: bool # undocumented
  201. def __init__(self, options: Options) -> None:
  202. if options.config_file is None: # pragma: no cover
  203. return
  204. toml_config = parse_toml(options.config_file)
  205. if toml_config is not None:
  206. config = toml_config.get('tool', {}).get('pydantic-mypy', {})
  207. for key in self.__slots__:
  208. setting = config.get(key, False)
  209. if not isinstance(setting, bool):
  210. raise ValueError(f'Configuration value must be a boolean for key: {key}')
  211. setattr(self, key, setting)
  212. else:
  213. plugin_config = ConfigParser()
  214. plugin_config.read(options.config_file)
  215. for key in self.__slots__:
  216. setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False)
  217. setattr(self, key, setting)
  218. def to_data(self) -> Dict[str, Any]:
  219. return {key: getattr(self, key) for key in self.__slots__}
  220. def from_orm_callback(ctx: MethodContext) -> Type:
  221. """
  222. Raise an error if orm_mode is not enabled
  223. """
  224. model_type: Instance
  225. ctx_type = ctx.type
  226. if isinstance(ctx_type, TypeType):
  227. ctx_type = ctx_type.item
  228. if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance):
  229. model_type = ctx_type.ret_type # called on the class
  230. elif isinstance(ctx_type, Instance):
  231. model_type = ctx_type # called on an instance (unusual, but still valid)
  232. else: # pragma: no cover
  233. detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})'
  234. error_unexpected_behavior(detail, ctx.api, ctx.context)
  235. return ctx.default_return_type
  236. pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
  237. if pydantic_metadata is None:
  238. return ctx.default_return_type
  239. orm_mode = pydantic_metadata.get('config', {}).get('orm_mode')
  240. if orm_mode is not True:
  241. error_from_orm(get_name(model_type.type), ctx.api, ctx.context)
  242. return ctx.default_return_type
  243. class PydanticModelTransformer:
  244. tracked_config_fields: Set[str] = {
  245. 'extra',
  246. 'allow_mutation',
  247. 'frozen',
  248. 'orm_mode',
  249. 'allow_population_by_field_name',
  250. 'alias_generator',
  251. }
  252. def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None:
  253. self._ctx = ctx
  254. self.plugin_config = plugin_config
  255. def transform(self) -> None:
  256. """
  257. Configures the BaseModel subclass according to the plugin settings.
  258. In particular:
  259. * determines the model config and fields,
  260. * adds a fields-aware signature for the initializer and construct methods
  261. * freezes the class if allow_mutation = False or frozen = True
  262. * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
  263. """
  264. ctx = self._ctx
  265. info = ctx.cls.info
  266. self.adjust_validator_signatures()
  267. config = self.collect_config()
  268. fields = self.collect_fields(config)
  269. is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1])
  270. self.add_initializer(fields, config, is_settings)
  271. self.add_construct_method(fields)
  272. self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True)
  273. info.metadata[METADATA_KEY] = {
  274. 'fields': {field.name: field.serialize() for field in fields},
  275. 'config': config.set_values_dict(),
  276. }
  277. def adjust_validator_signatures(self) -> None:
  278. """When we decorate a function `f` with `pydantic.validator(...), mypy sees
  279. `f` as a regular method taking a `self` instance, even though pydantic
  280. internally wraps `f` with `classmethod` if necessary.
  281. Teach mypy this by marking any function whose outermost decorator is a
  282. `validator()` call as a classmethod.
  283. """
  284. for name, sym in self._ctx.cls.info.names.items():
  285. if isinstance(sym.node, Decorator):
  286. first_dec = sym.node.original_decorators[0]
  287. if (
  288. isinstance(first_dec, CallExpr)
  289. and isinstance(first_dec.callee, NameExpr)
  290. and first_dec.callee.fullname == f'{_NAMESPACE}.class_validators.validator'
  291. ):
  292. sym.node.func.is_class = True
  293. def collect_config(self) -> 'ModelConfigData':
  294. """
  295. Collects the values of the config attributes that are used by the plugin, accounting for parent classes.
  296. """
  297. ctx = self._ctx
  298. cls = ctx.cls
  299. config = ModelConfigData()
  300. for stmt in cls.defs.body:
  301. if not isinstance(stmt, ClassDef):
  302. continue
  303. if stmt.name == 'Config':
  304. for substmt in stmt.defs.body:
  305. if not isinstance(substmt, AssignmentStmt):
  306. continue
  307. config.update(self.get_config_update(substmt))
  308. if (
  309. config.has_alias_generator
  310. and not config.allow_population_by_field_name
  311. and self.plugin_config.warn_required_dynamic_aliases
  312. ):
  313. error_required_dynamic_aliases(ctx.api, stmt)
  314. for info in cls.info.mro[1:]: # 0 is the current class
  315. if METADATA_KEY not in info.metadata:
  316. continue
  317. # Each class depends on the set of fields in its ancestors
  318. ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
  319. for name, value in info.metadata[METADATA_KEY]['config'].items():
  320. config.setdefault(name, value)
  321. return config
  322. def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']:
  323. """
  324. Collects the fields for the model, accounting for parent classes
  325. """
  326. # First, collect fields belonging to the current class.
  327. ctx = self._ctx
  328. cls = self._ctx.cls
  329. fields = [] # type: List[PydanticModelField]
  330. known_fields = set() # type: Set[str]
  331. for stmt in cls.defs.body:
  332. if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation
  333. continue
  334. lhs = stmt.lvalues[0]
  335. if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name):
  336. continue
  337. if not stmt.new_syntax and self.plugin_config.warn_untyped_fields:
  338. error_untyped_fields(ctx.api, stmt)
  339. # if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet
  340. # continue
  341. sym = cls.info.names.get(lhs.name)
  342. if sym is None: # pragma: no cover
  343. # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation)
  344. # This is the same logic used in the dataclasses plugin
  345. continue
  346. node = sym.node
  347. if isinstance(node, PlaceholderNode): # pragma: no cover
  348. # See the PlaceholderNode docstring for more detail about how this can occur
  349. # Basically, it is an edge case when dealing with complex import logic
  350. # This is the same logic used in the dataclasses plugin
  351. continue
  352. if not isinstance(node, Var): # pragma: no cover
  353. # Don't know if this edge case still happens with the `is_valid_field` check above
  354. # but better safe than sorry
  355. continue
  356. # x: ClassVar[int] is ignored by dataclasses.
  357. if node.is_classvar:
  358. continue
  359. is_required = self.get_is_required(cls, stmt, lhs)
  360. alias, has_dynamic_alias = self.get_alias_info(stmt)
  361. if (
  362. has_dynamic_alias
  363. and not model_config.allow_population_by_field_name
  364. and self.plugin_config.warn_required_dynamic_aliases
  365. ):
  366. error_required_dynamic_aliases(ctx.api, stmt)
  367. fields.append(
  368. PydanticModelField(
  369. name=lhs.name,
  370. is_required=is_required,
  371. alias=alias,
  372. has_dynamic_alias=has_dynamic_alias,
  373. line=stmt.line,
  374. column=stmt.column,
  375. )
  376. )
  377. known_fields.add(lhs.name)
  378. all_fields = fields.copy()
  379. for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object
  380. if METADATA_KEY not in info.metadata:
  381. continue
  382. superclass_fields = []
  383. # Each class depends on the set of fields in its ancestors
  384. ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
  385. for name, data in info.metadata[METADATA_KEY]['fields'].items():
  386. if name not in known_fields:
  387. field = PydanticModelField.deserialize(info, data)
  388. known_fields.add(name)
  389. superclass_fields.append(field)
  390. else:
  391. (field,) = (a for a in all_fields if a.name == name)
  392. all_fields.remove(field)
  393. superclass_fields.append(field)
  394. all_fields = superclass_fields + all_fields
  395. return all_fields
  396. def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None:
  397. """
  398. Adds a fields-aware `__init__` method to the class.
  399. The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
  400. """
  401. ctx = self._ctx
  402. typed = self.plugin_config.init_typed
  403. use_alias = config.allow_population_by_field_name is not True
  404. force_all_optional = is_settings or bool(
  405. config.has_alias_generator and not config.allow_population_by_field_name
  406. )
  407. init_arguments = self.get_field_arguments(
  408. fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias
  409. )
  410. if not self.should_init_forbid_extra(fields, config):
  411. var = Var('kwargs')
  412. init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
  413. if '__init__' not in ctx.cls.info.names:
  414. add_method(ctx, '__init__', init_arguments, NoneType())
  415. def add_construct_method(self, fields: List['PydanticModelField']) -> None:
  416. """
  417. Adds a fully typed `construct` classmethod to the class.
  418. Similar to the fields-aware __init__ method, but always uses the field names (not aliases),
  419. and does not treat settings fields as optional.
  420. """
  421. ctx = self._ctx
  422. set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')])
  423. optional_set_str = UnionType([set_str, NoneType()])
  424. fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
  425. construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False)
  426. construct_arguments = [fields_set_argument] + construct_arguments
  427. obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object')
  428. self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class
  429. tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name
  430. if MYPY_VERSION_TUPLE >= (1, 4):
  431. tvd = TypeVarType(
  432. self_tvar_name,
  433. tvar_fullname,
  434. (
  435. TypeVarId(-1, namespace=ctx.cls.fullname + '.construct')
  436. if MYPY_VERSION_TUPLE >= (1, 11)
  437. else TypeVarId(-1)
  438. ),
  439. [],
  440. obj_type,
  441. AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
  442. )
  443. self_tvar_expr = TypeVarExpr(
  444. self_tvar_name,
  445. tvar_fullname,
  446. [],
  447. obj_type,
  448. AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
  449. )
  450. else:
  451. tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type)
  452. self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type)
  453. ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr)
  454. # Backward-compatible with TypeVarDef from Mypy 0.910.
  455. if isinstance(tvd, TypeVarType):
  456. self_type = tvd
  457. else:
  458. self_type = TypeVarType(tvd)
  459. add_method(
  460. ctx,
  461. 'construct',
  462. construct_arguments,
  463. return_type=self_type,
  464. self_type=self_type,
  465. tvar_def=tvd,
  466. is_classmethod=True,
  467. )
  468. def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None:
  469. """
  470. Marks all fields as properties so that attempts to set them trigger mypy errors.
  471. This is the same approach used by the attrs and dataclasses plugins.
  472. """
  473. ctx = self._ctx
  474. info = ctx.cls.info
  475. for field in fields:
  476. sym_node = info.names.get(field.name)
  477. if sym_node is not None:
  478. var = sym_node.node
  479. if isinstance(var, Var):
  480. var.is_property = frozen
  481. elif isinstance(var, PlaceholderNode) and not ctx.api.final_iteration:
  482. # See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage
  483. ctx.api.defer()
  484. else: # pragma: no cover
  485. # I don't know whether it's possible to hit this branch, but I've added it for safety
  486. try:
  487. var_str = str(var)
  488. except TypeError:
  489. # This happens for PlaceholderNode; perhaps it will happen for other types in the future..
  490. var_str = repr(var)
  491. detail = f'sym_node.node: {var_str} (of type {var.__class__})'
  492. error_unexpected_behavior(detail, ctx.api, ctx.cls)
  493. else:
  494. var = field.to_var(info, use_alias=False)
  495. var.info = info
  496. var.is_property = frozen
  497. var._fullname = get_fullname(info) + '.' + get_name(var)
  498. info.names[get_name(var)] = SymbolTableNode(MDEF, var)
  499. def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']:
  500. """
  501. Determines the config update due to a single statement in the Config class definition.
  502. Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
  503. """
  504. lhs = substmt.lvalues[0]
  505. if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields):
  506. return None
  507. if lhs.name == 'extra':
  508. if isinstance(substmt.rvalue, StrExpr):
  509. forbid_extra = substmt.rvalue.value == 'forbid'
  510. elif isinstance(substmt.rvalue, MemberExpr):
  511. forbid_extra = substmt.rvalue.name == 'forbid'
  512. else:
  513. error_invalid_config_value(lhs.name, self._ctx.api, substmt)
  514. return None
  515. return ModelConfigData(forbid_extra=forbid_extra)
  516. if lhs.name == 'alias_generator':
  517. has_alias_generator = True
  518. if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None':
  519. has_alias_generator = False
  520. return ModelConfigData(has_alias_generator=has_alias_generator)
  521. if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'):
  522. return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'})
  523. error_invalid_config_value(lhs.name, self._ctx.api, substmt)
  524. return None
  525. @staticmethod
  526. def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool:
  527. """
  528. Returns a boolean indicating whether the field defined in `stmt` is a required field.
  529. """
  530. expr = stmt.rvalue
  531. if isinstance(expr, TempNode):
  532. # TempNode means annotation-only, so only non-required if Optional
  533. value_type = get_proper_type(cls.info[lhs.name].type)
  534. return not PydanticModelTransformer.type_has_implicit_default(value_type)
  535. if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
  536. # The "default value" is a call to `Field`; at this point, the field is
  537. # only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory
  538. # is specified.
  539. for arg, name in zip(expr.args, expr.arg_names):
  540. # If name is None, then this arg is the default because it is the only positional argument.
  541. if name is None or name == 'default':
  542. return arg.__class__ is EllipsisExpr
  543. if name == 'default_factory':
  544. return False
  545. # In this case, default and default_factory are not specified, so we need to look at the annotation
  546. value_type = get_proper_type(cls.info[lhs.name].type)
  547. return not PydanticModelTransformer.type_has_implicit_default(value_type)
  548. # Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
  549. return isinstance(expr, EllipsisExpr)
  550. @staticmethod
  551. def type_has_implicit_default(type_: Optional[ProperType]) -> bool:
  552. """
  553. Returns True if the passed type will be given an implicit default value.
  554. In pydantic v1, this is the case for Optional types and Any (with default value None).
  555. """
  556. if isinstance(type_, AnyType):
  557. # Annotated as Any
  558. return True
  559. if isinstance(type_, UnionType) and any(
  560. isinstance(item, NoneType) or isinstance(item, AnyType) for item in type_.items
  561. ):
  562. # Annotated as Optional, or otherwise having NoneType or AnyType in the union
  563. return True
  564. return False
  565. @staticmethod
  566. def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]:
  567. """
  568. Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
  569. `has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal.
  570. If `has_dynamic_alias` is True, `alias` will be None.
  571. """
  572. expr = stmt.rvalue
  573. if isinstance(expr, TempNode):
  574. # TempNode means annotation-only
  575. return None, False
  576. if not (
  577. isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
  578. ):
  579. # Assigned value is not a call to pydantic.fields.Field
  580. return None, False
  581. for i, arg_name in enumerate(expr.arg_names):
  582. if arg_name != 'alias':
  583. continue
  584. arg = expr.args[i]
  585. if isinstance(arg, StrExpr):
  586. return arg.value, False
  587. else:
  588. return None, True
  589. return None, False
  590. def get_field_arguments(
  591. self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool
  592. ) -> List[Argument]:
  593. """
  594. Helper function used during the construction of the `__init__` and `construct` method signatures.
  595. Returns a list of mypy Argument instances for use in the generated signatures.
  596. """
  597. info = self._ctx.cls.info
  598. arguments = [
  599. field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias)
  600. for field in fields
  601. if not (use_alias and field.has_dynamic_alias)
  602. ]
  603. return arguments
  604. def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool:
  605. """
  606. Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature
  607. We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
  608. *unless* a required dynamic alias is present (since then we can't determine a valid signature).
  609. """
  610. if not config.allow_population_by_field_name:
  611. if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
  612. return False
  613. if config.forbid_extra:
  614. return True
  615. return self.plugin_config.init_forbid_extra
  616. @staticmethod
  617. def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool:
  618. """
  619. Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be
  620. determined during static analysis.
  621. """
  622. for field in fields:
  623. if field.has_dynamic_alias:
  624. return True
  625. if has_alias_generator:
  626. for field in fields:
  627. if field.alias is None:
  628. return True
  629. return False
  630. class PydanticModelField:
  631. def __init__(
  632. self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int
  633. ):
  634. self.name = name
  635. self.is_required = is_required
  636. self.alias = alias
  637. self.has_dynamic_alias = has_dynamic_alias
  638. self.line = line
  639. self.column = column
  640. def to_var(self, info: TypeInfo, use_alias: bool) -> Var:
  641. name = self.name
  642. if use_alias and self.alias is not None:
  643. name = self.alias
  644. return Var(name, info[self.name].type)
  645. def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
  646. if typed and info[self.name].type is not None:
  647. type_annotation = info[self.name].type
  648. else:
  649. type_annotation = AnyType(TypeOfAny.explicit)
  650. return Argument(
  651. variable=self.to_var(info, use_alias),
  652. type_annotation=type_annotation,
  653. initializer=None,
  654. kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED,
  655. )
  656. def serialize(self) -> JsonDict:
  657. return self.__dict__
  658. @classmethod
  659. def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField':
  660. return cls(**data)
  661. class ModelConfigData:
  662. def __init__(
  663. self,
  664. forbid_extra: Optional[bool] = None,
  665. allow_mutation: Optional[bool] = None,
  666. frozen: Optional[bool] = None,
  667. orm_mode: Optional[bool] = None,
  668. allow_population_by_field_name: Optional[bool] = None,
  669. has_alias_generator: Optional[bool] = None,
  670. ):
  671. self.forbid_extra = forbid_extra
  672. self.allow_mutation = allow_mutation
  673. self.frozen = frozen
  674. self.orm_mode = orm_mode
  675. self.allow_population_by_field_name = allow_population_by_field_name
  676. self.has_alias_generator = has_alias_generator
  677. def set_values_dict(self) -> Dict[str, Any]:
  678. return {k: v for k, v in self.__dict__.items() if v is not None}
  679. def update(self, config: Optional['ModelConfigData']) -> None:
  680. if config is None:
  681. return
  682. for k, v in config.set_values_dict().items():
  683. setattr(self, k, v)
  684. def setdefault(self, key: str, value: Any) -> None:
  685. if getattr(self, key) is None:
  686. setattr(self, key, value)
  687. ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic')
  688. ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
  689. ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
  690. ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
  691. ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
  692. ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
  693. def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
  694. api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM)
  695. def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None:
  696. api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG)
  697. def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
  698. api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS)
  699. def error_unexpected_behavior(
  700. detail: str, api: Union[CheckerPluginInterface, SemanticAnalyzerPluginInterface], context: Context
  701. ) -> None: # pragma: no cover
  702. # Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path
  703. link = 'https://github.com/pydantic/pydantic/issues/new/choose'
  704. full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n'
  705. full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
  706. api.fail(full_message, context, code=ERROR_UNEXPECTED)
  707. def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
  708. api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
  709. def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
  710. api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
  711. def add_method(
  712. ctx: ClassDefContext,
  713. name: str,
  714. args: List[Argument],
  715. return_type: Type,
  716. self_type: Optional[Type] = None,
  717. tvar_def: Optional[TypeVarDef] = None,
  718. is_classmethod: bool = False,
  719. is_new: bool = False,
  720. # is_staticmethod: bool = False,
  721. ) -> None:
  722. """
  723. Adds a new method to a class.
  724. This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged
  725. """
  726. info = ctx.cls.info
  727. # First remove any previously generated methods with the same name
  728. # to avoid clashes and problems in the semantic analyzer.
  729. if name in info.names:
  730. sym = info.names[name]
  731. if sym.plugin_generated and isinstance(sym.node, FuncDef):
  732. ctx.cls.defs.body.remove(sym.node) # pragma: no cover
  733. self_type = self_type or fill_typevars(info)
  734. if is_classmethod or is_new:
  735. first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)]
  736. # elif is_staticmethod:
  737. # first = []
  738. else:
  739. self_type = self_type or fill_typevars(info)
  740. first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
  741. args = first + args
  742. arg_types, arg_names, arg_kinds = [], [], []
  743. for arg in args:
  744. assert arg.type_annotation, 'All arguments must be fully typed.'
  745. arg_types.append(arg.type_annotation)
  746. arg_names.append(get_name(arg.variable))
  747. arg_kinds.append(arg.kind)
  748. function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
  749. signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
  750. if tvar_def:
  751. signature.variables = [tvar_def]
  752. func = FuncDef(name, args, Block([PassStmt()]))
  753. func.info = info
  754. func.type = set_callable_name(signature, func)
  755. func.is_class = is_classmethod
  756. # func.is_static = is_staticmethod
  757. func._fullname = get_fullname(info) + '.' + name
  758. func.line = info.line
  759. # NOTE: we would like the plugin generated node to dominate, but we still
  760. # need to keep any existing definitions so they get semantically analyzed.
  761. if name in info.names:
  762. # Get a nice unique name instead.
  763. r_name = get_unique_redefinition_name(name, info.names)
  764. info.names[r_name] = info.names[name]
  765. if is_classmethod: # or is_staticmethod:
  766. func.is_decorated = True
  767. v = Var(name, func.type)
  768. v.info = info
  769. v._fullname = func._fullname
  770. # if is_classmethod:
  771. v.is_classmethod = True
  772. dec = Decorator(func, [NameExpr('classmethod')], v)
  773. # else:
  774. # v.is_staticmethod = True
  775. # dec = Decorator(func, [NameExpr('staticmethod')], v)
  776. dec.line = info.line
  777. sym = SymbolTableNode(MDEF, dec)
  778. else:
  779. sym = SymbolTableNode(MDEF, func)
  780. sym.plugin_generated = True
  781. info.names[name] = sym
  782. info.defn.defs.body.append(func)
  783. def get_fullname(x: Union[FuncBase, SymbolNode]) -> str:
  784. """
  785. Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
  786. """
  787. fn = x.fullname
  788. if callable(fn): # pragma: no cover
  789. return fn()
  790. return fn
  791. def get_name(x: Union[FuncBase, SymbolNode]) -> str:
  792. """
  793. Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
  794. """
  795. fn = x.name
  796. if callable(fn): # pragma: no cover
  797. return fn()
  798. return fn
  799. def parse_toml(config_file: str) -> Optional[Dict[str, Any]]:
  800. if not config_file.endswith('.toml'):
  801. return None
  802. read_mode = 'rb'
  803. if sys.version_info >= (3, 11):
  804. import tomllib as toml_
  805. else:
  806. try:
  807. import tomli as toml_
  808. except ImportError:
  809. # older versions of mypy have toml as a dependency, not tomli
  810. read_mode = 'r'
  811. try:
  812. import toml as toml_ # type: ignore[no-redef]
  813. except ImportError: # pragma: no cover
  814. import warnings
  815. warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.')
  816. return None
  817. with open(config_file, read_mode) as rf:
  818. return toml_.load(rf) # type: ignore[arg-type]