You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

770 line
27 KiB

  1. import dataclasses
  2. import inspect
  3. from contextlib import contextmanager
  4. from copy import deepcopy
  5. from typing import (
  6. Any,
  7. Callable,
  8. Coroutine,
  9. Dict,
  10. List,
  11. Mapping,
  12. Optional,
  13. Sequence,
  14. Tuple,
  15. Type,
  16. Union,
  17. cast,
  18. )
  19. import anyio
  20. from fastapi import params
  21. from fastapi.concurrency import (
  22. AsyncExitStack,
  23. asynccontextmanager,
  24. contextmanager_in_threadpool,
  25. )
  26. from fastapi.dependencies.models import Dependant, SecurityRequirement
  27. from fastapi.logger import logger
  28. from fastapi.security.base import SecurityBase
  29. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  30. from fastapi.security.open_id_connect_url import OpenIdConnect
  31. from fastapi.utils import create_response_field, get_path_param_names
  32. from pydantic import BaseModel, create_model
  33. from pydantic.error_wrappers import ErrorWrapper
  34. from pydantic.errors import MissingError
  35. from pydantic.fields import (
  36. SHAPE_LIST,
  37. SHAPE_SEQUENCE,
  38. SHAPE_SET,
  39. SHAPE_SINGLETON,
  40. SHAPE_TUPLE,
  41. SHAPE_TUPLE_ELLIPSIS,
  42. FieldInfo,
  43. ModelField,
  44. Required,
  45. )
  46. from pydantic.schema import get_annotation_from_field_info
  47. from pydantic.typing import ForwardRef, evaluate_forwardref
  48. from pydantic.utils import lenient_issubclass
  49. from starlette.background import BackgroundTasks
  50. from starlette.concurrency import run_in_threadpool
  51. from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
  52. from starlette.requests import HTTPConnection, Request
  53. from starlette.responses import Response
  54. from starlette.websockets import WebSocket
  55. sequence_shapes = {
  56. SHAPE_LIST,
  57. SHAPE_SET,
  58. SHAPE_TUPLE,
  59. SHAPE_SEQUENCE,
  60. SHAPE_TUPLE_ELLIPSIS,
  61. }
  62. sequence_types = (list, set, tuple)
  63. sequence_shape_to_type = {
  64. SHAPE_LIST: list,
  65. SHAPE_SET: set,
  66. SHAPE_TUPLE: tuple,
  67. SHAPE_SEQUENCE: list,
  68. SHAPE_TUPLE_ELLIPSIS: list,
  69. }
  70. multipart_not_installed_error = (
  71. 'Form data requires "python-multipart" to be installed. \n'
  72. 'You can install "python-multipart" with: \n\n'
  73. "pip install python-multipart\n"
  74. )
  75. multipart_incorrect_install_error = (
  76. 'Form data requires "python-multipart" to be installed. '
  77. 'It seems you installed "multipart" instead. \n'
  78. 'You can remove "multipart" with: \n\n'
  79. "pip uninstall multipart\n\n"
  80. 'And then install "python-multipart" with: \n\n'
  81. "pip install python-multipart\n"
  82. )
  83. def check_file_field(field: ModelField) -> None:
  84. field_info = field.field_info
  85. if isinstance(field_info, params.Form):
  86. try:
  87. # __version__ is available in both multiparts, and can be mocked
  88. from multipart import __version__ # type: ignore
  89. assert __version__
  90. try:
  91. # parse_options_header is only available in the right multipart
  92. from multipart.multipart import parse_options_header # type: ignore
  93. assert parse_options_header
  94. except ImportError:
  95. logger.error(multipart_incorrect_install_error)
  96. raise RuntimeError(multipart_incorrect_install_error)
  97. except ImportError:
  98. logger.error(multipart_not_installed_error)
  99. raise RuntimeError(multipart_not_installed_error)
  100. def get_param_sub_dependant(
  101. *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
  102. ) -> Dependant:
  103. depends: params.Depends = param.default
  104. if depends.dependency:
  105. dependency = depends.dependency
  106. else:
  107. dependency = param.annotation
  108. return get_sub_dependant(
  109. depends=depends,
  110. dependency=dependency,
  111. path=path,
  112. name=param.name,
  113. security_scopes=security_scopes,
  114. )
  115. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  116. assert callable(
  117. depends.dependency
  118. ), "A parameter-less dependency must have a callable dependency"
  119. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  120. def get_sub_dependant(
  121. *,
  122. depends: params.Depends,
  123. dependency: Callable[..., Any],
  124. path: str,
  125. name: Optional[str] = None,
  126. security_scopes: Optional[List[str]] = None,
  127. ) -> Dependant:
  128. security_requirement = None
  129. security_scopes = security_scopes or []
  130. if isinstance(depends, params.Security):
  131. dependency_scopes = depends.scopes
  132. security_scopes.extend(dependency_scopes)
  133. if isinstance(dependency, SecurityBase):
  134. use_scopes: List[str] = []
  135. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  136. use_scopes = security_scopes
  137. security_requirement = SecurityRequirement(
  138. security_scheme=dependency, scopes=use_scopes
  139. )
  140. sub_dependant = get_dependant(
  141. path=path,
  142. call=dependency,
  143. name=name,
  144. security_scopes=security_scopes,
  145. use_cache=depends.use_cache,
  146. )
  147. if security_requirement:
  148. sub_dependant.security_requirements.append(security_requirement)
  149. sub_dependant.security_scopes = security_scopes
  150. return sub_dependant
  151. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  152. def get_flat_dependant(
  153. dependant: Dependant,
  154. *,
  155. skip_repeats: bool = False,
  156. visited: Optional[List[CacheKey]] = None,
  157. ) -> Dependant:
  158. if visited is None:
  159. visited = []
  160. visited.append(dependant.cache_key)
  161. flat_dependant = Dependant(
  162. path_params=dependant.path_params.copy(),
  163. query_params=dependant.query_params.copy(),
  164. header_params=dependant.header_params.copy(),
  165. cookie_params=dependant.cookie_params.copy(),
  166. body_params=dependant.body_params.copy(),
  167. security_schemes=dependant.security_requirements.copy(),
  168. use_cache=dependant.use_cache,
  169. path=dependant.path,
  170. )
  171. for sub_dependant in dependant.dependencies:
  172. if skip_repeats and sub_dependant.cache_key in visited:
  173. continue
  174. flat_sub = get_flat_dependant(
  175. sub_dependant, skip_repeats=skip_repeats, visited=visited
  176. )
  177. flat_dependant.path_params.extend(flat_sub.path_params)
  178. flat_dependant.query_params.extend(flat_sub.query_params)
  179. flat_dependant.header_params.extend(flat_sub.header_params)
  180. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  181. flat_dependant.body_params.extend(flat_sub.body_params)
  182. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  183. return flat_dependant
  184. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  185. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  186. return (
  187. flat_dependant.path_params
  188. + flat_dependant.query_params
  189. + flat_dependant.header_params
  190. + flat_dependant.cookie_params
  191. )
  192. def is_scalar_field(field: ModelField) -> bool:
  193. field_info = field.field_info
  194. if not (
  195. field.shape == SHAPE_SINGLETON
  196. and not lenient_issubclass(field.type_, BaseModel)
  197. and not lenient_issubclass(field.type_, sequence_types + (dict,))
  198. and not dataclasses.is_dataclass(field.type_)
  199. and not isinstance(field_info, params.Body)
  200. ):
  201. return False
  202. if field.sub_fields:
  203. if not all(is_scalar_field(f) for f in field.sub_fields):
  204. return False
  205. return True
  206. def is_scalar_sequence_field(field: ModelField) -> bool:
  207. if (field.shape in sequence_shapes) and not lenient_issubclass(
  208. field.type_, BaseModel
  209. ):
  210. if field.sub_fields is not None:
  211. for sub_field in field.sub_fields:
  212. if not is_scalar_field(sub_field):
  213. return False
  214. return True
  215. if lenient_issubclass(field.type_, sequence_types):
  216. return True
  217. return False
  218. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  219. signature = inspect.signature(call)
  220. globalns = getattr(call, "__globals__", {})
  221. typed_params = [
  222. inspect.Parameter(
  223. name=param.name,
  224. kind=param.kind,
  225. default=param.default,
  226. annotation=get_typed_annotation(param, globalns),
  227. )
  228. for param in signature.parameters.values()
  229. ]
  230. typed_signature = inspect.Signature(typed_params)
  231. return typed_signature
  232. def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
  233. annotation = param.annotation
  234. if isinstance(annotation, str):
  235. annotation = ForwardRef(annotation)
  236. annotation = evaluate_forwardref(annotation, globalns, globalns)
  237. return annotation
  238. def get_dependant(
  239. *,
  240. path: str,
  241. call: Callable[..., Any],
  242. name: Optional[str] = None,
  243. security_scopes: Optional[List[str]] = None,
  244. use_cache: bool = True,
  245. ) -> Dependant:
  246. path_param_names = get_path_param_names(path)
  247. endpoint_signature = get_typed_signature(call)
  248. signature_params = endpoint_signature.parameters
  249. dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
  250. for param_name, param in signature_params.items():
  251. if isinstance(param.default, params.Depends):
  252. sub_dependant = get_param_sub_dependant(
  253. param=param, path=path, security_scopes=security_scopes
  254. )
  255. dependant.dependencies.append(sub_dependant)
  256. continue
  257. if add_non_field_param_to_dependency(param=param, dependant=dependant):
  258. continue
  259. param_field = get_param_field(
  260. param=param, default_field_info=params.Query, param_name=param_name
  261. )
  262. if param_name in path_param_names:
  263. assert is_scalar_field(
  264. field=param_field
  265. ), "Path params must be of one of the supported types"
  266. if isinstance(param.default, params.Path):
  267. ignore_default = False
  268. else:
  269. ignore_default = True
  270. param_field = get_param_field(
  271. param=param,
  272. param_name=param_name,
  273. default_field_info=params.Path,
  274. force_type=params.ParamTypes.path,
  275. ignore_default=ignore_default,
  276. )
  277. add_param_to_fields(field=param_field, dependant=dependant)
  278. elif is_scalar_field(field=param_field):
  279. add_param_to_fields(field=param_field, dependant=dependant)
  280. elif isinstance(
  281. param.default, (params.Query, params.Header)
  282. ) and is_scalar_sequence_field(param_field):
  283. add_param_to_fields(field=param_field, dependant=dependant)
  284. else:
  285. field_info = param_field.field_info
  286. assert isinstance(
  287. field_info, params.Body
  288. ), f"Param: {param_field.name} can only be a request body, using Body(...)"
  289. dependant.body_params.append(param_field)
  290. return dependant
  291. def add_non_field_param_to_dependency(
  292. *, param: inspect.Parameter, dependant: Dependant
  293. ) -> Optional[bool]:
  294. if lenient_issubclass(param.annotation, Request):
  295. dependant.request_param_name = param.name
  296. return True
  297. elif lenient_issubclass(param.annotation, WebSocket):
  298. dependant.websocket_param_name = param.name
  299. return True
  300. elif lenient_issubclass(param.annotation, HTTPConnection):
  301. dependant.http_connection_param_name = param.name
  302. return True
  303. elif lenient_issubclass(param.annotation, Response):
  304. dependant.response_param_name = param.name
  305. return True
  306. elif lenient_issubclass(param.annotation, BackgroundTasks):
  307. dependant.background_tasks_param_name = param.name
  308. return True
  309. elif lenient_issubclass(param.annotation, SecurityScopes):
  310. dependant.security_scopes_param_name = param.name
  311. return True
  312. return None
  313. def get_param_field(
  314. *,
  315. param: inspect.Parameter,
  316. param_name: str,
  317. default_field_info: Type[params.Param] = params.Param,
  318. force_type: Optional[params.ParamTypes] = None,
  319. ignore_default: bool = False,
  320. ) -> ModelField:
  321. default_value = Required
  322. had_schema = False
  323. if not param.default == param.empty and ignore_default is False:
  324. default_value = param.default
  325. if isinstance(default_value, FieldInfo):
  326. had_schema = True
  327. field_info = default_value
  328. default_value = field_info.default
  329. if (
  330. isinstance(field_info, params.Param)
  331. and getattr(field_info, "in_", None) is None
  332. ):
  333. field_info.in_ = default_field_info.in_
  334. if force_type:
  335. field_info.in_ = force_type # type: ignore
  336. else:
  337. field_info = default_field_info(default_value)
  338. required = default_value == Required
  339. annotation: Any = Any
  340. if not param.annotation == param.empty:
  341. annotation = param.annotation
  342. annotation = get_annotation_from_field_info(annotation, field_info, param_name)
  343. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  344. alias = param.name.replace("_", "-")
  345. else:
  346. alias = field_info.alias or param.name
  347. field = create_response_field(
  348. name=param.name,
  349. type_=annotation,
  350. default=None if required else default_value,
  351. alias=alias,
  352. required=required,
  353. field_info=field_info,
  354. )
  355. field.required = required
  356. if not had_schema and not is_scalar_field(field=field):
  357. field.field_info = params.Body(field_info.default)
  358. return field
  359. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  360. field_info = cast(params.Param, field.field_info)
  361. if field_info.in_ == params.ParamTypes.path:
  362. dependant.path_params.append(field)
  363. elif field_info.in_ == params.ParamTypes.query:
  364. dependant.query_params.append(field)
  365. elif field_info.in_ == params.ParamTypes.header:
  366. dependant.header_params.append(field)
  367. else:
  368. assert (
  369. field_info.in_ == params.ParamTypes.cookie
  370. ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
  371. dependant.cookie_params.append(field)
  372. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  373. if inspect.isroutine(call):
  374. return inspect.iscoroutinefunction(call)
  375. if inspect.isclass(call):
  376. return False
  377. call = getattr(call, "__call__", None)
  378. return inspect.iscoroutinefunction(call)
  379. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  380. if inspect.isasyncgenfunction(call):
  381. return True
  382. call = getattr(call, "__call__", None)
  383. return inspect.isasyncgenfunction(call)
  384. def is_gen_callable(call: Callable[..., Any]) -> bool:
  385. if inspect.isgeneratorfunction(call):
  386. return True
  387. call = getattr(call, "__call__", None)
  388. return inspect.isgeneratorfunction(call)
  389. async def solve_generator(
  390. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  391. ) -> Any:
  392. if is_gen_callable(call):
  393. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  394. elif is_async_gen_callable(call):
  395. cm = asynccontextmanager(call)(**sub_values)
  396. return await stack.enter_async_context(cm)
  397. async def solve_dependencies(
  398. *,
  399. request: Union[Request, WebSocket],
  400. dependant: Dependant,
  401. body: Optional[Union[Dict[str, Any], FormData]] = None,
  402. background_tasks: Optional[BackgroundTasks] = None,
  403. response: Optional[Response] = None,
  404. dependency_overrides_provider: Optional[Any] = None,
  405. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  406. ) -> Tuple[
  407. Dict[str, Any],
  408. List[ErrorWrapper],
  409. Optional[BackgroundTasks],
  410. Response,
  411. Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
  412. ]:
  413. values: Dict[str, Any] = {}
  414. errors: List[ErrorWrapper] = []
  415. response = response or Response(
  416. content=None,
  417. status_code=None, # type: ignore
  418. headers=None, # type: ignore # in Starlette
  419. media_type=None, # type: ignore # in Starlette
  420. background=None, # type: ignore # in Starlette
  421. )
  422. dependency_cache = dependency_cache or {}
  423. sub_dependant: Dependant
  424. for sub_dependant in dependant.dependencies:
  425. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  426. sub_dependant.cache_key = cast(
  427. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  428. )
  429. call = sub_dependant.call
  430. use_sub_dependant = sub_dependant
  431. if (
  432. dependency_overrides_provider
  433. and dependency_overrides_provider.dependency_overrides
  434. ):
  435. original_call = sub_dependant.call
  436. call = getattr(
  437. dependency_overrides_provider, "dependency_overrides", {}
  438. ).get(original_call, original_call)
  439. use_path: str = sub_dependant.path # type: ignore
  440. use_sub_dependant = get_dependant(
  441. path=use_path,
  442. call=call,
  443. name=sub_dependant.name,
  444. security_scopes=sub_dependant.security_scopes,
  445. )
  446. use_sub_dependant.security_scopes = sub_dependant.security_scopes
  447. solved_result = await solve_dependencies(
  448. request=request,
  449. dependant=use_sub_dependant,
  450. body=body,
  451. background_tasks=background_tasks,
  452. response=response,
  453. dependency_overrides_provider=dependency_overrides_provider,
  454. dependency_cache=dependency_cache,
  455. )
  456. (
  457. sub_values,
  458. sub_errors,
  459. background_tasks,
  460. _, # the subdependency returns the same response we have
  461. sub_dependency_cache,
  462. ) = solved_result
  463. dependency_cache.update(sub_dependency_cache)
  464. if sub_errors:
  465. errors.extend(sub_errors)
  466. continue
  467. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  468. solved = dependency_cache[sub_dependant.cache_key]
  469. elif is_gen_callable(call) or is_async_gen_callable(call):
  470. stack = request.scope.get("fastapi_astack")
  471. assert isinstance(stack, AsyncExitStack)
  472. solved = await solve_generator(
  473. call=call, stack=stack, sub_values=sub_values
  474. )
  475. elif is_coroutine_callable(call):
  476. solved = await call(**sub_values)
  477. else:
  478. solved = await run_in_threadpool(call, **sub_values)
  479. if sub_dependant.name is not None:
  480. values[sub_dependant.name] = solved
  481. if sub_dependant.cache_key not in dependency_cache:
  482. dependency_cache[sub_dependant.cache_key] = solved
  483. path_values, path_errors = request_params_to_args(
  484. dependant.path_params, request.path_params
  485. )
  486. query_values, query_errors = request_params_to_args(
  487. dependant.query_params, request.query_params
  488. )
  489. header_values, header_errors = request_params_to_args(
  490. dependant.header_params, request.headers
  491. )
  492. cookie_values, cookie_errors = request_params_to_args(
  493. dependant.cookie_params, request.cookies
  494. )
  495. values.update(path_values)
  496. values.update(query_values)
  497. values.update(header_values)
  498. values.update(cookie_values)
  499. errors += path_errors + query_errors + header_errors + cookie_errors
  500. if dependant.body_params:
  501. (
  502. body_values,
  503. body_errors,
  504. ) = await request_body_to_args( # body_params checked above
  505. required_params=dependant.body_params, received_body=body
  506. )
  507. values.update(body_values)
  508. errors.extend(body_errors)
  509. if dependant.http_connection_param_name:
  510. values[dependant.http_connection_param_name] = request
  511. if dependant.request_param_name and isinstance(request, Request):
  512. values[dependant.request_param_name] = request
  513. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  514. values[dependant.websocket_param_name] = request
  515. if dependant.background_tasks_param_name:
  516. if background_tasks is None:
  517. background_tasks = BackgroundTasks()
  518. values[dependant.background_tasks_param_name] = background_tasks
  519. if dependant.response_param_name:
  520. values[dependant.response_param_name] = response
  521. if dependant.security_scopes_param_name:
  522. values[dependant.security_scopes_param_name] = SecurityScopes(
  523. scopes=dependant.security_scopes
  524. )
  525. return values, errors, background_tasks, response, dependency_cache
  526. def request_params_to_args(
  527. required_params: Sequence[ModelField],
  528. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  529. ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
  530. values = {}
  531. errors = []
  532. for field in required_params:
  533. if is_scalar_sequence_field(field) and isinstance(
  534. received_params, (QueryParams, Headers)
  535. ):
  536. value = received_params.getlist(field.alias) or field.default
  537. else:
  538. value = received_params.get(field.alias)
  539. field_info = field.field_info
  540. assert isinstance(
  541. field_info, params.Param
  542. ), "Params must be subclasses of Param"
  543. if value is None:
  544. if field.required:
  545. errors.append(
  546. ErrorWrapper(
  547. MissingError(), loc=(field_info.in_.value, field.alias)
  548. )
  549. )
  550. else:
  551. values[field.name] = deepcopy(field.default)
  552. continue
  553. v_, errors_ = field.validate(
  554. value, values, loc=(field_info.in_.value, field.alias)
  555. )
  556. if isinstance(errors_, ErrorWrapper):
  557. errors.append(errors_)
  558. elif isinstance(errors_, list):
  559. errors.extend(errors_)
  560. else:
  561. values[field.name] = v_
  562. return values, errors
  563. async def request_body_to_args(
  564. required_params: List[ModelField],
  565. received_body: Optional[Union[Dict[str, Any], FormData]],
  566. ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
  567. values = {}
  568. errors = []
  569. if required_params:
  570. field = required_params[0]
  571. field_info = field.field_info
  572. embed = getattr(field_info, "embed", None)
  573. field_alias_omitted = len(required_params) == 1 and not embed
  574. if field_alias_omitted:
  575. received_body = {field.alias: received_body}
  576. for field in required_params:
  577. loc: Tuple[str, ...]
  578. if field_alias_omitted:
  579. loc = ("body",)
  580. else:
  581. loc = ("body", field.alias)
  582. value: Optional[Any] = None
  583. if received_body is not None:
  584. if (
  585. field.shape in sequence_shapes or field.type_ in sequence_types
  586. ) and isinstance(received_body, FormData):
  587. value = received_body.getlist(field.alias)
  588. else:
  589. try:
  590. value = received_body.get(field.alias)
  591. except AttributeError:
  592. errors.append(get_missing_field_error(loc))
  593. continue
  594. if (
  595. value is None
  596. or (isinstance(field_info, params.Form) and value == "")
  597. or (
  598. isinstance(field_info, params.Form)
  599. and field.shape in sequence_shapes
  600. and len(value) == 0
  601. )
  602. ):
  603. if field.required:
  604. errors.append(get_missing_field_error(loc))
  605. else:
  606. values[field.name] = deepcopy(field.default)
  607. continue
  608. if (
  609. isinstance(field_info, params.File)
  610. and lenient_issubclass(field.type_, bytes)
  611. and isinstance(value, UploadFile)
  612. ):
  613. value = await value.read()
  614. elif (
  615. field.shape in sequence_shapes
  616. and isinstance(field_info, params.File)
  617. and lenient_issubclass(field.type_, bytes)
  618. and isinstance(value, sequence_types)
  619. ):
  620. results: List[Union[bytes, str]] = []
  621. async def process_fn(
  622. fn: Callable[[], Coroutine[Any, Any, Any]]
  623. ) -> None:
  624. result = await fn()
  625. results.append(result)
  626. async with anyio.create_task_group() as tg:
  627. for sub_value in value:
  628. tg.start_soon(process_fn, sub_value.read)
  629. value = sequence_shape_to_type[field.shape](results)
  630. v_, errors_ = field.validate(value, values, loc=loc)
  631. if isinstance(errors_, ErrorWrapper):
  632. errors.append(errors_)
  633. elif isinstance(errors_, list):
  634. errors.extend(errors_)
  635. else:
  636. values[field.name] = v_
  637. return values, errors
  638. def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
  639. missing_field_error = ErrorWrapper(MissingError(), loc=loc)
  640. return missing_field_error
  641. def get_schema_compatible_field(*, field: ModelField) -> ModelField:
  642. out_field = field
  643. if lenient_issubclass(field.type_, UploadFile):
  644. use_type: type = bytes
  645. if field.shape in sequence_shapes:
  646. use_type = List[bytes]
  647. out_field = create_response_field(
  648. name=field.name,
  649. type_=use_type,
  650. class_validators=field.class_validators,
  651. model_config=field.model_config,
  652. default=field.default,
  653. required=field.required,
  654. alias=field.alias,
  655. field_info=field.field_info,
  656. )
  657. return out_field
  658. def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
  659. flat_dependant = get_flat_dependant(dependant)
  660. if not flat_dependant.body_params:
  661. return None
  662. first_param = flat_dependant.body_params[0]
  663. field_info = first_param.field_info
  664. embed = getattr(field_info, "embed", None)
  665. body_param_names_set = {param.name for param in flat_dependant.body_params}
  666. if len(body_param_names_set) == 1 and not embed:
  667. final_field = get_schema_compatible_field(field=first_param)
  668. check_file_field(final_field)
  669. return final_field
  670. # If one field requires to embed, all have to be embedded
  671. # in case a sub-dependency is evaluated with a single unique body field
  672. # That is combined (embedded) with other body fields
  673. for param in flat_dependant.body_params:
  674. setattr(param.field_info, "embed", True)
  675. model_name = "Body_" + name
  676. BodyModel: Type[BaseModel] = create_model(model_name)
  677. for f in flat_dependant.body_params:
  678. BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
  679. required = any(True for f in flat_dependant.body_params if f.required)
  680. BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
  681. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  682. BodyFieldInfo: Type[params.Body] = params.File
  683. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  684. BodyFieldInfo = params.Form
  685. else:
  686. BodyFieldInfo = params.Body
  687. body_param_media_types = [
  688. getattr(f.field_info, "media_type")
  689. for f in flat_dependant.body_params
  690. if isinstance(f.field_info, params.Body)
  691. ]
  692. if len(set(body_param_media_types)) == 1:
  693. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  694. final_field = create_response_field(
  695. name="body",
  696. type_=BodyModel,
  697. required=required,
  698. alias="body",
  699. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  700. )
  701. check_file_field(final_field)
  702. return final_field