Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 
 

1002 righe
36 KiB

  1. import inspect
  2. from contextlib import AsyncExitStack, contextmanager
  3. from copy import copy, deepcopy
  4. from dataclasses import dataclass
  5. from typing import (
  6. Any,
  7. Callable,
  8. Coroutine,
  9. Dict,
  10. ForwardRef,
  11. List,
  12. Mapping,
  13. Optional,
  14. Sequence,
  15. Tuple,
  16. Type,
  17. Union,
  18. cast,
  19. )
  20. import anyio
  21. from fastapi import params
  22. from fastapi._compat import (
  23. PYDANTIC_V2,
  24. ErrorWrapper,
  25. ModelField,
  26. RequiredParam,
  27. Undefined,
  28. _regenerate_error_with_loc,
  29. copy_field_info,
  30. create_body_model,
  31. evaluate_forwardref,
  32. field_annotation_is_scalar,
  33. get_annotation_from_field_info,
  34. get_cached_model_fields,
  35. get_missing_field_error,
  36. is_bytes_field,
  37. is_bytes_sequence_field,
  38. is_scalar_field,
  39. is_scalar_sequence_field,
  40. is_sequence_field,
  41. is_uploadfile_or_nonable_uploadfile_annotation,
  42. is_uploadfile_sequence_annotation,
  43. lenient_issubclass,
  44. sequence_types,
  45. serialize_sequence_value,
  46. value_is_sequence,
  47. )
  48. from fastapi.background import BackgroundTasks
  49. from fastapi.concurrency import (
  50. asynccontextmanager,
  51. contextmanager_in_threadpool,
  52. )
  53. from fastapi.dependencies.models import Dependant, SecurityRequirement
  54. from fastapi.logger import logger
  55. from fastapi.security.base import SecurityBase
  56. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  57. from fastapi.security.open_id_connect_url import OpenIdConnect
  58. from fastapi.utils import create_model_field, get_path_param_names
  59. from pydantic import BaseModel
  60. from pydantic.fields import FieldInfo
  61. from starlette.background import BackgroundTasks as StarletteBackgroundTasks
  62. from starlette.concurrency import run_in_threadpool
  63. from starlette.datastructures import (
  64. FormData,
  65. Headers,
  66. ImmutableMultiDict,
  67. QueryParams,
  68. UploadFile,
  69. )
  70. from starlette.requests import HTTPConnection, Request
  71. from starlette.responses import Response
  72. from starlette.websockets import WebSocket
  73. from typing_extensions import Annotated, get_args, get_origin
  74. multipart_not_installed_error = (
  75. 'Form data requires "python-multipart" to be installed. \n'
  76. 'You can install "python-multipart" with: \n\n'
  77. "pip install python-multipart\n"
  78. )
  79. multipart_incorrect_install_error = (
  80. 'Form data requires "python-multipart" to be installed. '
  81. 'It seems you installed "multipart" instead. \n'
  82. 'You can remove "multipart" with: \n\n'
  83. "pip uninstall multipart\n\n"
  84. 'And then install "python-multipart" with: \n\n'
  85. "pip install python-multipart\n"
  86. )
  87. def ensure_multipart_is_installed() -> None:
  88. try:
  89. from python_multipart import __version__
  90. # Import an attribute that can be mocked/deleted in testing
  91. assert __version__ > "0.0.12"
  92. except (ImportError, AssertionError):
  93. try:
  94. # __version__ is available in both multiparts, and can be mocked
  95. from multipart import __version__ # type: ignore[no-redef,import-untyped]
  96. assert __version__
  97. try:
  98. # parse_options_header is only available in the right multipart
  99. from multipart.multipart import ( # type: ignore[import-untyped]
  100. parse_options_header,
  101. )
  102. assert parse_options_header
  103. except ImportError:
  104. logger.error(multipart_incorrect_install_error)
  105. raise RuntimeError(multipart_incorrect_install_error) from None
  106. except ImportError:
  107. logger.error(multipart_not_installed_error)
  108. raise RuntimeError(multipart_not_installed_error) from None
  109. def get_param_sub_dependant(
  110. *,
  111. param_name: str,
  112. depends: params.Depends,
  113. path: str,
  114. security_scopes: Optional[List[str]] = None,
  115. ) -> Dependant:
  116. assert depends.dependency
  117. return get_sub_dependant(
  118. depends=depends,
  119. dependency=depends.dependency,
  120. path=path,
  121. name=param_name,
  122. security_scopes=security_scopes,
  123. )
  124. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  125. assert callable(depends.dependency), (
  126. "A parameter-less dependency must have a callable dependency"
  127. )
  128. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  129. def get_sub_dependant(
  130. *,
  131. depends: params.Depends,
  132. dependency: Callable[..., Any],
  133. path: str,
  134. name: Optional[str] = None,
  135. security_scopes: Optional[List[str]] = None,
  136. ) -> Dependant:
  137. security_requirement = None
  138. security_scopes = security_scopes or []
  139. if isinstance(depends, params.Security):
  140. dependency_scopes = depends.scopes
  141. security_scopes.extend(dependency_scopes)
  142. if isinstance(dependency, SecurityBase):
  143. use_scopes: List[str] = []
  144. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  145. use_scopes = security_scopes
  146. security_requirement = SecurityRequirement(
  147. security_scheme=dependency, scopes=use_scopes
  148. )
  149. sub_dependant = get_dependant(
  150. path=path,
  151. call=dependency,
  152. name=name,
  153. security_scopes=security_scopes,
  154. use_cache=depends.use_cache,
  155. )
  156. if security_requirement:
  157. sub_dependant.security_requirements.append(security_requirement)
  158. return sub_dependant
  159. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  160. def get_flat_dependant(
  161. dependant: Dependant,
  162. *,
  163. skip_repeats: bool = False,
  164. visited: Optional[List[CacheKey]] = None,
  165. ) -> Dependant:
  166. if visited is None:
  167. visited = []
  168. visited.append(dependant.cache_key)
  169. flat_dependant = Dependant(
  170. path_params=dependant.path_params.copy(),
  171. query_params=dependant.query_params.copy(),
  172. header_params=dependant.header_params.copy(),
  173. cookie_params=dependant.cookie_params.copy(),
  174. body_params=dependant.body_params.copy(),
  175. security_requirements=dependant.security_requirements.copy(),
  176. use_cache=dependant.use_cache,
  177. path=dependant.path,
  178. )
  179. for sub_dependant in dependant.dependencies:
  180. if skip_repeats and sub_dependant.cache_key in visited:
  181. continue
  182. flat_sub = get_flat_dependant(
  183. sub_dependant, skip_repeats=skip_repeats, visited=visited
  184. )
  185. flat_dependant.path_params.extend(flat_sub.path_params)
  186. flat_dependant.query_params.extend(flat_sub.query_params)
  187. flat_dependant.header_params.extend(flat_sub.header_params)
  188. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  189. flat_dependant.body_params.extend(flat_sub.body_params)
  190. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  191. return flat_dependant
  192. def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
  193. if not fields:
  194. return fields
  195. first_field = fields[0]
  196. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  197. fields_to_extract = get_cached_model_fields(first_field.type_)
  198. return fields_to_extract
  199. return fields
  200. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  201. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  202. path_params = _get_flat_fields_from_params(flat_dependant.path_params)
  203. query_params = _get_flat_fields_from_params(flat_dependant.query_params)
  204. header_params = _get_flat_fields_from_params(flat_dependant.header_params)
  205. cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
  206. return path_params + query_params + header_params + cookie_params
  207. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  208. signature = inspect.signature(call)
  209. globalns = getattr(call, "__globals__", {})
  210. typed_params = [
  211. inspect.Parameter(
  212. name=param.name,
  213. kind=param.kind,
  214. default=param.default,
  215. annotation=get_typed_annotation(param.annotation, globalns),
  216. )
  217. for param in signature.parameters.values()
  218. ]
  219. typed_signature = inspect.Signature(typed_params)
  220. return typed_signature
  221. def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
  222. if isinstance(annotation, str):
  223. annotation = ForwardRef(annotation)
  224. annotation = evaluate_forwardref(annotation, globalns, globalns)
  225. return annotation
  226. def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
  227. signature = inspect.signature(call)
  228. annotation = signature.return_annotation
  229. if annotation is inspect.Signature.empty:
  230. return None
  231. globalns = getattr(call, "__globals__", {})
  232. return get_typed_annotation(annotation, globalns)
  233. def get_dependant(
  234. *,
  235. path: str,
  236. call: Callable[..., Any],
  237. name: Optional[str] = None,
  238. security_scopes: Optional[List[str]] = None,
  239. use_cache: bool = True,
  240. ) -> Dependant:
  241. path_param_names = get_path_param_names(path)
  242. endpoint_signature = get_typed_signature(call)
  243. signature_params = endpoint_signature.parameters
  244. dependant = Dependant(
  245. call=call,
  246. name=name,
  247. path=path,
  248. security_scopes=security_scopes,
  249. use_cache=use_cache,
  250. )
  251. for param_name, param in signature_params.items():
  252. is_path_param = param_name in path_param_names
  253. param_details = analyze_param(
  254. param_name=param_name,
  255. annotation=param.annotation,
  256. value=param.default,
  257. is_path_param=is_path_param,
  258. )
  259. if param_details.depends is not None:
  260. sub_dependant = get_param_sub_dependant(
  261. param_name=param_name,
  262. depends=param_details.depends,
  263. path=path,
  264. security_scopes=security_scopes,
  265. )
  266. dependant.dependencies.append(sub_dependant)
  267. continue
  268. if add_non_field_param_to_dependency(
  269. param_name=param_name,
  270. type_annotation=param_details.type_annotation,
  271. dependant=dependant,
  272. ):
  273. assert param_details.field is None, (
  274. f"Cannot specify multiple FastAPI annotations for {param_name!r}"
  275. )
  276. continue
  277. assert param_details.field is not None
  278. if isinstance(param_details.field.field_info, params.Body):
  279. dependant.body_params.append(param_details.field)
  280. else:
  281. add_param_to_fields(field=param_details.field, dependant=dependant)
  282. return dependant
  283. def add_non_field_param_to_dependency(
  284. *, param_name: str, type_annotation: Any, dependant: Dependant
  285. ) -> Optional[bool]:
  286. if lenient_issubclass(type_annotation, Request):
  287. dependant.request_param_name = param_name
  288. return True
  289. elif lenient_issubclass(type_annotation, WebSocket):
  290. dependant.websocket_param_name = param_name
  291. return True
  292. elif lenient_issubclass(type_annotation, HTTPConnection):
  293. dependant.http_connection_param_name = param_name
  294. return True
  295. elif lenient_issubclass(type_annotation, Response):
  296. dependant.response_param_name = param_name
  297. return True
  298. elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
  299. dependant.background_tasks_param_name = param_name
  300. return True
  301. elif lenient_issubclass(type_annotation, SecurityScopes):
  302. dependant.security_scopes_param_name = param_name
  303. return True
  304. return None
  305. @dataclass
  306. class ParamDetails:
  307. type_annotation: Any
  308. depends: Optional[params.Depends]
  309. field: Optional[ModelField]
  310. def analyze_param(
  311. *,
  312. param_name: str,
  313. annotation: Any,
  314. value: Any,
  315. is_path_param: bool,
  316. ) -> ParamDetails:
  317. field_info = None
  318. depends = None
  319. type_annotation: Any = Any
  320. use_annotation: Any = Any
  321. if annotation is not inspect.Signature.empty:
  322. use_annotation = annotation
  323. type_annotation = annotation
  324. # Extract Annotated info
  325. if get_origin(use_annotation) is Annotated:
  326. annotated_args = get_args(annotation)
  327. type_annotation = annotated_args[0]
  328. fastapi_annotations = [
  329. arg
  330. for arg in annotated_args[1:]
  331. if isinstance(arg, (FieldInfo, params.Depends))
  332. ]
  333. fastapi_specific_annotations = [
  334. arg
  335. for arg in fastapi_annotations
  336. if isinstance(arg, (params.Param, params.Body, params.Depends))
  337. ]
  338. if fastapi_specific_annotations:
  339. fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
  340. fastapi_specific_annotations[-1]
  341. )
  342. else:
  343. fastapi_annotation = None
  344. # Set default for Annotated FieldInfo
  345. if isinstance(fastapi_annotation, FieldInfo):
  346. # Copy `field_info` because we mutate `field_info.default` below.
  347. field_info = copy_field_info(
  348. field_info=fastapi_annotation, annotation=use_annotation
  349. )
  350. assert (
  351. field_info.default is Undefined or field_info.default is RequiredParam
  352. ), (
  353. f"`{field_info.__class__.__name__}` default value cannot be set in"
  354. f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
  355. )
  356. if value is not inspect.Signature.empty:
  357. assert not is_path_param, "Path parameters cannot have default values"
  358. field_info.default = value
  359. else:
  360. field_info.default = RequiredParam
  361. # Get Annotated Depends
  362. elif isinstance(fastapi_annotation, params.Depends):
  363. depends = fastapi_annotation
  364. # Get Depends from default value
  365. if isinstance(value, params.Depends):
  366. assert depends is None, (
  367. "Cannot specify `Depends` in `Annotated` and default value"
  368. f" together for {param_name!r}"
  369. )
  370. assert field_info is None, (
  371. "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
  372. f" default value together for {param_name!r}"
  373. )
  374. depends = value
  375. # Get FieldInfo from default value
  376. elif isinstance(value, FieldInfo):
  377. assert field_info is None, (
  378. "Cannot specify FastAPI annotations in `Annotated` and default value"
  379. f" together for {param_name!r}"
  380. )
  381. field_info = value
  382. if PYDANTIC_V2:
  383. field_info.annotation = type_annotation
  384. # Get Depends from type annotation
  385. if depends is not None and depends.dependency is None:
  386. # Copy `depends` before mutating it
  387. depends = copy(depends)
  388. depends.dependency = type_annotation
  389. # Handle non-param type annotations like Request
  390. if lenient_issubclass(
  391. type_annotation,
  392. (
  393. Request,
  394. WebSocket,
  395. HTTPConnection,
  396. Response,
  397. StarletteBackgroundTasks,
  398. SecurityScopes,
  399. ),
  400. ):
  401. assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
  402. assert field_info is None, (
  403. f"Cannot specify FastAPI annotation for type {type_annotation!r}"
  404. )
  405. # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
  406. elif field_info is None and depends is None:
  407. default_value = value if value is not inspect.Signature.empty else RequiredParam
  408. if is_path_param:
  409. # We might check here that `default_value is RequiredParam`, but the fact is that the same
  410. # parameter might sometimes be a path parameter and sometimes not. See
  411. # `tests/test_infer_param_optionality.py` for an example.
  412. field_info = params.Path(annotation=use_annotation)
  413. elif is_uploadfile_or_nonable_uploadfile_annotation(
  414. type_annotation
  415. ) or is_uploadfile_sequence_annotation(type_annotation):
  416. field_info = params.File(annotation=use_annotation, default=default_value)
  417. elif not field_annotation_is_scalar(annotation=type_annotation):
  418. field_info = params.Body(annotation=use_annotation, default=default_value)
  419. else:
  420. field_info = params.Query(annotation=use_annotation, default=default_value)
  421. field = None
  422. # It's a field_info, not a dependency
  423. if field_info is not None:
  424. # Handle field_info.in_
  425. if is_path_param:
  426. assert isinstance(field_info, params.Path), (
  427. f"Cannot use `{field_info.__class__.__name__}` for path param"
  428. f" {param_name!r}"
  429. )
  430. elif (
  431. isinstance(field_info, params.Param)
  432. and getattr(field_info, "in_", None) is None
  433. ):
  434. field_info.in_ = params.ParamTypes.query
  435. use_annotation_from_field_info = get_annotation_from_field_info(
  436. use_annotation,
  437. field_info,
  438. param_name,
  439. )
  440. if isinstance(field_info, params.Form):
  441. ensure_multipart_is_installed()
  442. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  443. alias = param_name.replace("_", "-")
  444. else:
  445. alias = field_info.alias or param_name
  446. field_info.alias = alias
  447. field = create_model_field(
  448. name=param_name,
  449. type_=use_annotation_from_field_info,
  450. default=field_info.default,
  451. alias=alias,
  452. required=field_info.default in (RequiredParam, Undefined),
  453. field_info=field_info,
  454. )
  455. if is_path_param:
  456. assert is_scalar_field(field=field), (
  457. "Path params must be of one of the supported types"
  458. )
  459. elif isinstance(field_info, params.Query):
  460. assert (
  461. is_scalar_field(field)
  462. or is_scalar_sequence_field(field)
  463. or (
  464. lenient_issubclass(field.type_, BaseModel)
  465. # For Pydantic v1
  466. and getattr(field, "shape", 1) == 1
  467. )
  468. )
  469. return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
  470. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  471. field_info = field.field_info
  472. field_info_in = getattr(field_info, "in_", None)
  473. if field_info_in == params.ParamTypes.path:
  474. dependant.path_params.append(field)
  475. elif field_info_in == params.ParamTypes.query:
  476. dependant.query_params.append(field)
  477. elif field_info_in == params.ParamTypes.header:
  478. dependant.header_params.append(field)
  479. else:
  480. assert field_info_in == params.ParamTypes.cookie, (
  481. f"non-body parameters must be in path, query, header or cookie: {field.name}"
  482. )
  483. dependant.cookie_params.append(field)
  484. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  485. if inspect.isroutine(call):
  486. return inspect.iscoroutinefunction(call)
  487. if inspect.isclass(call):
  488. return False
  489. dunder_call = getattr(call, "__call__", None) # noqa: B004
  490. return inspect.iscoroutinefunction(dunder_call)
  491. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  492. if inspect.isasyncgenfunction(call):
  493. return True
  494. dunder_call = getattr(call, "__call__", None) # noqa: B004
  495. return inspect.isasyncgenfunction(dunder_call)
  496. def is_gen_callable(call: Callable[..., Any]) -> bool:
  497. if inspect.isgeneratorfunction(call):
  498. return True
  499. dunder_call = getattr(call, "__call__", None) # noqa: B004
  500. return inspect.isgeneratorfunction(dunder_call)
  501. async def solve_generator(
  502. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  503. ) -> Any:
  504. if is_gen_callable(call):
  505. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  506. elif is_async_gen_callable(call):
  507. cm = asynccontextmanager(call)(**sub_values)
  508. return await stack.enter_async_context(cm)
  509. @dataclass
  510. class SolvedDependency:
  511. values: Dict[str, Any]
  512. errors: List[Any]
  513. background_tasks: Optional[StarletteBackgroundTasks]
  514. response: Response
  515. dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
  516. async def solve_dependencies(
  517. *,
  518. request: Union[Request, WebSocket],
  519. dependant: Dependant,
  520. body: Optional[Union[Dict[str, Any], FormData]] = None,
  521. background_tasks: Optional[StarletteBackgroundTasks] = None,
  522. response: Optional[Response] = None,
  523. dependency_overrides_provider: Optional[Any] = None,
  524. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  525. async_exit_stack: AsyncExitStack,
  526. embed_body_fields: bool,
  527. ) -> SolvedDependency:
  528. values: Dict[str, Any] = {}
  529. errors: List[Any] = []
  530. if response is None:
  531. response = Response()
  532. del response.headers["content-length"]
  533. response.status_code = None # type: ignore
  534. dependency_cache = dependency_cache or {}
  535. sub_dependant: Dependant
  536. for sub_dependant in dependant.dependencies:
  537. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  538. sub_dependant.cache_key = cast(
  539. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  540. )
  541. call = sub_dependant.call
  542. use_sub_dependant = sub_dependant
  543. if (
  544. dependency_overrides_provider
  545. and dependency_overrides_provider.dependency_overrides
  546. ):
  547. original_call = sub_dependant.call
  548. call = getattr(
  549. dependency_overrides_provider, "dependency_overrides", {}
  550. ).get(original_call, original_call)
  551. use_path: str = sub_dependant.path # type: ignore
  552. use_sub_dependant = get_dependant(
  553. path=use_path,
  554. call=call,
  555. name=sub_dependant.name,
  556. security_scopes=sub_dependant.security_scopes,
  557. )
  558. solved_result = await solve_dependencies(
  559. request=request,
  560. dependant=use_sub_dependant,
  561. body=body,
  562. background_tasks=background_tasks,
  563. response=response,
  564. dependency_overrides_provider=dependency_overrides_provider,
  565. dependency_cache=dependency_cache,
  566. async_exit_stack=async_exit_stack,
  567. embed_body_fields=embed_body_fields,
  568. )
  569. background_tasks = solved_result.background_tasks
  570. dependency_cache.update(solved_result.dependency_cache)
  571. if solved_result.errors:
  572. errors.extend(solved_result.errors)
  573. continue
  574. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  575. solved = dependency_cache[sub_dependant.cache_key]
  576. elif is_gen_callable(call) or is_async_gen_callable(call):
  577. solved = await solve_generator(
  578. call=call, stack=async_exit_stack, sub_values=solved_result.values
  579. )
  580. elif is_coroutine_callable(call):
  581. solved = await call(**solved_result.values)
  582. else:
  583. solved = await run_in_threadpool(call, **solved_result.values)
  584. if sub_dependant.name is not None:
  585. values[sub_dependant.name] = solved
  586. if sub_dependant.cache_key not in dependency_cache:
  587. dependency_cache[sub_dependant.cache_key] = solved
  588. path_values, path_errors = request_params_to_args(
  589. dependant.path_params, request.path_params
  590. )
  591. query_values, query_errors = request_params_to_args(
  592. dependant.query_params, request.query_params
  593. )
  594. header_values, header_errors = request_params_to_args(
  595. dependant.header_params, request.headers
  596. )
  597. cookie_values, cookie_errors = request_params_to_args(
  598. dependant.cookie_params, request.cookies
  599. )
  600. values.update(path_values)
  601. values.update(query_values)
  602. values.update(header_values)
  603. values.update(cookie_values)
  604. errors += path_errors + query_errors + header_errors + cookie_errors
  605. if dependant.body_params:
  606. (
  607. body_values,
  608. body_errors,
  609. ) = await request_body_to_args( # body_params checked above
  610. body_fields=dependant.body_params,
  611. received_body=body,
  612. embed_body_fields=embed_body_fields,
  613. )
  614. values.update(body_values)
  615. errors.extend(body_errors)
  616. if dependant.http_connection_param_name:
  617. values[dependant.http_connection_param_name] = request
  618. if dependant.request_param_name and isinstance(request, Request):
  619. values[dependant.request_param_name] = request
  620. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  621. values[dependant.websocket_param_name] = request
  622. if dependant.background_tasks_param_name:
  623. if background_tasks is None:
  624. background_tasks = BackgroundTasks()
  625. values[dependant.background_tasks_param_name] = background_tasks
  626. if dependant.response_param_name:
  627. values[dependant.response_param_name] = response
  628. if dependant.security_scopes_param_name:
  629. values[dependant.security_scopes_param_name] = SecurityScopes(
  630. scopes=dependant.security_scopes
  631. )
  632. return SolvedDependency(
  633. values=values,
  634. errors=errors,
  635. background_tasks=background_tasks,
  636. response=response,
  637. dependency_cache=dependency_cache,
  638. )
  639. def _validate_value_with_model_field(
  640. *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
  641. ) -> Tuple[Any, List[Any]]:
  642. if value is None:
  643. if field.required:
  644. return None, [get_missing_field_error(loc=loc)]
  645. else:
  646. return deepcopy(field.default), []
  647. v_, errors_ = field.validate(value, values, loc=loc)
  648. if isinstance(errors_, ErrorWrapper):
  649. return None, [errors_]
  650. elif isinstance(errors_, list):
  651. new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
  652. return None, new_errors
  653. else:
  654. return v_, []
  655. def _get_multidict_value(
  656. field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
  657. ) -> Any:
  658. alias = alias or field.alias
  659. if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
  660. value = values.getlist(alias)
  661. else:
  662. value = values.get(alias, None)
  663. if (
  664. value is None
  665. or (
  666. isinstance(field.field_info, params.Form)
  667. and isinstance(value, str) # For type checks
  668. and value == ""
  669. )
  670. or (is_sequence_field(field) and len(value) == 0)
  671. ):
  672. if field.required:
  673. return
  674. else:
  675. return deepcopy(field.default)
  676. return value
  677. def request_params_to_args(
  678. fields: Sequence[ModelField],
  679. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  680. ) -> Tuple[Dict[str, Any], List[Any]]:
  681. values: Dict[str, Any] = {}
  682. errors: List[Dict[str, Any]] = []
  683. if not fields:
  684. return values, errors
  685. first_field = fields[0]
  686. fields_to_extract = fields
  687. single_not_embedded_field = False
  688. default_convert_underscores = True
  689. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  690. fields_to_extract = get_cached_model_fields(first_field.type_)
  691. single_not_embedded_field = True
  692. # If headers are in a Pydantic model, the way to disable convert_underscores
  693. # would be with Header(convert_underscores=False) at the Pydantic model level
  694. default_convert_underscores = getattr(
  695. first_field.field_info, "convert_underscores", True
  696. )
  697. params_to_process: Dict[str, Any] = {}
  698. processed_keys = set()
  699. for field in fields_to_extract:
  700. alias = None
  701. if isinstance(received_params, Headers):
  702. # Handle fields extracted from a Pydantic Model for a header, each field
  703. # doesn't have a FieldInfo of type Header with the default convert_underscores=True
  704. convert_underscores = getattr(
  705. field.field_info, "convert_underscores", default_convert_underscores
  706. )
  707. if convert_underscores:
  708. alias = (
  709. field.alias
  710. if field.alias != field.name
  711. else field.name.replace("_", "-")
  712. )
  713. value = _get_multidict_value(field, received_params, alias=alias)
  714. if value is not None:
  715. params_to_process[field.name] = value
  716. processed_keys.add(alias or field.alias)
  717. processed_keys.add(field.name)
  718. for key, value in received_params.items():
  719. if key not in processed_keys:
  720. params_to_process[key] = value
  721. if single_not_embedded_field:
  722. field_info = first_field.field_info
  723. assert isinstance(field_info, params.Param), (
  724. "Params must be subclasses of Param"
  725. )
  726. loc: Tuple[str, ...] = (field_info.in_.value,)
  727. v_, errors_ = _validate_value_with_model_field(
  728. field=first_field, value=params_to_process, values=values, loc=loc
  729. )
  730. return {first_field.name: v_}, errors_
  731. for field in fields:
  732. value = _get_multidict_value(field, received_params)
  733. field_info = field.field_info
  734. assert isinstance(field_info, params.Param), (
  735. "Params must be subclasses of Param"
  736. )
  737. loc = (field_info.in_.value, field.alias)
  738. v_, errors_ = _validate_value_with_model_field(
  739. field=field, value=value, values=values, loc=loc
  740. )
  741. if errors_:
  742. errors.extend(errors_)
  743. else:
  744. values[field.name] = v_
  745. return values, errors
  746. def is_union_of_base_models(field_type: Any) -> bool:
  747. """Check if field type is a Union where all members are BaseModel subclasses."""
  748. from fastapi.types import UnionType
  749. origin = get_origin(field_type)
  750. # Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
  751. if origin is not Union and origin is not UnionType:
  752. return False
  753. union_args = get_args(field_type)
  754. for arg in union_args:
  755. if not lenient_issubclass(arg, BaseModel):
  756. return False
  757. return True
  758. def _should_embed_body_fields(fields: List[ModelField]) -> bool:
  759. if not fields:
  760. return False
  761. # More than one dependency could have the same field, it would show up as multiple
  762. # fields but it's the same one, so count them by name
  763. body_param_names_set = {field.name for field in fields}
  764. # A top level field has to be a single field, not multiple
  765. if len(body_param_names_set) > 1:
  766. return True
  767. first_field = fields[0]
  768. # If it explicitly specifies it is embedded, it has to be embedded
  769. if getattr(first_field.field_info, "embed", None):
  770. return True
  771. # If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
  772. # otherwise it has to be embedded, so that the key value pair can be extracted
  773. if (
  774. isinstance(first_field.field_info, params.Form)
  775. and not lenient_issubclass(first_field.type_, BaseModel)
  776. and not is_union_of_base_models(first_field.type_)
  777. ):
  778. return True
  779. return False
  780. async def _extract_form_body(
  781. body_fields: List[ModelField],
  782. received_body: FormData,
  783. ) -> Dict[str, Any]:
  784. values = {}
  785. first_field = body_fields[0]
  786. first_field_info = first_field.field_info
  787. for field in body_fields:
  788. value = _get_multidict_value(field, received_body)
  789. if (
  790. isinstance(first_field_info, params.File)
  791. and is_bytes_field(field)
  792. and isinstance(value, UploadFile)
  793. ):
  794. value = await value.read()
  795. elif (
  796. is_bytes_sequence_field(field)
  797. and isinstance(first_field_info, params.File)
  798. and value_is_sequence(value)
  799. ):
  800. # For types
  801. assert isinstance(value, sequence_types) # type: ignore[arg-type]
  802. results: List[Union[bytes, str]] = []
  803. async def process_fn(
  804. fn: Callable[[], Coroutine[Any, Any, Any]],
  805. ) -> None:
  806. result = await fn()
  807. results.append(result) # noqa: B023
  808. async with anyio.create_task_group() as tg:
  809. for sub_value in value:
  810. tg.start_soon(process_fn, sub_value.read)
  811. value = serialize_sequence_value(field=field, value=results)
  812. if value is not None:
  813. values[field.alias] = value
  814. for key, value in received_body.items():
  815. if key not in values:
  816. values[key] = value
  817. return values
  818. async def request_body_to_args(
  819. body_fields: List[ModelField],
  820. received_body: Optional[Union[Dict[str, Any], FormData]],
  821. embed_body_fields: bool,
  822. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  823. values: Dict[str, Any] = {}
  824. errors: List[Dict[str, Any]] = []
  825. assert body_fields, "request_body_to_args() should be called with fields"
  826. single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
  827. first_field = body_fields[0]
  828. body_to_process = received_body
  829. fields_to_extract: List[ModelField] = body_fields
  830. if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
  831. fields_to_extract = get_cached_model_fields(first_field.type_)
  832. if isinstance(received_body, FormData):
  833. body_to_process = await _extract_form_body(fields_to_extract, received_body)
  834. if single_not_embedded_field:
  835. loc: Tuple[str, ...] = ("body",)
  836. v_, errors_ = _validate_value_with_model_field(
  837. field=first_field, value=body_to_process, values=values, loc=loc
  838. )
  839. return {first_field.name: v_}, errors_
  840. for field in body_fields:
  841. loc = ("body", field.alias)
  842. value: Optional[Any] = None
  843. if body_to_process is not None:
  844. try:
  845. value = body_to_process.get(field.alias)
  846. # If the received body is a list, not a dict
  847. except AttributeError:
  848. errors.append(get_missing_field_error(loc))
  849. continue
  850. v_, errors_ = _validate_value_with_model_field(
  851. field=field, value=value, values=values, loc=loc
  852. )
  853. if errors_:
  854. errors.extend(errors_)
  855. else:
  856. values[field.name] = v_
  857. return values, errors
  858. def get_body_field(
  859. *, flat_dependant: Dependant, name: str, embed_body_fields: bool
  860. ) -> Optional[ModelField]:
  861. """
  862. Get a ModelField representing the request body for a path operation, combining
  863. all body parameters into a single field if necessary.
  864. Used to check if it's form data (with `isinstance(body_field, params.Form)`)
  865. or JSON and to generate the JSON Schema for a request body.
  866. This is **not** used to validate/parse the request body, that's done with each
  867. individual body parameter.
  868. """
  869. if not flat_dependant.body_params:
  870. return None
  871. first_param = flat_dependant.body_params[0]
  872. if not embed_body_fields:
  873. return first_param
  874. model_name = "Body_" + name
  875. BodyModel = create_body_model(
  876. fields=flat_dependant.body_params, model_name=model_name
  877. )
  878. required = any(True for f in flat_dependant.body_params if f.required)
  879. BodyFieldInfo_kwargs: Dict[str, Any] = {
  880. "annotation": BodyModel,
  881. "alias": "body",
  882. }
  883. if not required:
  884. BodyFieldInfo_kwargs["default"] = None
  885. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  886. BodyFieldInfo: Type[params.Body] = params.File
  887. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  888. BodyFieldInfo = params.Form
  889. else:
  890. BodyFieldInfo = params.Body
  891. body_param_media_types = [
  892. f.field_info.media_type
  893. for f in flat_dependant.body_params
  894. if isinstance(f.field_info, params.Body)
  895. ]
  896. if len(set(body_param_media_types)) == 1:
  897. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  898. final_field = create_model_field(
  899. name="body",
  900. type_=BodyModel,
  901. required=required,
  902. alias="body",
  903. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  904. )
  905. return final_field