您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

1151 行
48 KiB

  1. import asyncio
  2. import dataclasses
  3. import email.message
  4. import enum
  5. import inspect
  6. import json
  7. from typing import (
  8. Any,
  9. Callable,
  10. Coroutine,
  11. Dict,
  12. List,
  13. Optional,
  14. Sequence,
  15. Set,
  16. Type,
  17. Union,
  18. )
  19. from fastapi import params
  20. from fastapi.datastructures import Default, DefaultPlaceholder
  21. from fastapi.dependencies.models import Dependant
  22. from fastapi.dependencies.utils import (
  23. get_body_field,
  24. get_dependant,
  25. get_parameterless_sub_dependant,
  26. solve_dependencies,
  27. )
  28. from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
  29. from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
  30. from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
  31. from fastapi.types import DecoratedCallable
  32. from fastapi.utils import (
  33. create_cloned_field,
  34. create_response_field,
  35. generate_operation_id_for_path,
  36. get_value_or_default,
  37. )
  38. from pydantic import BaseModel
  39. from pydantic.error_wrappers import ErrorWrapper, ValidationError
  40. from pydantic.fields import ModelField, Undefined
  41. from starlette import routing
  42. from starlette.concurrency import run_in_threadpool
  43. from starlette.exceptions import HTTPException
  44. from starlette.requests import Request
  45. from starlette.responses import JSONResponse, Response
  46. from starlette.routing import BaseRoute
  47. from starlette.routing import Mount as Mount # noqa
  48. from starlette.routing import (
  49. compile_path,
  50. get_name,
  51. request_response,
  52. websocket_session,
  53. )
  54. from starlette.status import WS_1008_POLICY_VIOLATION
  55. from starlette.types import ASGIApp
  56. from starlette.websockets import WebSocket
  57. def _prepare_response_content(
  58. res: Any,
  59. *,
  60. exclude_unset: bool,
  61. exclude_defaults: bool = False,
  62. exclude_none: bool = False,
  63. ) -> Any:
  64. if isinstance(res, BaseModel):
  65. read_with_orm_mode = getattr(res.__config__, "read_with_orm_mode", None)
  66. if read_with_orm_mode:
  67. # Let from_orm extract the data from this model instead of converting
  68. # it now to a dict.
  69. # Otherwise there's no way to extract lazy data that requires attribute
  70. # access instead of dict iteration, e.g. lazy relationships.
  71. return res
  72. return res.dict(
  73. by_alias=True,
  74. exclude_unset=exclude_unset,
  75. exclude_defaults=exclude_defaults,
  76. exclude_none=exclude_none,
  77. )
  78. elif isinstance(res, list):
  79. return [
  80. _prepare_response_content(
  81. item,
  82. exclude_unset=exclude_unset,
  83. exclude_defaults=exclude_defaults,
  84. exclude_none=exclude_none,
  85. )
  86. for item in res
  87. ]
  88. elif isinstance(res, dict):
  89. return {
  90. k: _prepare_response_content(
  91. v,
  92. exclude_unset=exclude_unset,
  93. exclude_defaults=exclude_defaults,
  94. exclude_none=exclude_none,
  95. )
  96. for k, v in res.items()
  97. }
  98. elif dataclasses.is_dataclass(res):
  99. return dataclasses.asdict(res)
  100. return res
  101. async def serialize_response(
  102. *,
  103. field: Optional[ModelField] = None,
  104. response_content: Any,
  105. include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  106. exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  107. by_alias: bool = True,
  108. exclude_unset: bool = False,
  109. exclude_defaults: bool = False,
  110. exclude_none: bool = False,
  111. is_coroutine: bool = True,
  112. ) -> Any:
  113. if field:
  114. errors = []
  115. response_content = _prepare_response_content(
  116. response_content,
  117. exclude_unset=exclude_unset,
  118. exclude_defaults=exclude_defaults,
  119. exclude_none=exclude_none,
  120. )
  121. if is_coroutine:
  122. value, errors_ = field.validate(response_content, {}, loc=("response",))
  123. else:
  124. value, errors_ = await run_in_threadpool(
  125. field.validate, response_content, {}, loc=("response",)
  126. )
  127. if isinstance(errors_, ErrorWrapper):
  128. errors.append(errors_)
  129. elif isinstance(errors_, list):
  130. errors.extend(errors_)
  131. if errors:
  132. raise ValidationError(errors, field.type_)
  133. return jsonable_encoder(
  134. value,
  135. include=include,
  136. exclude=exclude,
  137. by_alias=by_alias,
  138. exclude_unset=exclude_unset,
  139. exclude_defaults=exclude_defaults,
  140. exclude_none=exclude_none,
  141. )
  142. else:
  143. return jsonable_encoder(response_content)
  144. async def run_endpoint_function(
  145. *, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
  146. ) -> Any:
  147. # Only called by get_request_handler. Has been split into its own function to
  148. # facilitate profiling endpoints, since inner functions are harder to profile.
  149. assert dependant.call is not None, "dependant.call must be a function"
  150. if is_coroutine:
  151. return await dependant.call(**values)
  152. else:
  153. return await run_in_threadpool(dependant.call, **values)
  154. def get_request_handler(
  155. dependant: Dependant,
  156. body_field: Optional[ModelField] = None,
  157. status_code: Optional[int] = None,
  158. response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
  159. response_field: Optional[ModelField] = None,
  160. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  161. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  162. response_model_by_alias: bool = True,
  163. response_model_exclude_unset: bool = False,
  164. response_model_exclude_defaults: bool = False,
  165. response_model_exclude_none: bool = False,
  166. dependency_overrides_provider: Optional[Any] = None,
  167. ) -> Callable[[Request], Coroutine[Any, Any, Response]]:
  168. assert dependant.call is not None, "dependant.call must be a function"
  169. is_coroutine = asyncio.iscoroutinefunction(dependant.call)
  170. is_body_form = body_field and isinstance(body_field.field_info, params.Form)
  171. if isinstance(response_class, DefaultPlaceholder):
  172. actual_response_class: Type[Response] = response_class.value
  173. else:
  174. actual_response_class = response_class
  175. async def app(request: Request) -> Response:
  176. try:
  177. body: Any = None
  178. if body_field:
  179. if is_body_form:
  180. body = await request.form()
  181. else:
  182. body_bytes = await request.body()
  183. if body_bytes:
  184. json_body: Any = Undefined
  185. content_type_value = request.headers.get("content-type")
  186. if not content_type_value:
  187. json_body = await request.json()
  188. else:
  189. message = email.message.Message()
  190. message["content-type"] = content_type_value
  191. if message.get_content_maintype() == "application":
  192. subtype = message.get_content_subtype()
  193. if subtype == "json" or subtype.endswith("+json"):
  194. json_body = await request.json()
  195. if json_body != Undefined:
  196. body = json_body
  197. else:
  198. body = body_bytes
  199. except json.JSONDecodeError as e:
  200. raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
  201. except Exception as e:
  202. raise HTTPException(
  203. status_code=400, detail="There was an error parsing the body"
  204. ) from e
  205. solved_result = await solve_dependencies(
  206. request=request,
  207. dependant=dependant,
  208. body=body,
  209. dependency_overrides_provider=dependency_overrides_provider,
  210. )
  211. values, errors, background_tasks, sub_response, _ = solved_result
  212. if errors:
  213. raise RequestValidationError(errors, body=body)
  214. else:
  215. raw_response = await run_endpoint_function(
  216. dependant=dependant, values=values, is_coroutine=is_coroutine
  217. )
  218. if isinstance(raw_response, Response):
  219. if raw_response.background is None:
  220. raw_response.background = background_tasks
  221. return raw_response
  222. response_data = await serialize_response(
  223. field=response_field,
  224. response_content=raw_response,
  225. include=response_model_include,
  226. exclude=response_model_exclude,
  227. by_alias=response_model_by_alias,
  228. exclude_unset=response_model_exclude_unset,
  229. exclude_defaults=response_model_exclude_defaults,
  230. exclude_none=response_model_exclude_none,
  231. is_coroutine=is_coroutine,
  232. )
  233. response_args: Dict[str, Any] = {"background": background_tasks}
  234. # If status_code was set, use it, otherwise use the default from the
  235. # response class, in the case of redirect it's 307
  236. if status_code is not None:
  237. response_args["status_code"] = status_code
  238. response = actual_response_class(response_data, **response_args)
  239. response.headers.raw.extend(sub_response.headers.raw)
  240. if sub_response.status_code:
  241. response.status_code = sub_response.status_code
  242. return response
  243. return app
  244. def get_websocket_app(
  245. dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
  246. ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
  247. async def app(websocket: WebSocket) -> None:
  248. solved_result = await solve_dependencies(
  249. request=websocket,
  250. dependant=dependant,
  251. dependency_overrides_provider=dependency_overrides_provider,
  252. )
  253. values, errors, _, _2, _3 = solved_result
  254. if errors:
  255. await websocket.close(code=WS_1008_POLICY_VIOLATION)
  256. raise WebSocketRequestValidationError(errors)
  257. assert dependant.call is not None, "dependant.call must be a function"
  258. await dependant.call(**values)
  259. return app
  260. class APIWebSocketRoute(routing.WebSocketRoute):
  261. def __init__(
  262. self,
  263. path: str,
  264. endpoint: Callable[..., Any],
  265. *,
  266. name: Optional[str] = None,
  267. dependency_overrides_provider: Optional[Any] = None,
  268. ) -> None:
  269. self.path = path
  270. self.endpoint = endpoint
  271. self.name = get_name(endpoint) if name is None else name
  272. self.dependant = get_dependant(path=path, call=self.endpoint)
  273. self.app = websocket_session(
  274. get_websocket_app(
  275. dependant=self.dependant,
  276. dependency_overrides_provider=dependency_overrides_provider,
  277. )
  278. )
  279. self.path_regex, self.path_format, self.param_convertors = compile_path(path)
  280. class APIRoute(routing.Route):
  281. def __init__(
  282. self,
  283. path: str,
  284. endpoint: Callable[..., Any],
  285. *,
  286. response_model: Optional[Type[Any]] = None,
  287. status_code: Optional[int] = None,
  288. tags: Optional[List[str]] = None,
  289. dependencies: Optional[Sequence[params.Depends]] = None,
  290. summary: Optional[str] = None,
  291. description: Optional[str] = None,
  292. response_description: str = "Successful Response",
  293. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  294. deprecated: Optional[bool] = None,
  295. name: Optional[str] = None,
  296. methods: Optional[Union[Set[str], List[str]]] = None,
  297. operation_id: Optional[str] = None,
  298. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  299. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  300. response_model_by_alias: bool = True,
  301. response_model_exclude_unset: bool = False,
  302. response_model_exclude_defaults: bool = False,
  303. response_model_exclude_none: bool = False,
  304. include_in_schema: bool = True,
  305. response_class: Union[Type[Response], DefaultPlaceholder] = Default(
  306. JSONResponse
  307. ),
  308. dependency_overrides_provider: Optional[Any] = None,
  309. callbacks: Optional[List[BaseRoute]] = None,
  310. openapi_extra: Optional[Dict[str, Any]] = None,
  311. ) -> None:
  312. # normalise enums e.g. http.HTTPStatus
  313. if isinstance(status_code, enum.IntEnum):
  314. status_code = int(status_code)
  315. self.path = path
  316. self.endpoint = endpoint
  317. self.name = get_name(endpoint) if name is None else name
  318. self.path_regex, self.path_format, self.param_convertors = compile_path(path)
  319. if methods is None:
  320. methods = ["GET"]
  321. self.methods: Set[str] = set([method.upper() for method in methods])
  322. self.unique_id = generate_operation_id_for_path(
  323. name=self.name, path=self.path_format, method=list(methods)[0]
  324. )
  325. self.response_model = response_model
  326. if self.response_model:
  327. assert (
  328. status_code not in STATUS_CODES_WITH_NO_BODY
  329. ), f"Status code {status_code} must not have a response body"
  330. response_name = "Response_" + self.unique_id
  331. self.response_field = create_response_field(
  332. name=response_name, type_=self.response_model
  333. )
  334. # Create a clone of the field, so that a Pydantic submodel is not returned
  335. # as is just because it's an instance of a subclass of a more limited class
  336. # e.g. UserInDB (containing hashed_password) could be a subclass of User
  337. # that doesn't have the hashed_password. But because it's a subclass, it
  338. # would pass the validation and be returned as is.
  339. # By being a new field, no inheritance will be passed as is. A new model
  340. # will be always created.
  341. self.secure_cloned_response_field: Optional[
  342. ModelField
  343. ] = create_cloned_field(self.response_field)
  344. else:
  345. self.response_field = None # type: ignore
  346. self.secure_cloned_response_field = None
  347. self.status_code = status_code
  348. self.tags = tags or []
  349. if dependencies:
  350. self.dependencies = list(dependencies)
  351. else:
  352. self.dependencies = []
  353. self.summary = summary
  354. self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
  355. # if a "form feed" character (page break) is found in the description text,
  356. # truncate description text to the content preceding the first "form feed"
  357. self.description = self.description.split("\f")[0]
  358. self.response_description = response_description
  359. self.responses = responses or {}
  360. response_fields = {}
  361. for additional_status_code, response in self.responses.items():
  362. assert isinstance(response, dict), "An additional response must be a dict"
  363. model = response.get("model")
  364. if model:
  365. assert (
  366. additional_status_code not in STATUS_CODES_WITH_NO_BODY
  367. ), f"Status code {additional_status_code} must not have a response body"
  368. response_name = f"Response_{additional_status_code}_{self.unique_id}"
  369. response_field = create_response_field(name=response_name, type_=model)
  370. response_fields[additional_status_code] = response_field
  371. if response_fields:
  372. self.response_fields: Dict[Union[int, str], ModelField] = response_fields
  373. else:
  374. self.response_fields = {}
  375. self.deprecated = deprecated
  376. self.operation_id = operation_id
  377. self.response_model_include = response_model_include
  378. self.response_model_exclude = response_model_exclude
  379. self.response_model_by_alias = response_model_by_alias
  380. self.response_model_exclude_unset = response_model_exclude_unset
  381. self.response_model_exclude_defaults = response_model_exclude_defaults
  382. self.response_model_exclude_none = response_model_exclude_none
  383. self.include_in_schema = include_in_schema
  384. self.response_class = response_class
  385. assert callable(endpoint), "An endpoint must be a callable"
  386. self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
  387. for depends in self.dependencies[::-1]:
  388. self.dependant.dependencies.insert(
  389. 0,
  390. get_parameterless_sub_dependant(depends=depends, path=self.path_format),
  391. )
  392. self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
  393. self.dependency_overrides_provider = dependency_overrides_provider
  394. self.callbacks = callbacks
  395. self.app = request_response(self.get_route_handler())
  396. self.openapi_extra = openapi_extra
  397. def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
  398. return get_request_handler(
  399. dependant=self.dependant,
  400. body_field=self.body_field,
  401. status_code=self.status_code,
  402. response_class=self.response_class,
  403. response_field=self.secure_cloned_response_field,
  404. response_model_include=self.response_model_include,
  405. response_model_exclude=self.response_model_exclude,
  406. response_model_by_alias=self.response_model_by_alias,
  407. response_model_exclude_unset=self.response_model_exclude_unset,
  408. response_model_exclude_defaults=self.response_model_exclude_defaults,
  409. response_model_exclude_none=self.response_model_exclude_none,
  410. dependency_overrides_provider=self.dependency_overrides_provider,
  411. )
  412. class APIRouter(routing.Router):
  413. def __init__(
  414. self,
  415. *,
  416. prefix: str = "",
  417. tags: Optional[List[str]] = None,
  418. dependencies: Optional[Sequence[params.Depends]] = None,
  419. default_response_class: Type[Response] = Default(JSONResponse),
  420. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  421. callbacks: Optional[List[BaseRoute]] = None,
  422. routes: Optional[List[routing.BaseRoute]] = None,
  423. redirect_slashes: bool = True,
  424. default: Optional[ASGIApp] = None,
  425. dependency_overrides_provider: Optional[Any] = None,
  426. route_class: Type[APIRoute] = APIRoute,
  427. on_startup: Optional[Sequence[Callable[[], Any]]] = None,
  428. on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
  429. deprecated: Optional[bool] = None,
  430. include_in_schema: bool = True,
  431. ) -> None:
  432. super().__init__(
  433. routes=routes, # type: ignore # in Starlette
  434. redirect_slashes=redirect_slashes,
  435. default=default, # type: ignore # in Starlette
  436. on_startup=on_startup, # type: ignore # in Starlette
  437. on_shutdown=on_shutdown, # type: ignore # in Starlette
  438. )
  439. if prefix:
  440. assert prefix.startswith("/"), "A path prefix must start with '/'"
  441. assert not prefix.endswith(
  442. "/"
  443. ), "A path prefix must not end with '/', as the routes will start with '/'"
  444. self.prefix = prefix
  445. self.tags: List[str] = tags or []
  446. self.dependencies = list(dependencies or []) or []
  447. self.deprecated = deprecated
  448. self.include_in_schema = include_in_schema
  449. self.responses = responses or {}
  450. self.callbacks = callbacks or []
  451. self.dependency_overrides_provider = dependency_overrides_provider
  452. self.route_class = route_class
  453. self.default_response_class = default_response_class
  454. def add_api_route(
  455. self,
  456. path: str,
  457. endpoint: Callable[..., Any],
  458. *,
  459. response_model: Optional[Type[Any]] = None,
  460. status_code: Optional[int] = None,
  461. tags: Optional[List[str]] = None,
  462. dependencies: Optional[Sequence[params.Depends]] = None,
  463. summary: Optional[str] = None,
  464. description: Optional[str] = None,
  465. response_description: str = "Successful Response",
  466. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  467. deprecated: Optional[bool] = None,
  468. methods: Optional[Union[Set[str], List[str]]] = None,
  469. operation_id: Optional[str] = None,
  470. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  471. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  472. response_model_by_alias: bool = True,
  473. response_model_exclude_unset: bool = False,
  474. response_model_exclude_defaults: bool = False,
  475. response_model_exclude_none: bool = False,
  476. include_in_schema: bool = True,
  477. response_class: Union[Type[Response], DefaultPlaceholder] = Default(
  478. JSONResponse
  479. ),
  480. name: Optional[str] = None,
  481. route_class_override: Optional[Type[APIRoute]] = None,
  482. callbacks: Optional[List[BaseRoute]] = None,
  483. openapi_extra: Optional[Dict[str, Any]] = None,
  484. ) -> None:
  485. route_class = route_class_override or self.route_class
  486. responses = responses or {}
  487. combined_responses = {**self.responses, **responses}
  488. current_response_class = get_value_or_default(
  489. response_class, self.default_response_class
  490. )
  491. current_tags = self.tags.copy()
  492. if tags:
  493. current_tags.extend(tags)
  494. current_dependencies = self.dependencies.copy()
  495. if dependencies:
  496. current_dependencies.extend(dependencies)
  497. current_callbacks = self.callbacks.copy()
  498. if callbacks:
  499. current_callbacks.extend(callbacks)
  500. route = route_class(
  501. self.prefix + path,
  502. endpoint=endpoint,
  503. response_model=response_model,
  504. status_code=status_code,
  505. tags=current_tags,
  506. dependencies=current_dependencies,
  507. summary=summary,
  508. description=description,
  509. response_description=response_description,
  510. responses=combined_responses,
  511. deprecated=deprecated or self.deprecated,
  512. methods=methods,
  513. operation_id=operation_id,
  514. response_model_include=response_model_include,
  515. response_model_exclude=response_model_exclude,
  516. response_model_by_alias=response_model_by_alias,
  517. response_model_exclude_unset=response_model_exclude_unset,
  518. response_model_exclude_defaults=response_model_exclude_defaults,
  519. response_model_exclude_none=response_model_exclude_none,
  520. include_in_schema=include_in_schema and self.include_in_schema,
  521. response_class=current_response_class,
  522. name=name,
  523. dependency_overrides_provider=self.dependency_overrides_provider,
  524. callbacks=current_callbacks,
  525. openapi_extra=openapi_extra,
  526. )
  527. self.routes.append(route)
  528. def api_route(
  529. self,
  530. path: str,
  531. *,
  532. response_model: Optional[Type[Any]] = None,
  533. status_code: Optional[int] = None,
  534. tags: Optional[List[str]] = None,
  535. dependencies: Optional[Sequence[params.Depends]] = None,
  536. summary: Optional[str] = None,
  537. description: Optional[str] = None,
  538. response_description: str = "Successful Response",
  539. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  540. deprecated: Optional[bool] = None,
  541. methods: Optional[List[str]] = None,
  542. operation_id: Optional[str] = None,
  543. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  544. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  545. response_model_by_alias: bool = True,
  546. response_model_exclude_unset: bool = False,
  547. response_model_exclude_defaults: bool = False,
  548. response_model_exclude_none: bool = False,
  549. include_in_schema: bool = True,
  550. response_class: Type[Response] = Default(JSONResponse),
  551. name: Optional[str] = None,
  552. callbacks: Optional[List[BaseRoute]] = None,
  553. openapi_extra: Optional[Dict[str, Any]] = None,
  554. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  555. def decorator(func: DecoratedCallable) -> DecoratedCallable:
  556. self.add_api_route(
  557. path,
  558. func,
  559. response_model=response_model,
  560. status_code=status_code,
  561. tags=tags,
  562. dependencies=dependencies,
  563. summary=summary,
  564. description=description,
  565. response_description=response_description,
  566. responses=responses,
  567. deprecated=deprecated,
  568. methods=methods,
  569. operation_id=operation_id,
  570. response_model_include=response_model_include,
  571. response_model_exclude=response_model_exclude,
  572. response_model_by_alias=response_model_by_alias,
  573. response_model_exclude_unset=response_model_exclude_unset,
  574. response_model_exclude_defaults=response_model_exclude_defaults,
  575. response_model_exclude_none=response_model_exclude_none,
  576. include_in_schema=include_in_schema,
  577. response_class=response_class,
  578. name=name,
  579. callbacks=callbacks,
  580. openapi_extra=openapi_extra,
  581. )
  582. return func
  583. return decorator
  584. def add_api_websocket_route(
  585. self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
  586. ) -> None:
  587. route = APIWebSocketRoute(
  588. path,
  589. endpoint=endpoint,
  590. name=name,
  591. dependency_overrides_provider=self.dependency_overrides_provider,
  592. )
  593. self.routes.append(route)
  594. def websocket(
  595. self, path: str, name: Optional[str] = None
  596. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  597. def decorator(func: DecoratedCallable) -> DecoratedCallable:
  598. self.add_api_websocket_route(path, func, name=name)
  599. return func
  600. return decorator
  601. def include_router(
  602. self,
  603. router: "APIRouter",
  604. *,
  605. prefix: str = "",
  606. tags: Optional[List[str]] = None,
  607. dependencies: Optional[Sequence[params.Depends]] = None,
  608. default_response_class: Type[Response] = Default(JSONResponse),
  609. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  610. callbacks: Optional[List[BaseRoute]] = None,
  611. deprecated: Optional[bool] = None,
  612. include_in_schema: bool = True,
  613. ) -> None:
  614. if prefix:
  615. assert prefix.startswith("/"), "A path prefix must start with '/'"
  616. assert not prefix.endswith(
  617. "/"
  618. ), "A path prefix must not end with '/', as the routes will start with '/'"
  619. else:
  620. for r in router.routes:
  621. path = getattr(r, "path")
  622. name = getattr(r, "name", "unknown")
  623. if path is not None and not path:
  624. raise Exception(
  625. f"Prefix and path cannot be both empty (path operation: {name})"
  626. )
  627. if responses is None:
  628. responses = {}
  629. for route in router.routes:
  630. if isinstance(route, APIRoute):
  631. combined_responses = {**responses, **route.responses}
  632. use_response_class = get_value_or_default(
  633. route.response_class,
  634. router.default_response_class,
  635. default_response_class,
  636. self.default_response_class,
  637. )
  638. current_tags = []
  639. if tags:
  640. current_tags.extend(tags)
  641. if route.tags:
  642. current_tags.extend(route.tags)
  643. current_dependencies: List[params.Depends] = []
  644. if dependencies:
  645. current_dependencies.extend(dependencies)
  646. if route.dependencies:
  647. current_dependencies.extend(route.dependencies)
  648. current_callbacks = []
  649. if callbacks:
  650. current_callbacks.extend(callbacks)
  651. if route.callbacks:
  652. current_callbacks.extend(route.callbacks)
  653. self.add_api_route(
  654. prefix + route.path,
  655. route.endpoint,
  656. response_model=route.response_model,
  657. status_code=route.status_code,
  658. tags=current_tags,
  659. dependencies=current_dependencies,
  660. summary=route.summary,
  661. description=route.description,
  662. response_description=route.response_description,
  663. responses=combined_responses,
  664. deprecated=route.deprecated or deprecated or self.deprecated,
  665. methods=route.methods,
  666. operation_id=route.operation_id,
  667. response_model_include=route.response_model_include,
  668. response_model_exclude=route.response_model_exclude,
  669. response_model_by_alias=route.response_model_by_alias,
  670. response_model_exclude_unset=route.response_model_exclude_unset,
  671. response_model_exclude_defaults=route.response_model_exclude_defaults,
  672. response_model_exclude_none=route.response_model_exclude_none,
  673. include_in_schema=route.include_in_schema
  674. and self.include_in_schema
  675. and include_in_schema,
  676. response_class=use_response_class,
  677. name=route.name,
  678. route_class_override=type(route),
  679. callbacks=current_callbacks,
  680. openapi_extra=route.openapi_extra,
  681. )
  682. elif isinstance(route, routing.Route):
  683. methods = list(route.methods or []) # type: ignore # in Starlette
  684. self.add_route(
  685. prefix + route.path,
  686. route.endpoint,
  687. methods=methods,
  688. include_in_schema=route.include_in_schema,
  689. name=route.name,
  690. )
  691. elif isinstance(route, APIWebSocketRoute):
  692. self.add_api_websocket_route(
  693. prefix + route.path, route.endpoint, name=route.name
  694. )
  695. elif isinstance(route, routing.WebSocketRoute):
  696. self.add_websocket_route(
  697. prefix + route.path, route.endpoint, name=route.name
  698. )
  699. for handler in router.on_startup:
  700. self.add_event_handler("startup", handler)
  701. for handler in router.on_shutdown:
  702. self.add_event_handler("shutdown", handler)
  703. def get(
  704. self,
  705. path: str,
  706. *,
  707. response_model: Optional[Type[Any]] = None,
  708. status_code: Optional[int] = None,
  709. tags: Optional[List[str]] = None,
  710. dependencies: Optional[Sequence[params.Depends]] = None,
  711. summary: Optional[str] = None,
  712. description: Optional[str] = None,
  713. response_description: str = "Successful Response",
  714. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  715. deprecated: Optional[bool] = None,
  716. operation_id: Optional[str] = None,
  717. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  718. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  719. response_model_by_alias: bool = True,
  720. response_model_exclude_unset: bool = False,
  721. response_model_exclude_defaults: bool = False,
  722. response_model_exclude_none: bool = False,
  723. include_in_schema: bool = True,
  724. response_class: Type[Response] = Default(JSONResponse),
  725. name: Optional[str] = None,
  726. callbacks: Optional[List[BaseRoute]] = None,
  727. openapi_extra: Optional[Dict[str, Any]] = None,
  728. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  729. return self.api_route(
  730. path=path,
  731. response_model=response_model,
  732. status_code=status_code,
  733. tags=tags,
  734. dependencies=dependencies,
  735. summary=summary,
  736. description=description,
  737. response_description=response_description,
  738. responses=responses,
  739. deprecated=deprecated,
  740. methods=["GET"],
  741. operation_id=operation_id,
  742. response_model_include=response_model_include,
  743. response_model_exclude=response_model_exclude,
  744. response_model_by_alias=response_model_by_alias,
  745. response_model_exclude_unset=response_model_exclude_unset,
  746. response_model_exclude_defaults=response_model_exclude_defaults,
  747. response_model_exclude_none=response_model_exclude_none,
  748. include_in_schema=include_in_schema,
  749. response_class=response_class,
  750. name=name,
  751. callbacks=callbacks,
  752. openapi_extra=openapi_extra,
  753. )
  754. def put(
  755. self,
  756. path: str,
  757. *,
  758. response_model: Optional[Type[Any]] = None,
  759. status_code: Optional[int] = None,
  760. tags: Optional[List[str]] = None,
  761. dependencies: Optional[Sequence[params.Depends]] = None,
  762. summary: Optional[str] = None,
  763. description: Optional[str] = None,
  764. response_description: str = "Successful Response",
  765. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  766. deprecated: Optional[bool] = None,
  767. operation_id: Optional[str] = None,
  768. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  769. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  770. response_model_by_alias: bool = True,
  771. response_model_exclude_unset: bool = False,
  772. response_model_exclude_defaults: bool = False,
  773. response_model_exclude_none: bool = False,
  774. include_in_schema: bool = True,
  775. response_class: Type[Response] = Default(JSONResponse),
  776. name: Optional[str] = None,
  777. callbacks: Optional[List[BaseRoute]] = None,
  778. openapi_extra: Optional[Dict[str, Any]] = None,
  779. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  780. return self.api_route(
  781. path=path,
  782. response_model=response_model,
  783. status_code=status_code,
  784. tags=tags,
  785. dependencies=dependencies,
  786. summary=summary,
  787. description=description,
  788. response_description=response_description,
  789. responses=responses,
  790. deprecated=deprecated,
  791. methods=["PUT"],
  792. operation_id=operation_id,
  793. response_model_include=response_model_include,
  794. response_model_exclude=response_model_exclude,
  795. response_model_by_alias=response_model_by_alias,
  796. response_model_exclude_unset=response_model_exclude_unset,
  797. response_model_exclude_defaults=response_model_exclude_defaults,
  798. response_model_exclude_none=response_model_exclude_none,
  799. include_in_schema=include_in_schema,
  800. response_class=response_class,
  801. name=name,
  802. callbacks=callbacks,
  803. openapi_extra=openapi_extra,
  804. )
  805. def post(
  806. self,
  807. path: str,
  808. *,
  809. response_model: Optional[Type[Any]] = None,
  810. status_code: Optional[int] = None,
  811. tags: Optional[List[str]] = None,
  812. dependencies: Optional[Sequence[params.Depends]] = None,
  813. summary: Optional[str] = None,
  814. description: Optional[str] = None,
  815. response_description: str = "Successful Response",
  816. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  817. deprecated: Optional[bool] = None,
  818. operation_id: Optional[str] = None,
  819. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  820. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  821. response_model_by_alias: bool = True,
  822. response_model_exclude_unset: bool = False,
  823. response_model_exclude_defaults: bool = False,
  824. response_model_exclude_none: bool = False,
  825. include_in_schema: bool = True,
  826. response_class: Type[Response] = Default(JSONResponse),
  827. name: Optional[str] = None,
  828. callbacks: Optional[List[BaseRoute]] = None,
  829. openapi_extra: Optional[Dict[str, Any]] = None,
  830. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  831. return self.api_route(
  832. path=path,
  833. response_model=response_model,
  834. status_code=status_code,
  835. tags=tags,
  836. dependencies=dependencies,
  837. summary=summary,
  838. description=description,
  839. response_description=response_description,
  840. responses=responses,
  841. deprecated=deprecated,
  842. methods=["POST"],
  843. operation_id=operation_id,
  844. response_model_include=response_model_include,
  845. response_model_exclude=response_model_exclude,
  846. response_model_by_alias=response_model_by_alias,
  847. response_model_exclude_unset=response_model_exclude_unset,
  848. response_model_exclude_defaults=response_model_exclude_defaults,
  849. response_model_exclude_none=response_model_exclude_none,
  850. include_in_schema=include_in_schema,
  851. response_class=response_class,
  852. name=name,
  853. callbacks=callbacks,
  854. openapi_extra=openapi_extra,
  855. )
  856. def delete(
  857. self,
  858. path: str,
  859. *,
  860. response_model: Optional[Type[Any]] = None,
  861. status_code: Optional[int] = None,
  862. tags: Optional[List[str]] = None,
  863. dependencies: Optional[Sequence[params.Depends]] = None,
  864. summary: Optional[str] = None,
  865. description: Optional[str] = None,
  866. response_description: str = "Successful Response",
  867. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  868. deprecated: Optional[bool] = None,
  869. operation_id: Optional[str] = None,
  870. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  871. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  872. response_model_by_alias: bool = True,
  873. response_model_exclude_unset: bool = False,
  874. response_model_exclude_defaults: bool = False,
  875. response_model_exclude_none: bool = False,
  876. include_in_schema: bool = True,
  877. response_class: Type[Response] = Default(JSONResponse),
  878. name: Optional[str] = None,
  879. callbacks: Optional[List[BaseRoute]] = None,
  880. openapi_extra: Optional[Dict[str, Any]] = None,
  881. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  882. return self.api_route(
  883. path=path,
  884. response_model=response_model,
  885. status_code=status_code,
  886. tags=tags,
  887. dependencies=dependencies,
  888. summary=summary,
  889. description=description,
  890. response_description=response_description,
  891. responses=responses,
  892. deprecated=deprecated,
  893. methods=["DELETE"],
  894. operation_id=operation_id,
  895. response_model_include=response_model_include,
  896. response_model_exclude=response_model_exclude,
  897. response_model_by_alias=response_model_by_alias,
  898. response_model_exclude_unset=response_model_exclude_unset,
  899. response_model_exclude_defaults=response_model_exclude_defaults,
  900. response_model_exclude_none=response_model_exclude_none,
  901. include_in_schema=include_in_schema,
  902. response_class=response_class,
  903. name=name,
  904. callbacks=callbacks,
  905. openapi_extra=openapi_extra,
  906. )
  907. def options(
  908. self,
  909. path: str,
  910. *,
  911. response_model: Optional[Type[Any]] = None,
  912. status_code: Optional[int] = None,
  913. tags: Optional[List[str]] = None,
  914. dependencies: Optional[Sequence[params.Depends]] = None,
  915. summary: Optional[str] = None,
  916. description: Optional[str] = None,
  917. response_description: str = "Successful Response",
  918. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  919. deprecated: Optional[bool] = None,
  920. operation_id: Optional[str] = None,
  921. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  922. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  923. response_model_by_alias: bool = True,
  924. response_model_exclude_unset: bool = False,
  925. response_model_exclude_defaults: bool = False,
  926. response_model_exclude_none: bool = False,
  927. include_in_schema: bool = True,
  928. response_class: Type[Response] = Default(JSONResponse),
  929. name: Optional[str] = None,
  930. callbacks: Optional[List[BaseRoute]] = None,
  931. openapi_extra: Optional[Dict[str, Any]] = None,
  932. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  933. return self.api_route(
  934. path=path,
  935. response_model=response_model,
  936. status_code=status_code,
  937. tags=tags,
  938. dependencies=dependencies,
  939. summary=summary,
  940. description=description,
  941. response_description=response_description,
  942. responses=responses,
  943. deprecated=deprecated,
  944. methods=["OPTIONS"],
  945. operation_id=operation_id,
  946. response_model_include=response_model_include,
  947. response_model_exclude=response_model_exclude,
  948. response_model_by_alias=response_model_by_alias,
  949. response_model_exclude_unset=response_model_exclude_unset,
  950. response_model_exclude_defaults=response_model_exclude_defaults,
  951. response_model_exclude_none=response_model_exclude_none,
  952. include_in_schema=include_in_schema,
  953. response_class=response_class,
  954. name=name,
  955. callbacks=callbacks,
  956. openapi_extra=openapi_extra,
  957. )
  958. def head(
  959. self,
  960. path: str,
  961. *,
  962. response_model: Optional[Type[Any]] = None,
  963. status_code: Optional[int] = None,
  964. tags: Optional[List[str]] = None,
  965. dependencies: Optional[Sequence[params.Depends]] = None,
  966. summary: Optional[str] = None,
  967. description: Optional[str] = None,
  968. response_description: str = "Successful Response",
  969. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  970. deprecated: Optional[bool] = None,
  971. operation_id: Optional[str] = None,
  972. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  973. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  974. response_model_by_alias: bool = True,
  975. response_model_exclude_unset: bool = False,
  976. response_model_exclude_defaults: bool = False,
  977. response_model_exclude_none: bool = False,
  978. include_in_schema: bool = True,
  979. response_class: Type[Response] = Default(JSONResponse),
  980. name: Optional[str] = None,
  981. callbacks: Optional[List[BaseRoute]] = None,
  982. openapi_extra: Optional[Dict[str, Any]] = None,
  983. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  984. return self.api_route(
  985. path=path,
  986. response_model=response_model,
  987. status_code=status_code,
  988. tags=tags,
  989. dependencies=dependencies,
  990. summary=summary,
  991. description=description,
  992. response_description=response_description,
  993. responses=responses,
  994. deprecated=deprecated,
  995. methods=["HEAD"],
  996. operation_id=operation_id,
  997. response_model_include=response_model_include,
  998. response_model_exclude=response_model_exclude,
  999. response_model_by_alias=response_model_by_alias,
  1000. response_model_exclude_unset=response_model_exclude_unset,
  1001. response_model_exclude_defaults=response_model_exclude_defaults,
  1002. response_model_exclude_none=response_model_exclude_none,
  1003. include_in_schema=include_in_schema,
  1004. response_class=response_class,
  1005. name=name,
  1006. callbacks=callbacks,
  1007. openapi_extra=openapi_extra,
  1008. )
  1009. def patch(
  1010. self,
  1011. path: str,
  1012. *,
  1013. response_model: Optional[Type[Any]] = None,
  1014. status_code: Optional[int] = None,
  1015. tags: Optional[List[str]] = None,
  1016. dependencies: Optional[Sequence[params.Depends]] = None,
  1017. summary: Optional[str] = None,
  1018. description: Optional[str] = None,
  1019. response_description: str = "Successful Response",
  1020. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  1021. deprecated: Optional[bool] = None,
  1022. operation_id: Optional[str] = None,
  1023. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1024. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1025. response_model_by_alias: bool = True,
  1026. response_model_exclude_unset: bool = False,
  1027. response_model_exclude_defaults: bool = False,
  1028. response_model_exclude_none: bool = False,
  1029. include_in_schema: bool = True,
  1030. response_class: Type[Response] = Default(JSONResponse),
  1031. name: Optional[str] = None,
  1032. callbacks: Optional[List[BaseRoute]] = None,
  1033. openapi_extra: Optional[Dict[str, Any]] = None,
  1034. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  1035. return self.api_route(
  1036. path=path,
  1037. response_model=response_model,
  1038. status_code=status_code,
  1039. tags=tags,
  1040. dependencies=dependencies,
  1041. summary=summary,
  1042. description=description,
  1043. response_description=response_description,
  1044. responses=responses,
  1045. deprecated=deprecated,
  1046. methods=["PATCH"],
  1047. operation_id=operation_id,
  1048. response_model_include=response_model_include,
  1049. response_model_exclude=response_model_exclude,
  1050. response_model_by_alias=response_model_by_alias,
  1051. response_model_exclude_unset=response_model_exclude_unset,
  1052. response_model_exclude_defaults=response_model_exclude_defaults,
  1053. response_model_exclude_none=response_model_exclude_none,
  1054. include_in_schema=include_in_schema,
  1055. response_class=response_class,
  1056. name=name,
  1057. callbacks=callbacks,
  1058. openapi_extra=openapi_extra,
  1059. )
  1060. def trace(
  1061. self,
  1062. path: str,
  1063. *,
  1064. response_model: Optional[Type[Any]] = None,
  1065. status_code: Optional[int] = None,
  1066. tags: Optional[List[str]] = None,
  1067. dependencies: Optional[Sequence[params.Depends]] = None,
  1068. summary: Optional[str] = None,
  1069. description: Optional[str] = None,
  1070. response_description: str = "Successful Response",
  1071. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  1072. deprecated: Optional[bool] = None,
  1073. operation_id: Optional[str] = None,
  1074. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1075. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1076. response_model_by_alias: bool = True,
  1077. response_model_exclude_unset: bool = False,
  1078. response_model_exclude_defaults: bool = False,
  1079. response_model_exclude_none: bool = False,
  1080. include_in_schema: bool = True,
  1081. response_class: Type[Response] = Default(JSONResponse),
  1082. name: Optional[str] = None,
  1083. callbacks: Optional[List[BaseRoute]] = None,
  1084. openapi_extra: Optional[Dict[str, Any]] = None,
  1085. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  1086. return self.api_route(
  1087. path=path,
  1088. response_model=response_model,
  1089. status_code=status_code,
  1090. tags=tags,
  1091. dependencies=dependencies,
  1092. summary=summary,
  1093. description=description,
  1094. response_description=response_description,
  1095. responses=responses,
  1096. deprecated=deprecated,
  1097. methods=["TRACE"],
  1098. operation_id=operation_id,
  1099. response_model_include=response_model_include,
  1100. response_model_exclude=response_model_exclude,
  1101. response_model_by_alias=response_model_by_alias,
  1102. response_model_exclude_unset=response_model_exclude_unset,
  1103. response_model_exclude_defaults=response_model_exclude_defaults,
  1104. response_model_exclude_none=response_model_exclude_none,
  1105. include_in_schema=include_in_schema,
  1106. response_class=response_class,
  1107. name=name,
  1108. callbacks=callbacks,
  1109. openapi_extra=openapi_extra,
  1110. )