Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.
 
 
 
 

1151 linhas
48 KiB

  1. import asyncio
  2. import dataclasses
  3. import email.message
  4. import enum
  5. import inspect
  6. import json
  7. from typing import (
  8. Any,
  9. Callable,
  10. Coroutine,
  11. Dict,
  12. List,
  13. Optional,
  14. Sequence,
  15. Set,
  16. Type,
  17. Union,
  18. )
  19. from fastapi import params
  20. from fastapi.datastructures import Default, DefaultPlaceholder
  21. from fastapi.dependencies.models import Dependant
  22. from fastapi.dependencies.utils import (
  23. get_body_field,
  24. get_dependant,
  25. get_parameterless_sub_dependant,
  26. solve_dependencies,
  27. )
  28. from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
  29. from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
  30. from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
  31. from fastapi.types import DecoratedCallable
  32. from fastapi.utils import (
  33. create_cloned_field,
  34. create_response_field,
  35. generate_operation_id_for_path,
  36. get_value_or_default,
  37. )
  38. from pydantic import BaseModel
  39. from pydantic.error_wrappers import ErrorWrapper, ValidationError
  40. from pydantic.fields import ModelField, Undefined
  41. from starlette import routing
  42. from starlette.concurrency import run_in_threadpool
  43. from starlette.exceptions import HTTPException
  44. from starlette.requests import Request
  45. from starlette.responses import JSONResponse, Response
  46. from starlette.routing import BaseRoute
  47. from starlette.routing import Mount as Mount # noqa
  48. from starlette.routing import (
  49. compile_path,
  50. get_name,
  51. request_response,
  52. websocket_session,
  53. )
  54. from starlette.status import WS_1008_POLICY_VIOLATION
  55. from starlette.types import ASGIApp
  56. from starlette.websockets import WebSocket
  57. def _prepare_response_content(
  58. res: Any,
  59. *,
  60. exclude_unset: bool,
  61. exclude_defaults: bool = False,
  62. exclude_none: bool = False,
  63. ) -> Any:
  64. if isinstance(res, BaseModel):
  65. read_with_orm_mode = getattr(res.__config__, "read_with_orm_mode", None)
  66. if read_with_orm_mode:
  67. # Let from_orm extract the data from this model instead of converting
  68. # it now to a dict.
  69. # Otherwise there's no way to extract lazy data that requires attribute
  70. # access instead of dict iteration, e.g. lazy relationships.
  71. return res
  72. return res.dict(
  73. by_alias=True,
  74. exclude_unset=exclude_unset,
  75. exclude_defaults=exclude_defaults,
  76. exclude_none=exclude_none,
  77. )
  78. elif isinstance(res, list):
  79. return [
  80. _prepare_response_content(
  81. item,
  82. exclude_unset=exclude_unset,
  83. exclude_defaults=exclude_defaults,
  84. exclude_none=exclude_none,
  85. )
  86. for item in res
  87. ]
  88. elif isinstance(res, dict):
  89. return {
  90. k: _prepare_response_content(
  91. v,
  92. exclude_unset=exclude_unset,
  93. exclude_defaults=exclude_defaults,
  94. exclude_none=exclude_none,
  95. )
  96. for k, v in res.items()
  97. }
  98. elif dataclasses.is_dataclass(res):
  99. return dataclasses.asdict(res)
  100. return res
  101. async def serialize_response(
  102. *,
  103. field: Optional[ModelField] = None,
  104. response_content: Any,
  105. include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  106. exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  107. by_alias: bool = True,
  108. exclude_unset: bool = False,
  109. exclude_defaults: bool = False,
  110. exclude_none: bool = False,
  111. is_coroutine: bool = True,
  112. ) -> Any:
  113. if field:
  114. errors = []
  115. response_content = _prepare_response_content(
  116. response_content,
  117. exclude_unset=exclude_unset,
  118. exclude_defaults=exclude_defaults,
  119. exclude_none=exclude_none,
  120. )
  121. if is_coroutine:
  122. value, errors_ = field.validate(response_content, {}, loc=("response",))
  123. else:
  124. value, errors_ = await run_in_threadpool(
  125. field.validate, response_content, {}, loc=("response",)
  126. )
  127. if isinstance(errors_, ErrorWrapper):
  128. errors.append(errors_)
  129. elif isinstance(errors_, list):
  130. errors.extend(errors_)
  131. if errors:
  132. raise ValidationError(errors, field.type_)
  133. return jsonable_encoder(
  134. value,
  135. include=include,
  136. exclude=exclude,
  137. by_alias=by_alias,
  138. exclude_unset=exclude_unset,
  139. exclude_defaults=exclude_defaults,
  140. exclude_none=exclude_none,
  141. )
  142. else:
  143. return jsonable_encoder(response_content)
  144. async def run_endpoint_function(
  145. *, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
  146. ) -> Any:
  147. # Only called by get_request_handler. Has been split into its own function to
  148. # facilitate profiling endpoints, since inner functions are harder to profile.
  149. assert dependant.call is not None, "dependant.call must be a function"
  150. if is_coroutine:
  151. return await dependant.call(**values)
  152. else:
  153. return await run_in_threadpool(dependant.call, **values)
  154. def get_request_handler(
  155. dependant: Dependant,
  156. body_field: Optional[ModelField] = None,
  157. status_code: Optional[int] = None,
  158. response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
  159. response_field: Optional[ModelField] = None,
  160. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  161. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  162. response_model_by_alias: bool = True,
  163. response_model_exclude_unset: bool = False,
  164. response_model_exclude_defaults: bool = False,
  165. response_model_exclude_none: bool = False,
  166. dependency_overrides_provider: Optional[Any] = None,
  167. ) -> Callable[[Request], Coroutine[Any, Any, Response]]:
  168. assert dependant.call is not None, "dependant.call must be a function"
  169. is_coroutine = asyncio.iscoroutinefunction(dependant.call)
  170. is_body_form = body_field and isinstance(body_field.field_info, params.Form)
  171. if isinstance(response_class, DefaultPlaceholder):
  172. actual_response_class: Type[Response] = response_class.value
  173. else:
  174. actual_response_class = response_class
  175. async def app(request: Request) -> Response:
  176. try:
  177. body: Any = None
  178. if body_field:
  179. if is_body_form:
  180. body = await request.form()
  181. else:
  182. body_bytes = await request.body()
  183. if body_bytes:
  184. json_body: Any = Undefined
  185. content_type_value = request.headers.get("content-type")
  186. if not content_type_value:
  187. json_body = await request.json()
  188. else:
  189. message = email.message.Message()
  190. message["content-type"] = content_type_value
  191. if message.get_content_maintype() == "application":
  192. subtype = message.get_content_subtype()
  193. if subtype == "json" or subtype.endswith("+json"):
  194. json_body = await request.json()
  195. if json_body != Undefined:
  196. body = json_body
  197. else:
  198. body = body_bytes
  199. except json.JSONDecodeError as e:
  200. raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
  201. except Exception as e:
  202. raise HTTPException(
  203. status_code=400, detail="There was an error parsing the body"
  204. ) from e
  205. solved_result = await solve_dependencies(
  206. request=request,
  207. dependant=dependant,
  208. body=body,
  209. dependency_overrides_provider=dependency_overrides_provider,
  210. )
  211. values, errors, background_tasks, sub_response, _ = solved_result
  212. if errors:
  213. raise RequestValidationError(errors, body=body)
  214. else:
  215. raw_response = await run_endpoint_function(
  216. dependant=dependant, values=values, is_coroutine=is_coroutine
  217. )
  218. if isinstance(raw_response, Response):
  219. if raw_response.background is None:
  220. raw_response.background = background_tasks
  221. return raw_response
  222. response_data = await serialize_response(
  223. field=response_field,
  224. response_content=raw_response,
  225. include=response_model_include,
  226. exclude=response_model_exclude,
  227. by_alias=response_model_by_alias,
  228. exclude_unset=response_model_exclude_unset,
  229. exclude_defaults=response_model_exclude_defaults,
  230. exclude_none=response_model_exclude_none,
  231. is_coroutine=is_coroutine,
  232. )
  233. response_args: Dict[str, Any] = {"background": background_tasks}
  234. # If status_code was set, use it, otherwise use the default from the
  235. # response class, in the case of redirect it's 307
  236. if status_code is not None:
  237. response_args["status_code"] = status_code
  238. response = actual_response_class(response_data, **response_args)
  239. response.headers.raw.extend(sub_response.headers.raw)
  240. if sub_response.status_code:
  241. response.status_code = sub_response.status_code
  242. return response
  243. return app
  244. def get_websocket_app(
  245. dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
  246. ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
  247. async def app(websocket: WebSocket) -> None:
  248. solved_result = await solve_dependencies(
  249. request=websocket,
  250. dependant=dependant,
  251. dependency_overrides_provider=dependency_overrides_provider,
  252. )
  253. values, errors, _, _2, _3 = solved_result
  254. if errors:
  255. await websocket.close(code=WS_1008_POLICY_VIOLATION)
  256. raise WebSocketRequestValidationError(errors)
  257. assert dependant.call is not None, "dependant.call must be a function"
  258. await dependant.call(**values)
  259. return app
  260. class APIWebSocketRoute(routing.WebSocketRoute):
  261. def __init__(
  262. self,
  263. path: str,
  264. endpoint: Callable[..., Any],
  265. *,
  266. name: Optional[str] = None,
  267. dependency_overrides_provider: Optional[Any] = None,
  268. ) -> None:
  269. self.path = path
  270. self.endpoint = endpoint
  271. self.name = get_name(endpoint) if name is None else name
  272. self.dependant = get_dependant(path=path, call=self.endpoint)
  273. self.app = websocket_session(
  274. get_websocket_app(
  275. dependant=self.dependant,
  276. dependency_overrides_provider=dependency_overrides_provider,
  277. )
  278. )
  279. self.path_regex, self.path_format, self.param_convertors = compile_path(path)
  280. class APIRoute(routing.Route):
  281. def __init__(
  282. self,
  283. path: str,
  284. endpoint: Callable[..., Any],
  285. *,
  286. response_model: Optional[Type[Any]] = None,
  287. status_code: Optional[int] = None,
  288. tags: Optional[List[str]] = None,
  289. dependencies: Optional[Sequence[params.Depends]] = None,
  290. summary: Optional[str] = None,
  291. description: Optional[str] = None,
  292. response_description: str = "Successful Response",
  293. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  294. deprecated: Optional[bool] = None,
  295. name: Optional[str] = None,
  296. methods: Optional[Union[Set[str], List[str]]] = None,
  297. operation_id: Optional[str] = None,
  298. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  299. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  300. response_model_by_alias: bool = True,
  301. response_model_exclude_unset: bool = False,
  302. response_model_exclude_defaults: bool = False,
  303. response_model_exclude_none: bool = False,
  304. include_in_schema: bool = True,
  305. response_class: Union[Type[Response], DefaultPlaceholder] = Default(
  306. JSONResponse
  307. ),
  308. dependency_overrides_provider: Optional[Any] = None,
  309. callbacks: Optional[List[BaseRoute]] = None,
  310. openapi_extra: Optional[Dict[str, Any]] = None,
  311. ) -> None:
  312. # normalise enums e.g. http.HTTPStatus
  313. if isinstance(status_code, enum.IntEnum):
  314. status_code = int(status_code)
  315. self.path = path
  316. self.endpoint = endpoint
  317. self.name = get_name(endpoint) if name is None else name
  318. self.path_regex, self.path_format, self.param_convertors = compile_path(path)
  319. if methods is None:
  320. methods = ["GET"]
  321. self.methods: Set[str] = set([method.upper() for method in methods])
  322. self.unique_id = generate_operation_id_for_path(
  323. name=self.name, path=self.path_format, method=list(methods)[0]
  324. )
  325. self.response_model = response_model
  326. if self.response_model:
  327. assert (
  328. status_code not in STATUS_CODES_WITH_NO_BODY
  329. ), f"Status code {status_code} must not have a response body"
  330. response_name = "Response_" + self.unique_id
  331. self.response_field = create_response_field(
  332. name=response_name, type_=self.response_model
  333. )
  334. # Create a clone of the field, so that a Pydantic submodel is not returned
  335. # as is just because it's an instance of a subclass of a more limited class
  336. # e.g. UserInDB (containing hashed_password) could be a subclass of User
  337. # that doesn't have the hashed_password. But because it's a subclass, it
  338. # would pass the validation and be returned as is.
  339. # By being a new field, no inheritance will be passed as is. A new model
  340. # will be always created.
  341. self.secure_cloned_response_field: Optional[
  342. ModelField
  343. ] = create_cloned_field(self.response_field)
  344. else:
  345. self.response_field = None # type: ignore
  346. self.secure_cloned_response_field = None
  347. self.status_code = status_code
  348. self.tags = tags or []
  349. if dependencies:
  350. self.dependencies = list(dependencies)
  351. else:
  352. self.dependencies = []
  353. self.summary = summary
  354. self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
  355. # if a "form feed" character (page break) is found in the description text,
  356. # truncate description text to the content preceding the first "form feed"
  357. self.description = self.description.split("\f")[0]
  358. self.response_description = response_description
  359. self.responses = responses or {}
  360. response_fields = {}
  361. for additional_status_code, response in self.responses.items():
  362. assert isinstance(response, dict), "An additional response must be a dict"
  363. model = response.get("model")
  364. if model:
  365. assert (
  366. additional_status_code not in STATUS_CODES_WITH_NO_BODY
  367. ), f"Status code {additional_status_code} must not have a response body"
  368. response_name = f"Response_{additional_status_code}_{self.unique_id}"
  369. response_field = create_response_field(name=response_name, type_=model)
  370. response_fields[additional_status_code] = response_field
  371. if response_fields:
  372. self.response_fields: Dict[Union[int, str], ModelField] = response_fields
  373. else:
  374. self.response_fields = {}
  375. self.deprecated = deprecated
  376. self.operation_id = operation_id
  377. self.response_model_include = response_model_include
  378. self.response_model_exclude = response_model_exclude
  379. self.response_model_by_alias = response_model_by_alias
  380. self.response_model_exclude_unset = response_model_exclude_unset
  381. self.response_model_exclude_defaults = response_model_exclude_defaults
  382. self.response_model_exclude_none = response_model_exclude_none
  383. self.include_in_schema = include_in_schema
  384. self.response_class = response_class
  385. assert callable(endpoint), "An endpoint must be a callable"
  386. self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
  387. for depends in self.dependencies[::-1]:
  388. self.dependant.dependencies.insert(
  389. 0,
  390. get_parameterless_sub_dependant(depends=depends, path=self.path_format),
  391. )
  392. self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
  393. self.dependency_overrides_provider = dependency_overrides_provider
  394. self.callbacks = callbacks
  395. self.app = request_response(self.get_route_handler())
  396. self.openapi_extra = openapi_extra
  397. def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
  398. return get_request_handler(
  399. dependant=self.dependant,
  400. body_field=self.body_field,
  401. status_code=self.status_code,
  402. response_class=self.response_class,
  403. response_field=self.secure_cloned_response_field,
  404. response_model_include=self.response_model_include,
  405. response_model_exclude=self.response_model_exclude,
  406. response_model_by_alias=self.response_model_by_alias,
  407. response_model_exclude_unset=self.response_model_exclude_unset,
  408. response_model_exclude_defaults=self.response_model_exclude_defaults,
  409. response_model_exclude_none=self.response_model_exclude_none,
  410. dependency_overrides_provider=self.dependency_overrides_provider,
  411. )
  412. class APIRouter(routing.Router):
  413. def __init__(
  414. self,
  415. *,
  416. prefix: str = "",
  417. tags: Optional[List[str]] = None,
  418. dependencies: Optional[Sequence[params.Depends]] = None,
  419. default_response_class: Type[Response] = Default(JSONResponse),
  420. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  421. callbacks: Optional[List[BaseRoute]] = None,
  422. routes: Optional[List[routing.BaseRoute]] = None,
  423. redirect_slashes: bool = True,
  424. default: Optional[ASGIApp] = None,
  425. dependency_overrides_provider: Optional[Any] = None,
  426. route_class: Type[APIRoute] = APIRoute,
  427. on_startup: Optional[Sequence[Callable[[], Any]]] = None,
  428. on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
  429. deprecated: Optional[bool] = None,
  430. include_in_schema: bool = True,
  431. ) -> None:
  432. super().__init__(
  433. routes=routes, # type: ignore # in Starlette
  434. redirect_slashes=redirect_slashes,
  435. default=default, # type: ignore # in Starlette
  436. on_startup=on_startup, # type: ignore # in Starlette
  437. on_shutdown=on_shutdown, # type: ignore # in Starlette
  438. )
  439. if prefix:
  440. assert prefix.startswith("/"), "A path prefix must start with '/'"
  441. assert not prefix.endswith(
  442. "/"
  443. ), "A path prefix must not end with '/', as the routes will start with '/'"
  444. self.prefix = prefix
  445. self.tags: List[str] = tags or []
  446. self.dependencies = list(dependencies or []) or []
  447. self.deprecated = deprecated
  448. self.include_in_schema = include_in_schema
  449. self.responses = responses or {}
  450. self.callbacks = callbacks or []
  451. self.dependency_overrides_provider = dependency_overrides_provider
  452. self.route_class = route_class
  453. self.default_response_class = default_response_class
  454. def add_api_route(
  455. self,
  456. path: str,
  457. endpoint: Callable[..., Any],
  458. *,
  459. response_model: Optional[Type[Any]] = None,
  460. status_code: Optional[int] = None,
  461. tags: Optional[List[str]] = None,
  462. dependencies: Optional[Sequence[params.Depends]] = None,
  463. summary: Optional[str] = None,
  464. description: Optional[str] = None,
  465. response_description: str = "Successful Response",
  466. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  467. deprecated: Optional[bool] = None,
  468. methods: Optional[Union[Set[str], List[str]]] = None,
  469. operation_id: Optional[str] = None,
  470. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  471. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  472. response_model_by_alias: bool = True,
  473. response_model_exclude_unset: bool = False,
  474. response_model_exclude_defaults: bool = False,
  475. response_model_exclude_none: bool = False,
  476. include_in_schema: bool = True,
  477. response_class: Union[Type[Response], DefaultPlaceholder] = Default(
  478. JSONResponse
  479. ),
  480. name: Optional[str] = None,
  481. route_class_override: Optional[Type[APIRoute]] = None,
  482. callbacks: Optional[List[BaseRoute]] = None,
  483. openapi_extra: Optional[Dict[str, Any]] = None,
  484. ) -> None:
  485. route_class = route_class_override or self.route_class
  486. responses = responses or {}
  487. combined_responses = {**self.responses, **responses}
  488. current_response_class = get_value_or_default(
  489. response_class, self.default_response_class
  490. )
  491. current_tags = self.tags.copy()
  492. if tags:
  493. current_tags.extend(tags)
  494. current_dependencies = self.dependencies.copy()
  495. if dependencies:
  496. current_dependencies.extend(dependencies)
  497. current_callbacks = self.callbacks.copy()
  498. if callbacks:
  499. current_callbacks.extend(callbacks)
  500. route = route_class(
  501. self.prefix + path,
  502. endpoint=endpoint,
  503. response_model=response_model,
  504. status_code=status_code,
  505. tags=current_tags,
  506. dependencies=current_dependencies,
  507. summary=summary,
  508. description=description,
  509. response_description=response_description,
  510. responses=combined_responses,
  511. deprecated=deprecated or self.deprecated,
  512. methods=methods,
  513. operation_id=operation_id,
  514. response_model_include=response_model_include,
  515. response_model_exclude=response_model_exclude,
  516. response_model_by_alias=response_model_by_alias,
  517. response_model_exclude_unset=response_model_exclude_unset,
  518. response_model_exclude_defaults=response_model_exclude_defaults,
  519. response_model_exclude_none=response_model_exclude_none,
  520. include_in_schema=include_in_schema and self.include_in_schema,
  521. response_class=current_response_class,
  522. name=name,
  523. dependency_overrides_provider=self.dependency_overrides_provider,
  524. callbacks=current_callbacks,
  525. openapi_extra=openapi_extra,
  526. )
  527. self.routes.append(route)
  528. def api_route(
  529. self,
  530. path: str,
  531. *,
  532. response_model: Optional[Type[Any]] = None,
  533. status_code: Optional[int] = None,
  534. tags: Optional[List[str]] = None,
  535. dependencies: Optional[Sequence[params.Depends]] = None,
  536. summary: Optional[str] = None,
  537. description: Optional[str] = None,
  538. response_description: str = "Successful Response",
  539. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  540. deprecated: Optional[bool] = None,
  541. methods: Optional[List[str]] = None,
  542. operation_id: Optional[str] = None,
  543. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  544. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  545. response_model_by_alias: bool = True,
  546. response_model_exclude_unset: bool = False,
  547. response_model_exclude_defaults: bool = False,
  548. response_model_exclude_none: bool = False,
  549. include_in_schema: bool = True,
  550. response_class: Type[Response] = Default(JSONResponse),
  551. name: Optional[str] = None,
  552. callbacks: Optional[List[BaseRoute]] = None,
  553. openapi_extra: Optional[Dict[str, Any]] = None,
  554. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  555. def decorator(func: DecoratedCallable) -> DecoratedCallable:
  556. self.add_api_route(
  557. path,
  558. func,
  559. response_model=response_model,
  560. status_code=status_code,
  561. tags=tags,
  562. dependencies=dependencies,
  563. summary=summary,
  564. description=description,
  565. response_description=response_description,
  566. responses=responses,
  567. deprecated=deprecated,
  568. methods=methods,
  569. operation_id=operation_id,
  570. response_model_include=response_model_include,
  571. response_model_exclude=response_model_exclude,
  572. response_model_by_alias=response_model_by_alias,
  573. response_model_exclude_unset=response_model_exclude_unset,
  574. response_model_exclude_defaults=response_model_exclude_defaults,
  575. response_model_exclude_none=response_model_exclude_none,
  576. include_in_schema=include_in_schema,
  577. response_class=response_class,
  578. name=name,
  579. callbacks=callbacks,
  580. openapi_extra=openapi_extra,
  581. )
  582. return func
  583. return decorator
  584. def add_api_websocket_route(
  585. self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
  586. ) -> None:
  587. route = APIWebSocketRoute(
  588. path,
  589. endpoint=endpoint,
  590. name=name,
  591. dependency_overrides_provider=self.dependency_overrides_provider,
  592. )
  593. self.routes.append(route)
  594. def websocket(
  595. self, path: str, name: Optional[str] = None
  596. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  597. def decorator(func: DecoratedCallable) -> DecoratedCallable:
  598. self.add_api_websocket_route(path, func, name=name)
  599. return func
  600. return decorator
  601. def include_router(
  602. self,
  603. router: "APIRouter",
  604. *,
  605. prefix: str = "",
  606. tags: Optional[List[str]] = None,
  607. dependencies: Optional[Sequence[params.Depends]] = None,
  608. default_response_class: Type[Response] = Default(JSONResponse),
  609. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  610. callbacks: Optional[List[BaseRoute]] = None,
  611. deprecated: Optional[bool] = None,
  612. include_in_schema: bool = True,
  613. ) -> None:
  614. if prefix:
  615. assert prefix.startswith("/"), "A path prefix must start with '/'"
  616. assert not prefix.endswith(
  617. "/"
  618. ), "A path prefix must not end with '/', as the routes will start with '/'"
  619. else:
  620. for r in router.routes:
  621. path = getattr(r, "path")
  622. name = getattr(r, "name", "unknown")
  623. if path is not None and not path:
  624. raise Exception(
  625. f"Prefix and path cannot be both empty (path operation: {name})"
  626. )
  627. if responses is None:
  628. responses = {}
  629. for route in router.routes:
  630. if isinstance(route, APIRoute):
  631. combined_responses = {**responses, **route.responses}
  632. use_response_class = get_value_or_default(
  633. route.response_class,
  634. router.default_response_class,
  635. default_response_class,
  636. self.default_response_class,
  637. )
  638. current_tags = []
  639. if tags:
  640. current_tags.extend(tags)
  641. if route.tags:
  642. current_tags.extend(route.tags)
  643. current_dependencies: List[params.Depends] = []
  644. if dependencies:
  645. current_dependencies.extend(dependencies)
  646. if route.dependencies:
  647. current_dependencies.extend(route.dependencies)
  648. current_callbacks = []
  649. if callbacks:
  650. current_callbacks.extend(callbacks)
  651. if route.callbacks:
  652. current_callbacks.extend(route.callbacks)
  653. self.add_api_route(
  654. prefix + route.path,
  655. route.endpoint,
  656. response_model=route.response_model,
  657. status_code=route.status_code,
  658. tags=current_tags,
  659. dependencies=current_dependencies,
  660. summary=route.summary,
  661. description=route.description,
  662. response_description=route.response_description,
  663. responses=combined_responses,
  664. deprecated=route.deprecated or deprecated or self.deprecated,
  665. methods=route.methods,
  666. operation_id=route.operation_id,
  667. response_model_include=route.response_model_include,
  668. response_model_exclude=route.response_model_exclude,
  669. response_model_by_alias=route.response_model_by_alias,
  670. response_model_exclude_unset=route.response_model_exclude_unset,
  671. response_model_exclude_defaults=route.response_model_exclude_defaults,
  672. response_model_exclude_none=route.response_model_exclude_none,
  673. include_in_schema=route.include_in_schema
  674. and self.include_in_schema
  675. and include_in_schema,
  676. response_class=use_response_class,
  677. name=route.name,
  678. route_class_override=type(route),
  679. callbacks=current_callbacks,
  680. openapi_extra=route.openapi_extra,
  681. )
  682. elif isinstance(route, routing.Route):
  683. methods = list(route.methods or []) # type: ignore # in Starlette
  684. self.add_route(
  685. prefix + route.path,
  686. route.endpoint,
  687. methods=methods,
  688. include_in_schema=route.include_in_schema,
  689. name=route.name,
  690. )
  691. elif isinstance(route, APIWebSocketRoute):
  692. self.add_api_websocket_route(
  693. prefix + route.path, route.endpoint, name=route.name
  694. )
  695. elif isinstance(route, routing.WebSocketRoute):
  696. self.add_websocket_route(
  697. prefix + route.path, route.endpoint, name=route.name
  698. )
  699. for handler in router.on_startup:
  700. self.add_event_handler("startup", handler)
  701. for handler in router.on_shutdown:
  702. self.add_event_handler("shutdown", handler)
  703. def get(
  704. self,
  705. path: str,
  706. *,
  707. response_model: Optional[Type[Any]] = None,
  708. status_code: Optional[int] = None,
  709. tags: Optional[List[str]] = None,
  710. dependencies: Optional[Sequence[params.Depends]] = None,
  711. summary: Optional[str] = None,
  712. description: Optional[str] = None,
  713. response_description: str = "Successful Response",
  714. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  715. deprecated: Optional[bool] = None,
  716. operation_id: Optional[str] = None,
  717. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  718. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  719. response_model_by_alias: bool = True,
  720. response_model_exclude_unset: bool = False,
  721. response_model_exclude_defaults: bool = False,
  722. response_model_exclude_none: bool = False,
  723. include_in_schema: bool = True,
  724. response_class: Type[Response] = Default(JSONResponse),
  725. name: Optional[str] = None,
  726. callbacks: Optional[List[BaseRoute]] = None,
  727. openapi_extra: Optional[Dict[str, Any]] = None,
  728. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  729. return self.api_route(
  730. path=path,
  731. response_model=response_model,
  732. status_code=status_code,
  733. tags=tags,
  734. dependencies=dependencies,
  735. summary=summary,
  736. description=description,
  737. response_description=response_description,
  738. responses=responses,
  739. deprecated=deprecated,
  740. methods=["GET"],
  741. operation_id=operation_id,
  742. response_model_include=response_model_include,
  743. response_model_exclude=response_model_exclude,
  744. response_model_by_alias=response_model_by_alias,
  745. response_model_exclude_unset=response_model_exclude_unset,
  746. response_model_exclude_defaults=response_model_exclude_defaults,
  747. response_model_exclude_none=response_model_exclude_none,
  748. include_in_schema=include_in_schema,
  749. response_class=response_class,
  750. name=name,
  751. callbacks=callbacks,
  752. openapi_extra=openapi_extra,
  753. )
  754. def put(
  755. self,
  756. path: str,
  757. *,
  758. response_model: Optional[Type[Any]] = None,
  759. status_code: Optional[int] = None,
  760. tags: Optional[List[str]] = None,
  761. dependencies: Optional[Sequence[params.Depends]] = None,
  762. summary: Optional[str] = None,
  763. description: Optional[str] = None,
  764. response_description: str = "Successful Response",
  765. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  766. deprecated: Optional[bool] = None,
  767. operation_id: Optional[str] = None,
  768. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  769. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  770. response_model_by_alias: bool = True,
  771. response_model_exclude_unset: bool = False,
  772. response_model_exclude_defaults: bool = False,
  773. response_model_exclude_none: bool = False,
  774. include_in_schema: bool = True,
  775. response_class: Type[Response] = Default(JSONResponse),
  776. name: Optional[str] = None,
  777. callbacks: Optional[List[BaseRoute]] = None,
  778. openapi_extra: Optional[Dict[str, Any]] = None,
  779. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  780. return self.api_route(
  781. path=path,
  782. response_model=response_model,
  783. status_code=status_code,
  784. tags=tags,
  785. dependencies=dependencies,
  786. summary=summary,
  787. description=description,
  788. response_description=response_description,
  789. responses=responses,
  790. deprecated=deprecated,
  791. methods=["PUT"],
  792. operation_id=operation_id,
  793. response_model_include=response_model_include,
  794. response_model_exclude=response_model_exclude,
  795. response_model_by_alias=response_model_by_alias,
  796. response_model_exclude_unset=response_model_exclude_unset,
  797. response_model_exclude_defaults=response_model_exclude_defaults,
  798. response_model_exclude_none=response_model_exclude_none,
  799. include_in_schema=include_in_schema,
  800. response_class=response_class,
  801. name=name,
  802. callbacks=callbacks,
  803. openapi_extra=openapi_extra,
  804. )
  805. def post(
  806. self,
  807. path: str,
  808. *,
  809. response_model: Optional[Type[Any]] = None,
  810. status_code: Optional[int] = None,
  811. tags: Optional[List[str]] = None,
  812. dependencies: Optional[Sequence[params.Depends]] = None,
  813. summary: Optional[str] = None,
  814. description: Optional[str] = None,
  815. response_description: str = "Successful Response",
  816. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  817. deprecated: Optional[bool] = None,
  818. operation_id: Optional[str] = None,
  819. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  820. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  821. response_model_by_alias: bool = True,
  822. response_model_exclude_unset: bool = False,
  823. response_model_exclude_defaults: bool = False,
  824. response_model_exclude_none: bool = False,
  825. include_in_schema: bool = True,
  826. response_class: Type[Response] = Default(JSONResponse),
  827. name: Optional[str] = None,
  828. callbacks: Optional[List[BaseRoute]] = None,
  829. openapi_extra: Optional[Dict[str, Any]] = None,
  830. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  831. return self.api_route(
  832. path=path,
  833. response_model=response_model,
  834. status_code=status_code,
  835. tags=tags,
  836. dependencies=dependencies,
  837. summary=summary,
  838. description=description,
  839. response_description=response_description,
  840. responses=responses,
  841. deprecated=deprecated,
  842. methods=["POST"],
  843. operation_id=operation_id,
  844. response_model_include=response_model_include,
  845. response_model_exclude=response_model_exclude,
  846. response_model_by_alias=response_model_by_alias,
  847. response_model_exclude_unset=response_model_exclude_unset,
  848. response_model_exclude_defaults=response_model_exclude_defaults,
  849. response_model_exclude_none=response_model_exclude_none,
  850. include_in_schema=include_in_schema,
  851. response_class=response_class,
  852. name=name,
  853. callbacks=callbacks,
  854. openapi_extra=openapi_extra,
  855. )
  856. def delete(
  857. self,
  858. path: str,
  859. *,
  860. response_model: Optional[Type[Any]] = None,
  861. status_code: Optional[int] = None,
  862. tags: Optional[List[str]] = None,
  863. dependencies: Optional[Sequence[params.Depends]] = None,
  864. summary: Optional[str] = None,
  865. description: Optional[str] = None,
  866. response_description: str = "Successful Response",
  867. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  868. deprecated: Optional[bool] = None,
  869. operation_id: Optional[str] = None,
  870. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  871. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  872. response_model_by_alias: bool = True,
  873. response_model_exclude_unset: bool = False,
  874. response_model_exclude_defaults: bool = False,
  875. response_model_exclude_none: bool = False,
  876. include_in_schema: bool = True,
  877. response_class: Type[Response] = Default(JSONResponse),
  878. name: Optional[str] = None,
  879. callbacks: Optional[List[BaseRoute]] = None,
  880. openapi_extra: Optional[Dict[str, Any]] = None,
  881. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  882. return self.api_route(
  883. path=path,
  884. response_model=response_model,
  885. status_code=status_code,
  886. tags=tags,
  887. dependencies=dependencies,
  888. summary=summary,
  889. description=description,
  890. response_description=response_description,
  891. responses=responses,
  892. deprecated=deprecated,
  893. methods=["DELETE"],
  894. operation_id=operation_id,
  895. response_model_include=response_model_include,
  896. response_model_exclude=response_model_exclude,
  897. response_model_by_alias=response_model_by_alias,
  898. response_model_exclude_unset=response_model_exclude_unset,
  899. response_model_exclude_defaults=response_model_exclude_defaults,
  900. response_model_exclude_none=response_model_exclude_none,
  901. include_in_schema=include_in_schema,
  902. response_class=response_class,
  903. name=name,
  904. callbacks=callbacks,
  905. openapi_extra=openapi_extra,
  906. )
  907. def options(
  908. self,
  909. path: str,
  910. *,
  911. response_model: Optional[Type[Any]] = None,
  912. status_code: Optional[int] = None,
  913. tags: Optional[List[str]] = None,
  914. dependencies: Optional[Sequence[params.Depends]] = None,
  915. summary: Optional[str] = None,
  916. description: Optional[str] = None,
  917. response_description: str = "Successful Response",
  918. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  919. deprecated: Optional[bool] = None,
  920. operation_id: Optional[str] = None,
  921. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  922. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  923. response_model_by_alias: bool = True,
  924. response_model_exclude_unset: bool = False,
  925. response_model_exclude_defaults: bool = False,
  926. response_model_exclude_none: bool = False,
  927. include_in_schema: bool = True,
  928. response_class: Type[Response] = Default(JSONResponse),
  929. name: Optional[str] = None,
  930. callbacks: Optional[List[BaseRoute]] = None,
  931. openapi_extra: Optional[Dict[str, Any]] = None,
  932. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  933. return self.api_route(
  934. path=path,
  935. response_model=response_model,
  936. status_code=status_code,
  937. tags=tags,
  938. dependencies=dependencies,
  939. summary=summary,
  940. description=description,
  941. response_description=response_description,
  942. responses=responses,
  943. deprecated=deprecated,
  944. methods=["OPTIONS"],
  945. operation_id=operation_id,
  946. response_model_include=response_model_include,
  947. response_model_exclude=response_model_exclude,
  948. response_model_by_alias=response_model_by_alias,
  949. response_model_exclude_unset=response_model_exclude_unset,
  950. response_model_exclude_defaults=response_model_exclude_defaults,
  951. response_model_exclude_none=response_model_exclude_none,
  952. include_in_schema=include_in_schema,
  953. response_class=response_class,
  954. name=name,
  955. callbacks=callbacks,
  956. openapi_extra=openapi_extra,
  957. )
  958. def head(
  959. self,
  960. path: str,
  961. *,
  962. response_model: Optional[Type[Any]] = None,
  963. status_code: Optional[int] = None,
  964. tags: Optional[List[str]] = None,
  965. dependencies: Optional[Sequence[params.Depends]] = None,
  966. summary: Optional[str] = None,
  967. description: Optional[str] = None,
  968. response_description: str = "Successful Response",
  969. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  970. deprecated: Optional[bool] = None,
  971. operation_id: Optional[str] = None,
  972. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  973. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  974. response_model_by_alias: bool = True,
  975. response_model_exclude_unset: bool = False,
  976. response_model_exclude_defaults: bool = False,
  977. response_model_exclude_none: bool = False,
  978. include_in_schema: bool = True,
  979. response_class: Type[Response] = Default(JSONResponse),
  980. name: Optional[str] = None,
  981. callbacks: Optional[List[BaseRoute]] = None,
  982. openapi_extra: Optional[Dict[str, Any]] = None,
  983. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  984. return self.api_route(
  985. path=path,
  986. response_model=response_model,
  987. status_code=status_code,
  988. tags=tags,
  989. dependencies=dependencies,
  990. summary=summary,
  991. description=description,
  992. response_description=response_description,
  993. responses=responses,
  994. deprecated=deprecated,
  995. methods=["HEAD"],
  996. operation_id=operation_id,
  997. response_model_include=response_model_include,
  998. response_model_exclude=response_model_exclude,
  999. response_model_by_alias=response_model_by_alias,
  1000. response_model_exclude_unset=response_model_exclude_unset,
  1001. response_model_exclude_defaults=response_model_exclude_defaults,
  1002. response_model_exclude_none=response_model_exclude_none,
  1003. include_in_schema=include_in_schema,
  1004. response_class=response_class,
  1005. name=name,
  1006. callbacks=callbacks,
  1007. openapi_extra=openapi_extra,
  1008. )
  1009. def patch(
  1010. self,
  1011. path: str,
  1012. *,
  1013. response_model: Optional[Type[Any]] = None,
  1014. status_code: Optional[int] = None,
  1015. tags: Optional[List[str]] = None,
  1016. dependencies: Optional[Sequence[params.Depends]] = None,
  1017. summary: Optional[str] = None,
  1018. description: Optional[str] = None,
  1019. response_description: str = "Successful Response",
  1020. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  1021. deprecated: Optional[bool] = None,
  1022. operation_id: Optional[str] = None,
  1023. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1024. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1025. response_model_by_alias: bool = True,
  1026. response_model_exclude_unset: bool = False,
  1027. response_model_exclude_defaults: bool = False,
  1028. response_model_exclude_none: bool = False,
  1029. include_in_schema: bool = True,
  1030. response_class: Type[Response] = Default(JSONResponse),
  1031. name: Optional[str] = None,
  1032. callbacks: Optional[List[BaseRoute]] = None,
  1033. openapi_extra: Optional[Dict[str, Any]] = None,
  1034. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  1035. return self.api_route(
  1036. path=path,
  1037. response_model=response_model,
  1038. status_code=status_code,
  1039. tags=tags,
  1040. dependencies=dependencies,
  1041. summary=summary,
  1042. description=description,
  1043. response_description=response_description,
  1044. responses=responses,
  1045. deprecated=deprecated,
  1046. methods=["PATCH"],
  1047. operation_id=operation_id,
  1048. response_model_include=response_model_include,
  1049. response_model_exclude=response_model_exclude,
  1050. response_model_by_alias=response_model_by_alias,
  1051. response_model_exclude_unset=response_model_exclude_unset,
  1052. response_model_exclude_defaults=response_model_exclude_defaults,
  1053. response_model_exclude_none=response_model_exclude_none,
  1054. include_in_schema=include_in_schema,
  1055. response_class=response_class,
  1056. name=name,
  1057. callbacks=callbacks,
  1058. openapi_extra=openapi_extra,
  1059. )
  1060. def trace(
  1061. self,
  1062. path: str,
  1063. *,
  1064. response_model: Optional[Type[Any]] = None,
  1065. status_code: Optional[int] = None,
  1066. tags: Optional[List[str]] = None,
  1067. dependencies: Optional[Sequence[params.Depends]] = None,
  1068. summary: Optional[str] = None,
  1069. description: Optional[str] = None,
  1070. response_description: str = "Successful Response",
  1071. responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
  1072. deprecated: Optional[bool] = None,
  1073. operation_id: Optional[str] = None,
  1074. response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1075. response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
  1076. response_model_by_alias: bool = True,
  1077. response_model_exclude_unset: bool = False,
  1078. response_model_exclude_defaults: bool = False,
  1079. response_model_exclude_none: bool = False,
  1080. include_in_schema: bool = True,
  1081. response_class: Type[Response] = Default(JSONResponse),
  1082. name: Optional[str] = None,
  1083. callbacks: Optional[List[BaseRoute]] = None,
  1084. openapi_extra: Optional[Dict[str, Any]] = None,
  1085. ) -> Callable[[DecoratedCallable], DecoratedCallable]:
  1086. return self.api_route(
  1087. path=path,
  1088. response_model=response_model,
  1089. status_code=status_code,
  1090. tags=tags,
  1091. dependencies=dependencies,
  1092. summary=summary,
  1093. description=description,
  1094. response_description=response_description,
  1095. responses=responses,
  1096. deprecated=deprecated,
  1097. methods=["TRACE"],
  1098. operation_id=operation_id,
  1099. response_model_include=response_model_include,
  1100. response_model_exclude=response_model_exclude,
  1101. response_model_by_alias=response_model_by_alias,
  1102. response_model_exclude_unset=response_model_exclude_unset,
  1103. response_model_exclude_defaults=response_model_exclude_defaults,
  1104. response_model_exclude_none=response_model_exclude_none,
  1105. include_in_schema=include_in_schema,
  1106. response_class=response_class,
  1107. name=name,
  1108. callbacks=callbacks,
  1109. openapi_extra=openapi_extra,
  1110. )