You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

338 lines
13 KiB

  1. import warnings
  2. from collections import ChainMap
  3. from functools import wraps
  4. from itertools import chain
  5. from types import FunctionType
  6. from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
  7. from .errors import ConfigError
  8. from .typing import AnyCallable
  9. from .utils import ROOT_KEY, in_ipython
  10. if TYPE_CHECKING:
  11. from .typing import AnyClassMethod
  12. class Validator:
  13. __slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'
  14. def __init__(
  15. self,
  16. func: AnyCallable,
  17. pre: bool = False,
  18. each_item: bool = False,
  19. always: bool = False,
  20. check_fields: bool = False,
  21. skip_on_failure: bool = False,
  22. ):
  23. self.func = func
  24. self.pre = pre
  25. self.each_item = each_item
  26. self.always = always
  27. self.check_fields = check_fields
  28. self.skip_on_failure = skip_on_failure
  29. if TYPE_CHECKING:
  30. from inspect import Signature
  31. from .config import BaseConfig
  32. from .fields import ModelField
  33. from .types import ModelOrDc
  34. ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
  35. ValidatorsList = List[ValidatorCallable]
  36. ValidatorListDict = Dict[str, List[Validator]]
  37. _FUNCS: Set[str] = set()
  38. VALIDATOR_CONFIG_KEY = '__validator_config__'
  39. ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__'
  40. def validator(
  41. *fields: str,
  42. pre: bool = False,
  43. each_item: bool = False,
  44. always: bool = False,
  45. check_fields: bool = True,
  46. whole: bool = None,
  47. allow_reuse: bool = False,
  48. ) -> Callable[[AnyCallable], 'AnyClassMethod']:
  49. """
  50. Decorate methods on the class indicating that they should be used to validate fields
  51. :param fields: which field(s) the method should be called on
  52. :param pre: whether or not this validator should be called before the standard validators (else after)
  53. :param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the
  54. whole object
  55. :param always: whether this method and other validators should be called even if the value is missing
  56. :param check_fields: whether to check that the fields actually exist on the model
  57. :param allow_reuse: whether to track and raise an error if another validator refers to the decorated function
  58. """
  59. if not fields:
  60. raise ConfigError('validator with no fields specified')
  61. elif isinstance(fields[0], FunctionType):
  62. raise ConfigError(
  63. "validators should be used with fields and keyword arguments, not bare. " # noqa: Q000
  64. "E.g. usage should be `@validator('<field_name>', ...)`"
  65. )
  66. if whole is not None:
  67. warnings.warn(
  68. 'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead',
  69. DeprecationWarning,
  70. )
  71. assert each_item is False, '"each_item" and "whole" conflict, remove "whole"'
  72. each_item = not whole
  73. def dec(f: AnyCallable) -> 'AnyClassMethod':
  74. f_cls = _prepare_validator(f, allow_reuse)
  75. setattr(
  76. f_cls,
  77. VALIDATOR_CONFIG_KEY,
  78. (
  79. fields,
  80. Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields),
  81. ),
  82. )
  83. return f_cls
  84. return dec
  85. @overload
  86. def root_validator(_func: AnyCallable) -> 'AnyClassMethod':
  87. ...
  88. @overload
  89. def root_validator(
  90. *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
  91. ) -> Callable[[AnyCallable], 'AnyClassMethod']:
  92. ...
  93. def root_validator(
  94. _func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
  95. ) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]:
  96. """
  97. Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
  98. before or after standard model parsing/validation is performed.
  99. """
  100. if _func:
  101. f_cls = _prepare_validator(_func, allow_reuse)
  102. setattr(
  103. f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
  104. )
  105. return f_cls
  106. def dec(f: AnyCallable) -> 'AnyClassMethod':
  107. f_cls = _prepare_validator(f, allow_reuse)
  108. setattr(
  109. f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
  110. )
  111. return f_cls
  112. return dec
  113. def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod':
  114. """
  115. Avoid validators with duplicated names since without this, validators can be overwritten silently
  116. which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False.
  117. """
  118. f_cls = function if isinstance(function, classmethod) else classmethod(function)
  119. if not in_ipython() and not allow_reuse:
  120. ref = f_cls.__func__.__module__ + '.' + f_cls.__func__.__qualname__
  121. if ref in _FUNCS:
  122. raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`')
  123. _FUNCS.add(ref)
  124. return f_cls
  125. class ValidatorGroup:
  126. def __init__(self, validators: 'ValidatorListDict') -> None:
  127. self.validators = validators
  128. self.used_validators = {'*'}
  129. def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
  130. self.used_validators.add(name)
  131. validators = self.validators.get(name, [])
  132. if name != ROOT_KEY:
  133. validators += self.validators.get('*', [])
  134. if validators:
  135. return {v.func.__name__: v for v in validators}
  136. else:
  137. return None
  138. def check_for_unused(self) -> None:
  139. unused_validators = set(
  140. chain.from_iterable(
  141. (v.func.__name__ for v in self.validators[f] if v.check_fields)
  142. for f in (self.validators.keys() - self.used_validators)
  143. )
  144. )
  145. if unused_validators:
  146. fn = ', '.join(unused_validators)
  147. raise ConfigError(
  148. f"Validators defined with incorrect fields: {fn} " # noqa: Q000
  149. f"(use check_fields=False if you're inheriting from the model and intended this)"
  150. )
  151. def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
  152. validators: Dict[str, List[Validator]] = {}
  153. for var_name, value in namespace.items():
  154. validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None)
  155. if validator_config:
  156. fields, v = validator_config
  157. for field in fields:
  158. if field in validators:
  159. validators[field].append(v)
  160. else:
  161. validators[field] = [v]
  162. return validators
  163. def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
  164. from inspect import signature
  165. pre_validators: List[AnyCallable] = []
  166. post_validators: List[Tuple[bool, AnyCallable]] = []
  167. for name, value in namespace.items():
  168. validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
  169. if validator_config:
  170. sig = signature(validator_config.func)
  171. args = list(sig.parameters.keys())
  172. if args[0] == 'self':
  173. raise ConfigError(
  174. f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, '
  175. f'should be: (cls, values).'
  176. )
  177. if len(args) != 2:
  178. raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).')
  179. # check function signature
  180. if validator_config.pre:
  181. pre_validators.append(validator_config.func)
  182. else:
  183. post_validators.append((validator_config.skip_on_failure, validator_config.func))
  184. return pre_validators, post_validators
  185. def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict':
  186. for field, field_validators in base_validators.items():
  187. if field not in validators:
  188. validators[field] = []
  189. validators[field] += field_validators
  190. return validators
  191. def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable':
  192. """
  193. Make a generic function which calls a validator with the right arguments.
  194. Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow,
  195. hence this laborious way of doing things.
  196. It's done like this so validators don't all need **kwargs in their signature, eg. any combination of
  197. the arguments "values", "fields" and/or "config" are permitted.
  198. """
  199. from inspect import signature
  200. sig = signature(validator)
  201. args = list(sig.parameters.keys())
  202. first_arg = args.pop(0)
  203. if first_arg == 'self':
  204. raise ConfigError(
  205. f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, '
  206. f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.'
  207. )
  208. elif first_arg == 'cls':
  209. # assume the second argument is value
  210. return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:])))
  211. else:
  212. # assume the first argument was value which has already been removed
  213. return wraps(validator)(_generic_validator_basic(validator, sig, set(args)))
  214. def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList':
  215. return [make_generic_validator(f) for f in v_funcs if f]
  216. all_kwargs = {'values', 'field', 'config'}
  217. def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
  218. # assume the first argument is value
  219. has_kwargs = False
  220. if 'kwargs' in args:
  221. has_kwargs = True
  222. args -= {'kwargs'}
  223. if not args.issubset(all_kwargs):
  224. raise ConfigError(
  225. f'Invalid signature for validator {validator}: {sig}, should be: '
  226. f'(cls, value, values, config, field), "values", "config" and "field" are all optional.'
  227. )
  228. if has_kwargs:
  229. return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
  230. elif args == set():
  231. return lambda cls, v, values, field, config: validator(cls, v)
  232. elif args == {'values'}:
  233. return lambda cls, v, values, field, config: validator(cls, v, values=values)
  234. elif args == {'field'}:
  235. return lambda cls, v, values, field, config: validator(cls, v, field=field)
  236. elif args == {'config'}:
  237. return lambda cls, v, values, field, config: validator(cls, v, config=config)
  238. elif args == {'values', 'field'}:
  239. return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field)
  240. elif args == {'values', 'config'}:
  241. return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config)
  242. elif args == {'field', 'config'}:
  243. return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config)
  244. else:
  245. # args == {'values', 'field', 'config'}
  246. return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
  247. def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
  248. has_kwargs = False
  249. if 'kwargs' in args:
  250. has_kwargs = True
  251. args -= {'kwargs'}
  252. if not args.issubset(all_kwargs):
  253. raise ConfigError(
  254. f'Invalid signature for validator {validator}: {sig}, should be: '
  255. f'(value, values, config, field), "values", "config" and "field" are all optional.'
  256. )
  257. if has_kwargs:
  258. return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
  259. elif args == set():
  260. return lambda cls, v, values, field, config: validator(v)
  261. elif args == {'values'}:
  262. return lambda cls, v, values, field, config: validator(v, values=values)
  263. elif args == {'field'}:
  264. return lambda cls, v, values, field, config: validator(v, field=field)
  265. elif args == {'config'}:
  266. return lambda cls, v, values, field, config: validator(v, config=config)
  267. elif args == {'values', 'field'}:
  268. return lambda cls, v, values, field, config: validator(v, values=values, field=field)
  269. elif args == {'values', 'config'}:
  270. return lambda cls, v, values, field, config: validator(v, values=values, config=config)
  271. elif args == {'field', 'config'}:
  272. return lambda cls, v, values, field, config: validator(v, field=field, config=config)
  273. else:
  274. # args == {'values', 'field', 'config'}
  275. return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
  276. def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']:
  277. all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__])
  278. return {
  279. k: v
  280. for k, v in all_attributes.items()
  281. if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY)
  282. }