Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

770 řádky
27 KiB

  1. import dataclasses
  2. import inspect
  3. from contextlib import contextmanager
  4. from copy import deepcopy
  5. from typing import (
  6. Any,
  7. Callable,
  8. Coroutine,
  9. Dict,
  10. List,
  11. Mapping,
  12. Optional,
  13. Sequence,
  14. Tuple,
  15. Type,
  16. Union,
  17. cast,
  18. )
  19. import anyio
  20. from fastapi import params
  21. from fastapi.concurrency import (
  22. AsyncExitStack,
  23. asynccontextmanager,
  24. contextmanager_in_threadpool,
  25. )
  26. from fastapi.dependencies.models import Dependant, SecurityRequirement
  27. from fastapi.logger import logger
  28. from fastapi.security.base import SecurityBase
  29. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  30. from fastapi.security.open_id_connect_url import OpenIdConnect
  31. from fastapi.utils import create_response_field, get_path_param_names
  32. from pydantic import BaseModel, create_model
  33. from pydantic.error_wrappers import ErrorWrapper
  34. from pydantic.errors import MissingError
  35. from pydantic.fields import (
  36. SHAPE_LIST,
  37. SHAPE_SEQUENCE,
  38. SHAPE_SET,
  39. SHAPE_SINGLETON,
  40. SHAPE_TUPLE,
  41. SHAPE_TUPLE_ELLIPSIS,
  42. FieldInfo,
  43. ModelField,
  44. Required,
  45. )
  46. from pydantic.schema import get_annotation_from_field_info
  47. from pydantic.typing import ForwardRef, evaluate_forwardref
  48. from pydantic.utils import lenient_issubclass
  49. from starlette.background import BackgroundTasks
  50. from starlette.concurrency import run_in_threadpool
  51. from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
  52. from starlette.requests import HTTPConnection, Request
  53. from starlette.responses import Response
  54. from starlette.websockets import WebSocket
  55. sequence_shapes = {
  56. SHAPE_LIST,
  57. SHAPE_SET,
  58. SHAPE_TUPLE,
  59. SHAPE_SEQUENCE,
  60. SHAPE_TUPLE_ELLIPSIS,
  61. }
  62. sequence_types = (list, set, tuple)
  63. sequence_shape_to_type = {
  64. SHAPE_LIST: list,
  65. SHAPE_SET: set,
  66. SHAPE_TUPLE: tuple,
  67. SHAPE_SEQUENCE: list,
  68. SHAPE_TUPLE_ELLIPSIS: list,
  69. }
  70. multipart_not_installed_error = (
  71. 'Form data requires "python-multipart" to be installed. \n'
  72. 'You can install "python-multipart" with: \n\n'
  73. "pip install python-multipart\n"
  74. )
  75. multipart_incorrect_install_error = (
  76. 'Form data requires "python-multipart" to be installed. '
  77. 'It seems you installed "multipart" instead. \n'
  78. 'You can remove "multipart" with: \n\n'
  79. "pip uninstall multipart\n\n"
  80. 'And then install "python-multipart" with: \n\n'
  81. "pip install python-multipart\n"
  82. )
  83. def check_file_field(field: ModelField) -> None:
  84. field_info = field.field_info
  85. if isinstance(field_info, params.Form):
  86. try:
  87. # __version__ is available in both multiparts, and can be mocked
  88. from multipart import __version__ # type: ignore
  89. assert __version__
  90. try:
  91. # parse_options_header is only available in the right multipart
  92. from multipart.multipart import parse_options_header # type: ignore
  93. assert parse_options_header
  94. except ImportError:
  95. logger.error(multipart_incorrect_install_error)
  96. raise RuntimeError(multipart_incorrect_install_error)
  97. except ImportError:
  98. logger.error(multipart_not_installed_error)
  99. raise RuntimeError(multipart_not_installed_error)
  100. def get_param_sub_dependant(
  101. *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
  102. ) -> Dependant:
  103. depends: params.Depends = param.default
  104. if depends.dependency:
  105. dependency = depends.dependency
  106. else:
  107. dependency = param.annotation
  108. return get_sub_dependant(
  109. depends=depends,
  110. dependency=dependency,
  111. path=path,
  112. name=param.name,
  113. security_scopes=security_scopes,
  114. )
  115. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  116. assert callable(
  117. depends.dependency
  118. ), "A parameter-less dependency must have a callable dependency"
  119. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  120. def get_sub_dependant(
  121. *,
  122. depends: params.Depends,
  123. dependency: Callable[..., Any],
  124. path: str,
  125. name: Optional[str] = None,
  126. security_scopes: Optional[List[str]] = None,
  127. ) -> Dependant:
  128. security_requirement = None
  129. security_scopes = security_scopes or []
  130. if isinstance(depends, params.Security):
  131. dependency_scopes = depends.scopes
  132. security_scopes.extend(dependency_scopes)
  133. if isinstance(dependency, SecurityBase):
  134. use_scopes: List[str] = []
  135. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  136. use_scopes = security_scopes
  137. security_requirement = SecurityRequirement(
  138. security_scheme=dependency, scopes=use_scopes
  139. )
  140. sub_dependant = get_dependant(
  141. path=path,
  142. call=dependency,
  143. name=name,
  144. security_scopes=security_scopes,
  145. use_cache=depends.use_cache,
  146. )
  147. if security_requirement:
  148. sub_dependant.security_requirements.append(security_requirement)
  149. sub_dependant.security_scopes = security_scopes
  150. return sub_dependant
  151. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  152. def get_flat_dependant(
  153. dependant: Dependant,
  154. *,
  155. skip_repeats: bool = False,
  156. visited: Optional[List[CacheKey]] = None,
  157. ) -> Dependant:
  158. if visited is None:
  159. visited = []
  160. visited.append(dependant.cache_key)
  161. flat_dependant = Dependant(
  162. path_params=dependant.path_params.copy(),
  163. query_params=dependant.query_params.copy(),
  164. header_params=dependant.header_params.copy(),
  165. cookie_params=dependant.cookie_params.copy(),
  166. body_params=dependant.body_params.copy(),
  167. security_schemes=dependant.security_requirements.copy(),
  168. use_cache=dependant.use_cache,
  169. path=dependant.path,
  170. )
  171. for sub_dependant in dependant.dependencies:
  172. if skip_repeats and sub_dependant.cache_key in visited:
  173. continue
  174. flat_sub = get_flat_dependant(
  175. sub_dependant, skip_repeats=skip_repeats, visited=visited
  176. )
  177. flat_dependant.path_params.extend(flat_sub.path_params)
  178. flat_dependant.query_params.extend(flat_sub.query_params)
  179. flat_dependant.header_params.extend(flat_sub.header_params)
  180. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  181. flat_dependant.body_params.extend(flat_sub.body_params)
  182. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  183. return flat_dependant
  184. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  185. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  186. return (
  187. flat_dependant.path_params
  188. + flat_dependant.query_params
  189. + flat_dependant.header_params
  190. + flat_dependant.cookie_params
  191. )
  192. def is_scalar_field(field: ModelField) -> bool:
  193. field_info = field.field_info
  194. if not (
  195. field.shape == SHAPE_SINGLETON
  196. and not lenient_issubclass(field.type_, BaseModel)
  197. and not lenient_issubclass(field.type_, sequence_types + (dict,))
  198. and not dataclasses.is_dataclass(field.type_)
  199. and not isinstance(field_info, params.Body)
  200. ):
  201. return False
  202. if field.sub_fields:
  203. if not all(is_scalar_field(f) for f in field.sub_fields):
  204. return False
  205. return True
  206. def is_scalar_sequence_field(field: ModelField) -> bool:
  207. if (field.shape in sequence_shapes) and not lenient_issubclass(
  208. field.type_, BaseModel
  209. ):
  210. if field.sub_fields is not None:
  211. for sub_field in field.sub_fields:
  212. if not is_scalar_field(sub_field):
  213. return False
  214. return True
  215. if lenient_issubclass(field.type_, sequence_types):
  216. return True
  217. return False
  218. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  219. signature = inspect.signature(call)
  220. globalns = getattr(call, "__globals__", {})
  221. typed_params = [
  222. inspect.Parameter(
  223. name=param.name,
  224. kind=param.kind,
  225. default=param.default,
  226. annotation=get_typed_annotation(param, globalns),
  227. )
  228. for param in signature.parameters.values()
  229. ]
  230. typed_signature = inspect.Signature(typed_params)
  231. return typed_signature
  232. def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
  233. annotation = param.annotation
  234. if isinstance(annotation, str):
  235. annotation = ForwardRef(annotation)
  236. annotation = evaluate_forwardref(annotation, globalns, globalns)
  237. return annotation
  238. def get_dependant(
  239. *,
  240. path: str,
  241. call: Callable[..., Any],
  242. name: Optional[str] = None,
  243. security_scopes: Optional[List[str]] = None,
  244. use_cache: bool = True,
  245. ) -> Dependant:
  246. path_param_names = get_path_param_names(path)
  247. endpoint_signature = get_typed_signature(call)
  248. signature_params = endpoint_signature.parameters
  249. dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
  250. for param_name, param in signature_params.items():
  251. if isinstance(param.default, params.Depends):
  252. sub_dependant = get_param_sub_dependant(
  253. param=param, path=path, security_scopes=security_scopes
  254. )
  255. dependant.dependencies.append(sub_dependant)
  256. continue
  257. if add_non_field_param_to_dependency(param=param, dependant=dependant):
  258. continue
  259. param_field = get_param_field(
  260. param=param, default_field_info=params.Query, param_name=param_name
  261. )
  262. if param_name in path_param_names:
  263. assert is_scalar_field(
  264. field=param_field
  265. ), "Path params must be of one of the supported types"
  266. if isinstance(param.default, params.Path):
  267. ignore_default = False
  268. else:
  269. ignore_default = True
  270. param_field = get_param_field(
  271. param=param,
  272. param_name=param_name,
  273. default_field_info=params.Path,
  274. force_type=params.ParamTypes.path,
  275. ignore_default=ignore_default,
  276. )
  277. add_param_to_fields(field=param_field, dependant=dependant)
  278. elif is_scalar_field(field=param_field):
  279. add_param_to_fields(field=param_field, dependant=dependant)
  280. elif isinstance(
  281. param.default, (params.Query, params.Header)
  282. ) and is_scalar_sequence_field(param_field):
  283. add_param_to_fields(field=param_field, dependant=dependant)
  284. else:
  285. field_info = param_field.field_info
  286. assert isinstance(
  287. field_info, params.Body
  288. ), f"Param: {param_field.name} can only be a request body, using Body(...)"
  289. dependant.body_params.append(param_field)
  290. return dependant
  291. def add_non_field_param_to_dependency(
  292. *, param: inspect.Parameter, dependant: Dependant
  293. ) -> Optional[bool]:
  294. if lenient_issubclass(param.annotation, Request):
  295. dependant.request_param_name = param.name
  296. return True
  297. elif lenient_issubclass(param.annotation, WebSocket):
  298. dependant.websocket_param_name = param.name
  299. return True
  300. elif lenient_issubclass(param.annotation, HTTPConnection):
  301. dependant.http_connection_param_name = param.name
  302. return True
  303. elif lenient_issubclass(param.annotation, Response):
  304. dependant.response_param_name = param.name
  305. return True
  306. elif lenient_issubclass(param.annotation, BackgroundTasks):
  307. dependant.background_tasks_param_name = param.name
  308. return True
  309. elif lenient_issubclass(param.annotation, SecurityScopes):
  310. dependant.security_scopes_param_name = param.name
  311. return True
  312. return None
  313. def get_param_field(
  314. *,
  315. param: inspect.Parameter,
  316. param_name: str,
  317. default_field_info: Type[params.Param] = params.Param,
  318. force_type: Optional[params.ParamTypes] = None,
  319. ignore_default: bool = False,
  320. ) -> ModelField:
  321. default_value = Required
  322. had_schema = False
  323. if not param.default == param.empty and ignore_default is False:
  324. default_value = param.default
  325. if isinstance(default_value, FieldInfo):
  326. had_schema = True
  327. field_info = default_value
  328. default_value = field_info.default
  329. if (
  330. isinstance(field_info, params.Param)
  331. and getattr(field_info, "in_", None) is None
  332. ):
  333. field_info.in_ = default_field_info.in_
  334. if force_type:
  335. field_info.in_ = force_type # type: ignore
  336. else:
  337. field_info = default_field_info(default_value)
  338. required = default_value == Required
  339. annotation: Any = Any
  340. if not param.annotation == param.empty:
  341. annotation = param.annotation
  342. annotation = get_annotation_from_field_info(annotation, field_info, param_name)
  343. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  344. alias = param.name.replace("_", "-")
  345. else:
  346. alias = field_info.alias or param.name
  347. field = create_response_field(
  348. name=param.name,
  349. type_=annotation,
  350. default=None if required else default_value,
  351. alias=alias,
  352. required=required,
  353. field_info=field_info,
  354. )
  355. field.required = required
  356. if not had_schema and not is_scalar_field(field=field):
  357. field.field_info = params.Body(field_info.default)
  358. return field
  359. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  360. field_info = cast(params.Param, field.field_info)
  361. if field_info.in_ == params.ParamTypes.path:
  362. dependant.path_params.append(field)
  363. elif field_info.in_ == params.ParamTypes.query:
  364. dependant.query_params.append(field)
  365. elif field_info.in_ == params.ParamTypes.header:
  366. dependant.header_params.append(field)
  367. else:
  368. assert (
  369. field_info.in_ == params.ParamTypes.cookie
  370. ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
  371. dependant.cookie_params.append(field)
  372. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  373. if inspect.isroutine(call):
  374. return inspect.iscoroutinefunction(call)
  375. if inspect.isclass(call):
  376. return False
  377. call = getattr(call, "__call__", None)
  378. return inspect.iscoroutinefunction(call)
  379. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  380. if inspect.isasyncgenfunction(call):
  381. return True
  382. call = getattr(call, "__call__", None)
  383. return inspect.isasyncgenfunction(call)
  384. def is_gen_callable(call: Callable[..., Any]) -> bool:
  385. if inspect.isgeneratorfunction(call):
  386. return True
  387. call = getattr(call, "__call__", None)
  388. return inspect.isgeneratorfunction(call)
  389. async def solve_generator(
  390. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  391. ) -> Any:
  392. if is_gen_callable(call):
  393. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  394. elif is_async_gen_callable(call):
  395. cm = asynccontextmanager(call)(**sub_values)
  396. return await stack.enter_async_context(cm)
  397. async def solve_dependencies(
  398. *,
  399. request: Union[Request, WebSocket],
  400. dependant: Dependant,
  401. body: Optional[Union[Dict[str, Any], FormData]] = None,
  402. background_tasks: Optional[BackgroundTasks] = None,
  403. response: Optional[Response] = None,
  404. dependency_overrides_provider: Optional[Any] = None,
  405. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  406. ) -> Tuple[
  407. Dict[str, Any],
  408. List[ErrorWrapper],
  409. Optional[BackgroundTasks],
  410. Response,
  411. Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
  412. ]:
  413. values: Dict[str, Any] = {}
  414. errors: List[ErrorWrapper] = []
  415. response = response or Response(
  416. content=None,
  417. status_code=None, # type: ignore
  418. headers=None, # type: ignore # in Starlette
  419. media_type=None, # type: ignore # in Starlette
  420. background=None, # type: ignore # in Starlette
  421. )
  422. dependency_cache = dependency_cache or {}
  423. sub_dependant: Dependant
  424. for sub_dependant in dependant.dependencies:
  425. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  426. sub_dependant.cache_key = cast(
  427. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  428. )
  429. call = sub_dependant.call
  430. use_sub_dependant = sub_dependant
  431. if (
  432. dependency_overrides_provider
  433. and dependency_overrides_provider.dependency_overrides
  434. ):
  435. original_call = sub_dependant.call
  436. call = getattr(
  437. dependency_overrides_provider, "dependency_overrides", {}
  438. ).get(original_call, original_call)
  439. use_path: str = sub_dependant.path # type: ignore
  440. use_sub_dependant = get_dependant(
  441. path=use_path,
  442. call=call,
  443. name=sub_dependant.name,
  444. security_scopes=sub_dependant.security_scopes,
  445. )
  446. use_sub_dependant.security_scopes = sub_dependant.security_scopes
  447. solved_result = await solve_dependencies(
  448. request=request,
  449. dependant=use_sub_dependant,
  450. body=body,
  451. background_tasks=background_tasks,
  452. response=response,
  453. dependency_overrides_provider=dependency_overrides_provider,
  454. dependency_cache=dependency_cache,
  455. )
  456. (
  457. sub_values,
  458. sub_errors,
  459. background_tasks,
  460. _, # the subdependency returns the same response we have
  461. sub_dependency_cache,
  462. ) = solved_result
  463. dependency_cache.update(sub_dependency_cache)
  464. if sub_errors:
  465. errors.extend(sub_errors)
  466. continue
  467. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  468. solved = dependency_cache[sub_dependant.cache_key]
  469. elif is_gen_callable(call) or is_async_gen_callable(call):
  470. stack = request.scope.get("fastapi_astack")
  471. assert isinstance(stack, AsyncExitStack)
  472. solved = await solve_generator(
  473. call=call, stack=stack, sub_values=sub_values
  474. )
  475. elif is_coroutine_callable(call):
  476. solved = await call(**sub_values)
  477. else:
  478. solved = await run_in_threadpool(call, **sub_values)
  479. if sub_dependant.name is not None:
  480. values[sub_dependant.name] = solved
  481. if sub_dependant.cache_key not in dependency_cache:
  482. dependency_cache[sub_dependant.cache_key] = solved
  483. path_values, path_errors = request_params_to_args(
  484. dependant.path_params, request.path_params
  485. )
  486. query_values, query_errors = request_params_to_args(
  487. dependant.query_params, request.query_params
  488. )
  489. header_values, header_errors = request_params_to_args(
  490. dependant.header_params, request.headers
  491. )
  492. cookie_values, cookie_errors = request_params_to_args(
  493. dependant.cookie_params, request.cookies
  494. )
  495. values.update(path_values)
  496. values.update(query_values)
  497. values.update(header_values)
  498. values.update(cookie_values)
  499. errors += path_errors + query_errors + header_errors + cookie_errors
  500. if dependant.body_params:
  501. (
  502. body_values,
  503. body_errors,
  504. ) = await request_body_to_args( # body_params checked above
  505. required_params=dependant.body_params, received_body=body
  506. )
  507. values.update(body_values)
  508. errors.extend(body_errors)
  509. if dependant.http_connection_param_name:
  510. values[dependant.http_connection_param_name] = request
  511. if dependant.request_param_name and isinstance(request, Request):
  512. values[dependant.request_param_name] = request
  513. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  514. values[dependant.websocket_param_name] = request
  515. if dependant.background_tasks_param_name:
  516. if background_tasks is None:
  517. background_tasks = BackgroundTasks()
  518. values[dependant.background_tasks_param_name] = background_tasks
  519. if dependant.response_param_name:
  520. values[dependant.response_param_name] = response
  521. if dependant.security_scopes_param_name:
  522. values[dependant.security_scopes_param_name] = SecurityScopes(
  523. scopes=dependant.security_scopes
  524. )
  525. return values, errors, background_tasks, response, dependency_cache
  526. def request_params_to_args(
  527. required_params: Sequence[ModelField],
  528. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  529. ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
  530. values = {}
  531. errors = []
  532. for field in required_params:
  533. if is_scalar_sequence_field(field) and isinstance(
  534. received_params, (QueryParams, Headers)
  535. ):
  536. value = received_params.getlist(field.alias) or field.default
  537. else:
  538. value = received_params.get(field.alias)
  539. field_info = field.field_info
  540. assert isinstance(
  541. field_info, params.Param
  542. ), "Params must be subclasses of Param"
  543. if value is None:
  544. if field.required:
  545. errors.append(
  546. ErrorWrapper(
  547. MissingError(), loc=(field_info.in_.value, field.alias)
  548. )
  549. )
  550. else:
  551. values[field.name] = deepcopy(field.default)
  552. continue
  553. v_, errors_ = field.validate(
  554. value, values, loc=(field_info.in_.value, field.alias)
  555. )
  556. if isinstance(errors_, ErrorWrapper):
  557. errors.append(errors_)
  558. elif isinstance(errors_, list):
  559. errors.extend(errors_)
  560. else:
  561. values[field.name] = v_
  562. return values, errors
  563. async def request_body_to_args(
  564. required_params: List[ModelField],
  565. received_body: Optional[Union[Dict[str, Any], FormData]],
  566. ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
  567. values = {}
  568. errors = []
  569. if required_params:
  570. field = required_params[0]
  571. field_info = field.field_info
  572. embed = getattr(field_info, "embed", None)
  573. field_alias_omitted = len(required_params) == 1 and not embed
  574. if field_alias_omitted:
  575. received_body = {field.alias: received_body}
  576. for field in required_params:
  577. loc: Tuple[str, ...]
  578. if field_alias_omitted:
  579. loc = ("body",)
  580. else:
  581. loc = ("body", field.alias)
  582. value: Optional[Any] = None
  583. if received_body is not None:
  584. if (
  585. field.shape in sequence_shapes or field.type_ in sequence_types
  586. ) and isinstance(received_body, FormData):
  587. value = received_body.getlist(field.alias)
  588. else:
  589. try:
  590. value = received_body.get(field.alias)
  591. except AttributeError:
  592. errors.append(get_missing_field_error(loc))
  593. continue
  594. if (
  595. value is None
  596. or (isinstance(field_info, params.Form) and value == "")
  597. or (
  598. isinstance(field_info, params.Form)
  599. and field.shape in sequence_shapes
  600. and len(value) == 0
  601. )
  602. ):
  603. if field.required:
  604. errors.append(get_missing_field_error(loc))
  605. else:
  606. values[field.name] = deepcopy(field.default)
  607. continue
  608. if (
  609. isinstance(field_info, params.File)
  610. and lenient_issubclass(field.type_, bytes)
  611. and isinstance(value, UploadFile)
  612. ):
  613. value = await value.read()
  614. elif (
  615. field.shape in sequence_shapes
  616. and isinstance(field_info, params.File)
  617. and lenient_issubclass(field.type_, bytes)
  618. and isinstance(value, sequence_types)
  619. ):
  620. results: List[Union[bytes, str]] = []
  621. async def process_fn(
  622. fn: Callable[[], Coroutine[Any, Any, Any]]
  623. ) -> None:
  624. result = await fn()
  625. results.append(result)
  626. async with anyio.create_task_group() as tg:
  627. for sub_value in value:
  628. tg.start_soon(process_fn, sub_value.read)
  629. value = sequence_shape_to_type[field.shape](results)
  630. v_, errors_ = field.validate(value, values, loc=loc)
  631. if isinstance(errors_, ErrorWrapper):
  632. errors.append(errors_)
  633. elif isinstance(errors_, list):
  634. errors.extend(errors_)
  635. else:
  636. values[field.name] = v_
  637. return values, errors
  638. def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
  639. missing_field_error = ErrorWrapper(MissingError(), loc=loc)
  640. return missing_field_error
  641. def get_schema_compatible_field(*, field: ModelField) -> ModelField:
  642. out_field = field
  643. if lenient_issubclass(field.type_, UploadFile):
  644. use_type: type = bytes
  645. if field.shape in sequence_shapes:
  646. use_type = List[bytes]
  647. out_field = create_response_field(
  648. name=field.name,
  649. type_=use_type,
  650. class_validators=field.class_validators,
  651. model_config=field.model_config,
  652. default=field.default,
  653. required=field.required,
  654. alias=field.alias,
  655. field_info=field.field_info,
  656. )
  657. return out_field
  658. def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
  659. flat_dependant = get_flat_dependant(dependant)
  660. if not flat_dependant.body_params:
  661. return None
  662. first_param = flat_dependant.body_params[0]
  663. field_info = first_param.field_info
  664. embed = getattr(field_info, "embed", None)
  665. body_param_names_set = {param.name for param in flat_dependant.body_params}
  666. if len(body_param_names_set) == 1 and not embed:
  667. final_field = get_schema_compatible_field(field=first_param)
  668. check_file_field(final_field)
  669. return final_field
  670. # If one field requires to embed, all have to be embedded
  671. # in case a sub-dependency is evaluated with a single unique body field
  672. # That is combined (embedded) with other body fields
  673. for param in flat_dependant.body_params:
  674. setattr(param.field_info, "embed", True)
  675. model_name = "Body_" + name
  676. BodyModel: Type[BaseModel] = create_model(model_name)
  677. for f in flat_dependant.body_params:
  678. BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
  679. required = any(True for f in flat_dependant.body_params if f.required)
  680. BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
  681. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  682. BodyFieldInfo: Type[params.Body] = params.File
  683. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  684. BodyFieldInfo = params.Form
  685. else:
  686. BodyFieldInfo = params.Body
  687. body_param_media_types = [
  688. getattr(f.field_info, "media_type")
  689. for f in flat_dependant.body_params
  690. if isinstance(f.field_info, params.Body)
  691. ]
  692. if len(set(body_param_media_types)) == 1:
  693. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  694. final_field = create_response_field(
  695. name="body",
  696. type_=BodyModel,
  697. required=required,
  698. alias="body",
  699. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  700. )
  701. check_file_field(final_field)
  702. return final_field