Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.
 
 
 
 

1002 wiersze
36 KiB

  1. import inspect
  2. from contextlib import AsyncExitStack, contextmanager
  3. from copy import copy, deepcopy
  4. from dataclasses import dataclass
  5. from typing import (
  6. Any,
  7. Callable,
  8. Coroutine,
  9. Dict,
  10. ForwardRef,
  11. List,
  12. Mapping,
  13. Optional,
  14. Sequence,
  15. Tuple,
  16. Type,
  17. Union,
  18. cast,
  19. )
  20. import anyio
  21. from fastapi import params
  22. from fastapi._compat import (
  23. PYDANTIC_V2,
  24. ErrorWrapper,
  25. ModelField,
  26. RequiredParam,
  27. Undefined,
  28. _regenerate_error_with_loc,
  29. copy_field_info,
  30. create_body_model,
  31. evaluate_forwardref,
  32. field_annotation_is_scalar,
  33. get_annotation_from_field_info,
  34. get_cached_model_fields,
  35. get_missing_field_error,
  36. is_bytes_field,
  37. is_bytes_sequence_field,
  38. is_scalar_field,
  39. is_scalar_sequence_field,
  40. is_sequence_field,
  41. is_uploadfile_or_nonable_uploadfile_annotation,
  42. is_uploadfile_sequence_annotation,
  43. lenient_issubclass,
  44. sequence_types,
  45. serialize_sequence_value,
  46. value_is_sequence,
  47. )
  48. from fastapi.background import BackgroundTasks
  49. from fastapi.concurrency import (
  50. asynccontextmanager,
  51. contextmanager_in_threadpool,
  52. )
  53. from fastapi.dependencies.models import Dependant, SecurityRequirement
  54. from fastapi.logger import logger
  55. from fastapi.security.base import SecurityBase
  56. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  57. from fastapi.security.open_id_connect_url import OpenIdConnect
  58. from fastapi.utils import create_model_field, get_path_param_names
  59. from pydantic import BaseModel
  60. from pydantic.fields import FieldInfo
  61. from starlette.background import BackgroundTasks as StarletteBackgroundTasks
  62. from starlette.concurrency import run_in_threadpool
  63. from starlette.datastructures import (
  64. FormData,
  65. Headers,
  66. ImmutableMultiDict,
  67. QueryParams,
  68. UploadFile,
  69. )
  70. from starlette.requests import HTTPConnection, Request
  71. from starlette.responses import Response
  72. from starlette.websockets import WebSocket
  73. from typing_extensions import Annotated, get_args, get_origin
  74. multipart_not_installed_error = (
  75. 'Form data requires "python-multipart" to be installed. \n'
  76. 'You can install "python-multipart" with: \n\n'
  77. "pip install python-multipart\n"
  78. )
  79. multipart_incorrect_install_error = (
  80. 'Form data requires "python-multipart" to be installed. '
  81. 'It seems you installed "multipart" instead. \n'
  82. 'You can remove "multipart" with: \n\n'
  83. "pip uninstall multipart\n\n"
  84. 'And then install "python-multipart" with: \n\n'
  85. "pip install python-multipart\n"
  86. )
  87. def ensure_multipart_is_installed() -> None:
  88. try:
  89. from python_multipart import __version__
  90. # Import an attribute that can be mocked/deleted in testing
  91. assert __version__ > "0.0.12"
  92. except (ImportError, AssertionError):
  93. try:
  94. # __version__ is available in both multiparts, and can be mocked
  95. from multipart import __version__ # type: ignore[no-redef,import-untyped]
  96. assert __version__
  97. try:
  98. # parse_options_header is only available in the right multipart
  99. from multipart.multipart import ( # type: ignore[import-untyped]
  100. parse_options_header,
  101. )
  102. assert parse_options_header
  103. except ImportError:
  104. logger.error(multipart_incorrect_install_error)
  105. raise RuntimeError(multipart_incorrect_install_error) from None
  106. except ImportError:
  107. logger.error(multipart_not_installed_error)
  108. raise RuntimeError(multipart_not_installed_error) from None
  109. def get_param_sub_dependant(
  110. *,
  111. param_name: str,
  112. depends: params.Depends,
  113. path: str,
  114. security_scopes: Optional[List[str]] = None,
  115. ) -> Dependant:
  116. assert depends.dependency
  117. return get_sub_dependant(
  118. depends=depends,
  119. dependency=depends.dependency,
  120. path=path,
  121. name=param_name,
  122. security_scopes=security_scopes,
  123. )
  124. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  125. assert callable(depends.dependency), (
  126. "A parameter-less dependency must have a callable dependency"
  127. )
  128. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  129. def get_sub_dependant(
  130. *,
  131. depends: params.Depends,
  132. dependency: Callable[..., Any],
  133. path: str,
  134. name: Optional[str] = None,
  135. security_scopes: Optional[List[str]] = None,
  136. ) -> Dependant:
  137. security_requirement = None
  138. security_scopes = security_scopes or []
  139. if isinstance(depends, params.Security):
  140. dependency_scopes = depends.scopes
  141. security_scopes.extend(dependency_scopes)
  142. if isinstance(dependency, SecurityBase):
  143. use_scopes: List[str] = []
  144. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  145. use_scopes = security_scopes
  146. security_requirement = SecurityRequirement(
  147. security_scheme=dependency, scopes=use_scopes
  148. )
  149. sub_dependant = get_dependant(
  150. path=path,
  151. call=dependency,
  152. name=name,
  153. security_scopes=security_scopes,
  154. use_cache=depends.use_cache,
  155. )
  156. if security_requirement:
  157. sub_dependant.security_requirements.append(security_requirement)
  158. return sub_dependant
  159. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  160. def get_flat_dependant(
  161. dependant: Dependant,
  162. *,
  163. skip_repeats: bool = False,
  164. visited: Optional[List[CacheKey]] = None,
  165. ) -> Dependant:
  166. if visited is None:
  167. visited = []
  168. visited.append(dependant.cache_key)
  169. flat_dependant = Dependant(
  170. path_params=dependant.path_params.copy(),
  171. query_params=dependant.query_params.copy(),
  172. header_params=dependant.header_params.copy(),
  173. cookie_params=dependant.cookie_params.copy(),
  174. body_params=dependant.body_params.copy(),
  175. security_requirements=dependant.security_requirements.copy(),
  176. use_cache=dependant.use_cache,
  177. path=dependant.path,
  178. )
  179. for sub_dependant in dependant.dependencies:
  180. if skip_repeats and sub_dependant.cache_key in visited:
  181. continue
  182. flat_sub = get_flat_dependant(
  183. sub_dependant, skip_repeats=skip_repeats, visited=visited
  184. )
  185. flat_dependant.path_params.extend(flat_sub.path_params)
  186. flat_dependant.query_params.extend(flat_sub.query_params)
  187. flat_dependant.header_params.extend(flat_sub.header_params)
  188. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  189. flat_dependant.body_params.extend(flat_sub.body_params)
  190. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  191. return flat_dependant
  192. def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
  193. if not fields:
  194. return fields
  195. first_field = fields[0]
  196. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  197. fields_to_extract = get_cached_model_fields(first_field.type_)
  198. return fields_to_extract
  199. return fields
  200. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  201. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  202. path_params = _get_flat_fields_from_params(flat_dependant.path_params)
  203. query_params = _get_flat_fields_from_params(flat_dependant.query_params)
  204. header_params = _get_flat_fields_from_params(flat_dependant.header_params)
  205. cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
  206. return path_params + query_params + header_params + cookie_params
  207. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  208. signature = inspect.signature(call)
  209. globalns = getattr(call, "__globals__", {})
  210. typed_params = [
  211. inspect.Parameter(
  212. name=param.name,
  213. kind=param.kind,
  214. default=param.default,
  215. annotation=get_typed_annotation(param.annotation, globalns),
  216. )
  217. for param in signature.parameters.values()
  218. ]
  219. typed_signature = inspect.Signature(typed_params)
  220. return typed_signature
  221. def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
  222. if isinstance(annotation, str):
  223. annotation = ForwardRef(annotation)
  224. annotation = evaluate_forwardref(annotation, globalns, globalns)
  225. return annotation
  226. def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
  227. signature = inspect.signature(call)
  228. annotation = signature.return_annotation
  229. if annotation is inspect.Signature.empty:
  230. return None
  231. globalns = getattr(call, "__globals__", {})
  232. return get_typed_annotation(annotation, globalns)
  233. def get_dependant(
  234. *,
  235. path: str,
  236. call: Callable[..., Any],
  237. name: Optional[str] = None,
  238. security_scopes: Optional[List[str]] = None,
  239. use_cache: bool = True,
  240. ) -> Dependant:
  241. path_param_names = get_path_param_names(path)
  242. endpoint_signature = get_typed_signature(call)
  243. signature_params = endpoint_signature.parameters
  244. dependant = Dependant(
  245. call=call,
  246. name=name,
  247. path=path,
  248. security_scopes=security_scopes,
  249. use_cache=use_cache,
  250. )
  251. for param_name, param in signature_params.items():
  252. is_path_param = param_name in path_param_names
  253. param_details = analyze_param(
  254. param_name=param_name,
  255. annotation=param.annotation,
  256. value=param.default,
  257. is_path_param=is_path_param,
  258. )
  259. if param_details.depends is not None:
  260. sub_dependant = get_param_sub_dependant(
  261. param_name=param_name,
  262. depends=param_details.depends,
  263. path=path,
  264. security_scopes=security_scopes,
  265. )
  266. dependant.dependencies.append(sub_dependant)
  267. continue
  268. if add_non_field_param_to_dependency(
  269. param_name=param_name,
  270. type_annotation=param_details.type_annotation,
  271. dependant=dependant,
  272. ):
  273. assert param_details.field is None, (
  274. f"Cannot specify multiple FastAPI annotations for {param_name!r}"
  275. )
  276. continue
  277. assert param_details.field is not None
  278. if isinstance(param_details.field.field_info, params.Body):
  279. dependant.body_params.append(param_details.field)
  280. else:
  281. add_param_to_fields(field=param_details.field, dependant=dependant)
  282. return dependant
  283. def add_non_field_param_to_dependency(
  284. *, param_name: str, type_annotation: Any, dependant: Dependant
  285. ) -> Optional[bool]:
  286. if lenient_issubclass(type_annotation, Request):
  287. dependant.request_param_name = param_name
  288. return True
  289. elif lenient_issubclass(type_annotation, WebSocket):
  290. dependant.websocket_param_name = param_name
  291. return True
  292. elif lenient_issubclass(type_annotation, HTTPConnection):
  293. dependant.http_connection_param_name = param_name
  294. return True
  295. elif lenient_issubclass(type_annotation, Response):
  296. dependant.response_param_name = param_name
  297. return True
  298. elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
  299. dependant.background_tasks_param_name = param_name
  300. return True
  301. elif lenient_issubclass(type_annotation, SecurityScopes):
  302. dependant.security_scopes_param_name = param_name
  303. return True
  304. return None
  305. @dataclass
  306. class ParamDetails:
  307. type_annotation: Any
  308. depends: Optional[params.Depends]
  309. field: Optional[ModelField]
  310. def analyze_param(
  311. *,
  312. param_name: str,
  313. annotation: Any,
  314. value: Any,
  315. is_path_param: bool,
  316. ) -> ParamDetails:
  317. field_info = None
  318. depends = None
  319. type_annotation: Any = Any
  320. use_annotation: Any = Any
  321. if annotation is not inspect.Signature.empty:
  322. use_annotation = annotation
  323. type_annotation = annotation
  324. # Extract Annotated info
  325. if get_origin(use_annotation) is Annotated:
  326. annotated_args = get_args(annotation)
  327. type_annotation = annotated_args[0]
  328. fastapi_annotations = [
  329. arg
  330. for arg in annotated_args[1:]
  331. if isinstance(arg, (FieldInfo, params.Depends))
  332. ]
  333. fastapi_specific_annotations = [
  334. arg
  335. for arg in fastapi_annotations
  336. if isinstance(arg, (params.Param, params.Body, params.Depends))
  337. ]
  338. if fastapi_specific_annotations:
  339. fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
  340. fastapi_specific_annotations[-1]
  341. )
  342. else:
  343. fastapi_annotation = None
  344. # Set default for Annotated FieldInfo
  345. if isinstance(fastapi_annotation, FieldInfo):
  346. # Copy `field_info` because we mutate `field_info.default` below.
  347. field_info = copy_field_info(
  348. field_info=fastapi_annotation, annotation=use_annotation
  349. )
  350. assert (
  351. field_info.default is Undefined or field_info.default is RequiredParam
  352. ), (
  353. f"`{field_info.__class__.__name__}` default value cannot be set in"
  354. f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
  355. )
  356. if value is not inspect.Signature.empty:
  357. assert not is_path_param, "Path parameters cannot have default values"
  358. field_info.default = value
  359. else:
  360. field_info.default = RequiredParam
  361. # Get Annotated Depends
  362. elif isinstance(fastapi_annotation, params.Depends):
  363. depends = fastapi_annotation
  364. # Get Depends from default value
  365. if isinstance(value, params.Depends):
  366. assert depends is None, (
  367. "Cannot specify `Depends` in `Annotated` and default value"
  368. f" together for {param_name!r}"
  369. )
  370. assert field_info is None, (
  371. "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
  372. f" default value together for {param_name!r}"
  373. )
  374. depends = value
  375. # Get FieldInfo from default value
  376. elif isinstance(value, FieldInfo):
  377. assert field_info is None, (
  378. "Cannot specify FastAPI annotations in `Annotated` and default value"
  379. f" together for {param_name!r}"
  380. )
  381. field_info = value
  382. if PYDANTIC_V2:
  383. field_info.annotation = type_annotation
  384. # Get Depends from type annotation
  385. if depends is not None and depends.dependency is None:
  386. # Copy `depends` before mutating it
  387. depends = copy(depends)
  388. depends.dependency = type_annotation
  389. # Handle non-param type annotations like Request
  390. if lenient_issubclass(
  391. type_annotation,
  392. (
  393. Request,
  394. WebSocket,
  395. HTTPConnection,
  396. Response,
  397. StarletteBackgroundTasks,
  398. SecurityScopes,
  399. ),
  400. ):
  401. assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
  402. assert field_info is None, (
  403. f"Cannot specify FastAPI annotation for type {type_annotation!r}"
  404. )
  405. # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
  406. elif field_info is None and depends is None:
  407. default_value = value if value is not inspect.Signature.empty else RequiredParam
  408. if is_path_param:
  409. # We might check here that `default_value is RequiredParam`, but the fact is that the same
  410. # parameter might sometimes be a path parameter and sometimes not. See
  411. # `tests/test_infer_param_optionality.py` for an example.
  412. field_info = params.Path(annotation=use_annotation)
  413. elif is_uploadfile_or_nonable_uploadfile_annotation(
  414. type_annotation
  415. ) or is_uploadfile_sequence_annotation(type_annotation):
  416. field_info = params.File(annotation=use_annotation, default=default_value)
  417. elif not field_annotation_is_scalar(annotation=type_annotation):
  418. field_info = params.Body(annotation=use_annotation, default=default_value)
  419. else:
  420. field_info = params.Query(annotation=use_annotation, default=default_value)
  421. field = None
  422. # It's a field_info, not a dependency
  423. if field_info is not None:
  424. # Handle field_info.in_
  425. if is_path_param:
  426. assert isinstance(field_info, params.Path), (
  427. f"Cannot use `{field_info.__class__.__name__}` for path param"
  428. f" {param_name!r}"
  429. )
  430. elif (
  431. isinstance(field_info, params.Param)
  432. and getattr(field_info, "in_", None) is None
  433. ):
  434. field_info.in_ = params.ParamTypes.query
  435. use_annotation_from_field_info = get_annotation_from_field_info(
  436. use_annotation,
  437. field_info,
  438. param_name,
  439. )
  440. if isinstance(field_info, params.Form):
  441. ensure_multipart_is_installed()
  442. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  443. alias = param_name.replace("_", "-")
  444. else:
  445. alias = field_info.alias or param_name
  446. field_info.alias = alias
  447. field = create_model_field(
  448. name=param_name,
  449. type_=use_annotation_from_field_info,
  450. default=field_info.default,
  451. alias=alias,
  452. required=field_info.default in (RequiredParam, Undefined),
  453. field_info=field_info,
  454. )
  455. if is_path_param:
  456. assert is_scalar_field(field=field), (
  457. "Path params must be of one of the supported types"
  458. )
  459. elif isinstance(field_info, params.Query):
  460. assert (
  461. is_scalar_field(field)
  462. or is_scalar_sequence_field(field)
  463. or (
  464. lenient_issubclass(field.type_, BaseModel)
  465. # For Pydantic v1
  466. and getattr(field, "shape", 1) == 1
  467. )
  468. )
  469. return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
  470. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  471. field_info = field.field_info
  472. field_info_in = getattr(field_info, "in_", None)
  473. if field_info_in == params.ParamTypes.path:
  474. dependant.path_params.append(field)
  475. elif field_info_in == params.ParamTypes.query:
  476. dependant.query_params.append(field)
  477. elif field_info_in == params.ParamTypes.header:
  478. dependant.header_params.append(field)
  479. else:
  480. assert field_info_in == params.ParamTypes.cookie, (
  481. f"non-body parameters must be in path, query, header or cookie: {field.name}"
  482. )
  483. dependant.cookie_params.append(field)
  484. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  485. if inspect.isroutine(call):
  486. return inspect.iscoroutinefunction(call)
  487. if inspect.isclass(call):
  488. return False
  489. dunder_call = getattr(call, "__call__", None) # noqa: B004
  490. return inspect.iscoroutinefunction(dunder_call)
  491. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  492. if inspect.isasyncgenfunction(call):
  493. return True
  494. dunder_call = getattr(call, "__call__", None) # noqa: B004
  495. return inspect.isasyncgenfunction(dunder_call)
  496. def is_gen_callable(call: Callable[..., Any]) -> bool:
  497. if inspect.isgeneratorfunction(call):
  498. return True
  499. dunder_call = getattr(call, "__call__", None) # noqa: B004
  500. return inspect.isgeneratorfunction(dunder_call)
  501. async def solve_generator(
  502. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  503. ) -> Any:
  504. if is_gen_callable(call):
  505. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  506. elif is_async_gen_callable(call):
  507. cm = asynccontextmanager(call)(**sub_values)
  508. return await stack.enter_async_context(cm)
  509. @dataclass
  510. class SolvedDependency:
  511. values: Dict[str, Any]
  512. errors: List[Any]
  513. background_tasks: Optional[StarletteBackgroundTasks]
  514. response: Response
  515. dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
  516. async def solve_dependencies(
  517. *,
  518. request: Union[Request, WebSocket],
  519. dependant: Dependant,
  520. body: Optional[Union[Dict[str, Any], FormData]] = None,
  521. background_tasks: Optional[StarletteBackgroundTasks] = None,
  522. response: Optional[Response] = None,
  523. dependency_overrides_provider: Optional[Any] = None,
  524. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  525. async_exit_stack: AsyncExitStack,
  526. embed_body_fields: bool,
  527. ) -> SolvedDependency:
  528. values: Dict[str, Any] = {}
  529. errors: List[Any] = []
  530. if response is None:
  531. response = Response()
  532. del response.headers["content-length"]
  533. response.status_code = None # type: ignore
  534. dependency_cache = dependency_cache or {}
  535. sub_dependant: Dependant
  536. for sub_dependant in dependant.dependencies:
  537. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  538. sub_dependant.cache_key = cast(
  539. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  540. )
  541. call = sub_dependant.call
  542. use_sub_dependant = sub_dependant
  543. if (
  544. dependency_overrides_provider
  545. and dependency_overrides_provider.dependency_overrides
  546. ):
  547. original_call = sub_dependant.call
  548. call = getattr(
  549. dependency_overrides_provider, "dependency_overrides", {}
  550. ).get(original_call, original_call)
  551. use_path: str = sub_dependant.path # type: ignore
  552. use_sub_dependant = get_dependant(
  553. path=use_path,
  554. call=call,
  555. name=sub_dependant.name,
  556. security_scopes=sub_dependant.security_scopes,
  557. )
  558. solved_result = await solve_dependencies(
  559. request=request,
  560. dependant=use_sub_dependant,
  561. body=body,
  562. background_tasks=background_tasks,
  563. response=response,
  564. dependency_overrides_provider=dependency_overrides_provider,
  565. dependency_cache=dependency_cache,
  566. async_exit_stack=async_exit_stack,
  567. embed_body_fields=embed_body_fields,
  568. )
  569. background_tasks = solved_result.background_tasks
  570. dependency_cache.update(solved_result.dependency_cache)
  571. if solved_result.errors:
  572. errors.extend(solved_result.errors)
  573. continue
  574. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  575. solved = dependency_cache[sub_dependant.cache_key]
  576. elif is_gen_callable(call) or is_async_gen_callable(call):
  577. solved = await solve_generator(
  578. call=call, stack=async_exit_stack, sub_values=solved_result.values
  579. )
  580. elif is_coroutine_callable(call):
  581. solved = await call(**solved_result.values)
  582. else:
  583. solved = await run_in_threadpool(call, **solved_result.values)
  584. if sub_dependant.name is not None:
  585. values[sub_dependant.name] = solved
  586. if sub_dependant.cache_key not in dependency_cache:
  587. dependency_cache[sub_dependant.cache_key] = solved
  588. path_values, path_errors = request_params_to_args(
  589. dependant.path_params, request.path_params
  590. )
  591. query_values, query_errors = request_params_to_args(
  592. dependant.query_params, request.query_params
  593. )
  594. header_values, header_errors = request_params_to_args(
  595. dependant.header_params, request.headers
  596. )
  597. cookie_values, cookie_errors = request_params_to_args(
  598. dependant.cookie_params, request.cookies
  599. )
  600. values.update(path_values)
  601. values.update(query_values)
  602. values.update(header_values)
  603. values.update(cookie_values)
  604. errors += path_errors + query_errors + header_errors + cookie_errors
  605. if dependant.body_params:
  606. (
  607. body_values,
  608. body_errors,
  609. ) = await request_body_to_args( # body_params checked above
  610. body_fields=dependant.body_params,
  611. received_body=body,
  612. embed_body_fields=embed_body_fields,
  613. )
  614. values.update(body_values)
  615. errors.extend(body_errors)
  616. if dependant.http_connection_param_name:
  617. values[dependant.http_connection_param_name] = request
  618. if dependant.request_param_name and isinstance(request, Request):
  619. values[dependant.request_param_name] = request
  620. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  621. values[dependant.websocket_param_name] = request
  622. if dependant.background_tasks_param_name:
  623. if background_tasks is None:
  624. background_tasks = BackgroundTasks()
  625. values[dependant.background_tasks_param_name] = background_tasks
  626. if dependant.response_param_name:
  627. values[dependant.response_param_name] = response
  628. if dependant.security_scopes_param_name:
  629. values[dependant.security_scopes_param_name] = SecurityScopes(
  630. scopes=dependant.security_scopes
  631. )
  632. return SolvedDependency(
  633. values=values,
  634. errors=errors,
  635. background_tasks=background_tasks,
  636. response=response,
  637. dependency_cache=dependency_cache,
  638. )
  639. def _validate_value_with_model_field(
  640. *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
  641. ) -> Tuple[Any, List[Any]]:
  642. if value is None:
  643. if field.required:
  644. return None, [get_missing_field_error(loc=loc)]
  645. else:
  646. return deepcopy(field.default), []
  647. v_, errors_ = field.validate(value, values, loc=loc)
  648. if isinstance(errors_, ErrorWrapper):
  649. return None, [errors_]
  650. elif isinstance(errors_, list):
  651. new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
  652. return None, new_errors
  653. else:
  654. return v_, []
  655. def _get_multidict_value(
  656. field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
  657. ) -> Any:
  658. alias = alias or field.alias
  659. if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
  660. value = values.getlist(alias)
  661. else:
  662. value = values.get(alias, None)
  663. if (
  664. value is None
  665. or (
  666. isinstance(field.field_info, params.Form)
  667. and isinstance(value, str) # For type checks
  668. and value == ""
  669. )
  670. or (is_sequence_field(field) and len(value) == 0)
  671. ):
  672. if field.required:
  673. return
  674. else:
  675. return deepcopy(field.default)
  676. return value
  677. def request_params_to_args(
  678. fields: Sequence[ModelField],
  679. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  680. ) -> Tuple[Dict[str, Any], List[Any]]:
  681. values: Dict[str, Any] = {}
  682. errors: List[Dict[str, Any]] = []
  683. if not fields:
  684. return values, errors
  685. first_field = fields[0]
  686. fields_to_extract = fields
  687. single_not_embedded_field = False
  688. default_convert_underscores = True
  689. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  690. fields_to_extract = get_cached_model_fields(first_field.type_)
  691. single_not_embedded_field = True
  692. # If headers are in a Pydantic model, the way to disable convert_underscores
  693. # would be with Header(convert_underscores=False) at the Pydantic model level
  694. default_convert_underscores = getattr(
  695. first_field.field_info, "convert_underscores", True
  696. )
  697. params_to_process: Dict[str, Any] = {}
  698. processed_keys = set()
  699. for field in fields_to_extract:
  700. alias = None
  701. if isinstance(received_params, Headers):
  702. # Handle fields extracted from a Pydantic Model for a header, each field
  703. # doesn't have a FieldInfo of type Header with the default convert_underscores=True
  704. convert_underscores = getattr(
  705. field.field_info, "convert_underscores", default_convert_underscores
  706. )
  707. if convert_underscores:
  708. alias = (
  709. field.alias
  710. if field.alias != field.name
  711. else field.name.replace("_", "-")
  712. )
  713. value = _get_multidict_value(field, received_params, alias=alias)
  714. if value is not None:
  715. params_to_process[field.name] = value
  716. processed_keys.add(alias or field.alias)
  717. processed_keys.add(field.name)
  718. for key, value in received_params.items():
  719. if key not in processed_keys:
  720. params_to_process[key] = value
  721. if single_not_embedded_field:
  722. field_info = first_field.field_info
  723. assert isinstance(field_info, params.Param), (
  724. "Params must be subclasses of Param"
  725. )
  726. loc: Tuple[str, ...] = (field_info.in_.value,)
  727. v_, errors_ = _validate_value_with_model_field(
  728. field=first_field, value=params_to_process, values=values, loc=loc
  729. )
  730. return {first_field.name: v_}, errors_
  731. for field in fields:
  732. value = _get_multidict_value(field, received_params)
  733. field_info = field.field_info
  734. assert isinstance(field_info, params.Param), (
  735. "Params must be subclasses of Param"
  736. )
  737. loc = (field_info.in_.value, field.alias)
  738. v_, errors_ = _validate_value_with_model_field(
  739. field=field, value=value, values=values, loc=loc
  740. )
  741. if errors_:
  742. errors.extend(errors_)
  743. else:
  744. values[field.name] = v_
  745. return values, errors
  746. def is_union_of_base_models(field_type: Any) -> bool:
  747. """Check if field type is a Union where all members are BaseModel subclasses."""
  748. from fastapi.types import UnionType
  749. origin = get_origin(field_type)
  750. # Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
  751. if origin is not Union and origin is not UnionType:
  752. return False
  753. union_args = get_args(field_type)
  754. for arg in union_args:
  755. if not lenient_issubclass(arg, BaseModel):
  756. return False
  757. return True
  758. def _should_embed_body_fields(fields: List[ModelField]) -> bool:
  759. if not fields:
  760. return False
  761. # More than one dependency could have the same field, it would show up as multiple
  762. # fields but it's the same one, so count them by name
  763. body_param_names_set = {field.name for field in fields}
  764. # A top level field has to be a single field, not multiple
  765. if len(body_param_names_set) > 1:
  766. return True
  767. first_field = fields[0]
  768. # If it explicitly specifies it is embedded, it has to be embedded
  769. if getattr(first_field.field_info, "embed", None):
  770. return True
  771. # If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
  772. # otherwise it has to be embedded, so that the key value pair can be extracted
  773. if (
  774. isinstance(first_field.field_info, params.Form)
  775. and not lenient_issubclass(first_field.type_, BaseModel)
  776. and not is_union_of_base_models(first_field.type_)
  777. ):
  778. return True
  779. return False
  780. async def _extract_form_body(
  781. body_fields: List[ModelField],
  782. received_body: FormData,
  783. ) -> Dict[str, Any]:
  784. values = {}
  785. first_field = body_fields[0]
  786. first_field_info = first_field.field_info
  787. for field in body_fields:
  788. value = _get_multidict_value(field, received_body)
  789. if (
  790. isinstance(first_field_info, params.File)
  791. and is_bytes_field(field)
  792. and isinstance(value, UploadFile)
  793. ):
  794. value = await value.read()
  795. elif (
  796. is_bytes_sequence_field(field)
  797. and isinstance(first_field_info, params.File)
  798. and value_is_sequence(value)
  799. ):
  800. # For types
  801. assert isinstance(value, sequence_types) # type: ignore[arg-type]
  802. results: List[Union[bytes, str]] = []
  803. async def process_fn(
  804. fn: Callable[[], Coroutine[Any, Any, Any]],
  805. ) -> None:
  806. result = await fn()
  807. results.append(result) # noqa: B023
  808. async with anyio.create_task_group() as tg:
  809. for sub_value in value:
  810. tg.start_soon(process_fn, sub_value.read)
  811. value = serialize_sequence_value(field=field, value=results)
  812. if value is not None:
  813. values[field.alias] = value
  814. for key, value in received_body.items():
  815. if key not in values:
  816. values[key] = value
  817. return values
  818. async def request_body_to_args(
  819. body_fields: List[ModelField],
  820. received_body: Optional[Union[Dict[str, Any], FormData]],
  821. embed_body_fields: bool,
  822. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  823. values: Dict[str, Any] = {}
  824. errors: List[Dict[str, Any]] = []
  825. assert body_fields, "request_body_to_args() should be called with fields"
  826. single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
  827. first_field = body_fields[0]
  828. body_to_process = received_body
  829. fields_to_extract: List[ModelField] = body_fields
  830. if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
  831. fields_to_extract = get_cached_model_fields(first_field.type_)
  832. if isinstance(received_body, FormData):
  833. body_to_process = await _extract_form_body(fields_to_extract, received_body)
  834. if single_not_embedded_field:
  835. loc: Tuple[str, ...] = ("body",)
  836. v_, errors_ = _validate_value_with_model_field(
  837. field=first_field, value=body_to_process, values=values, loc=loc
  838. )
  839. return {first_field.name: v_}, errors_
  840. for field in body_fields:
  841. loc = ("body", field.alias)
  842. value: Optional[Any] = None
  843. if body_to_process is not None:
  844. try:
  845. value = body_to_process.get(field.alias)
  846. # If the received body is a list, not a dict
  847. except AttributeError:
  848. errors.append(get_missing_field_error(loc))
  849. continue
  850. v_, errors_ = _validate_value_with_model_field(
  851. field=field, value=value, values=values, loc=loc
  852. )
  853. if errors_:
  854. errors.extend(errors_)
  855. else:
  856. values[field.name] = v_
  857. return values, errors
  858. def get_body_field(
  859. *, flat_dependant: Dependant, name: str, embed_body_fields: bool
  860. ) -> Optional[ModelField]:
  861. """
  862. Get a ModelField representing the request body for a path operation, combining
  863. all body parameters into a single field if necessary.
  864. Used to check if it's form data (with `isinstance(body_field, params.Form)`)
  865. or JSON and to generate the JSON Schema for a request body.
  866. This is **not** used to validate/parse the request body, that's done with each
  867. individual body parameter.
  868. """
  869. if not flat_dependant.body_params:
  870. return None
  871. first_param = flat_dependant.body_params[0]
  872. if not embed_body_fields:
  873. return first_param
  874. model_name = "Body_" + name
  875. BodyModel = create_body_model(
  876. fields=flat_dependant.body_params, model_name=model_name
  877. )
  878. required = any(True for f in flat_dependant.body_params if f.required)
  879. BodyFieldInfo_kwargs: Dict[str, Any] = {
  880. "annotation": BodyModel,
  881. "alias": "body",
  882. }
  883. if not required:
  884. BodyFieldInfo_kwargs["default"] = None
  885. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  886. BodyFieldInfo: Type[params.Body] = params.File
  887. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  888. BodyFieldInfo = params.Form
  889. else:
  890. BodyFieldInfo = params.Body
  891. body_param_media_types = [
  892. f.field_info.media_type
  893. for f in flat_dependant.body_params
  894. if isinstance(f.field_info, params.Body)
  895. ]
  896. if len(set(body_param_media_types)) == 1:
  897. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  898. final_field = create_model_field(
  899. name="body",
  900. type_=BodyModel,
  901. required=required,
  902. alias="body",
  903. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  904. )
  905. return final_field