You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

570 lines
23 KiB

  1. import http.client
  2. import inspect
  3. import warnings
  4. from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
  5. from fastapi import routing
  6. from fastapi._compat import (
  7. GenerateJsonSchema,
  8. JsonSchemaValue,
  9. ModelField,
  10. Undefined,
  11. get_compat_model_name_map,
  12. get_definitions,
  13. get_schema_from_model_field,
  14. lenient_issubclass,
  15. )
  16. from fastapi.datastructures import DefaultPlaceholder
  17. from fastapi.dependencies.models import Dependant
  18. from fastapi.dependencies.utils import (
  19. _get_flat_fields_from_params,
  20. get_flat_dependant,
  21. get_flat_params,
  22. )
  23. from fastapi.encoders import jsonable_encoder
  24. from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
  25. from fastapi.openapi.models import OpenAPI
  26. from fastapi.params import Body, ParamTypes
  27. from fastapi.responses import Response
  28. from fastapi.types import ModelNameMap
  29. from fastapi.utils import (
  30. deep_dict_update,
  31. generate_operation_id_for_path,
  32. is_body_allowed_for_status_code,
  33. )
  34. from pydantic import BaseModel
  35. from starlette.responses import JSONResponse
  36. from starlette.routing import BaseRoute
  37. from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
  38. from typing_extensions import Literal
  39. validation_error_definition = {
  40. "title": "ValidationError",
  41. "type": "object",
  42. "properties": {
  43. "loc": {
  44. "title": "Location",
  45. "type": "array",
  46. "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
  47. },
  48. "msg": {"title": "Message", "type": "string"},
  49. "type": {"title": "Error Type", "type": "string"},
  50. },
  51. "required": ["loc", "msg", "type"],
  52. }
  53. validation_error_response_definition = {
  54. "title": "HTTPValidationError",
  55. "type": "object",
  56. "properties": {
  57. "detail": {
  58. "title": "Detail",
  59. "type": "array",
  60. "items": {"$ref": REF_PREFIX + "ValidationError"},
  61. }
  62. },
  63. }
  64. status_code_ranges: Dict[str, str] = {
  65. "1XX": "Information",
  66. "2XX": "Success",
  67. "3XX": "Redirection",
  68. "4XX": "Client Error",
  69. "5XX": "Server Error",
  70. "DEFAULT": "Default Response",
  71. }
  72. def get_openapi_security_definitions(
  73. flat_dependant: Dependant,
  74. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  75. security_definitions = {}
  76. operation_security = []
  77. for security_requirement in flat_dependant.security_requirements:
  78. security_definition = jsonable_encoder(
  79. security_requirement.security_scheme.model,
  80. by_alias=True,
  81. exclude_none=True,
  82. )
  83. security_name = security_requirement.security_scheme.scheme_name
  84. security_definitions[security_name] = security_definition
  85. operation_security.append({security_name: security_requirement.scopes})
  86. return security_definitions, operation_security
  87. def _get_openapi_operation_parameters(
  88. *,
  89. dependant: Dependant,
  90. schema_generator: GenerateJsonSchema,
  91. model_name_map: ModelNameMap,
  92. field_mapping: Dict[
  93. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  94. ],
  95. separate_input_output_schemas: bool = True,
  96. ) -> List[Dict[str, Any]]:
  97. parameters = []
  98. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  99. path_params = _get_flat_fields_from_params(flat_dependant.path_params)
  100. query_params = _get_flat_fields_from_params(flat_dependant.query_params)
  101. header_params = _get_flat_fields_from_params(flat_dependant.header_params)
  102. cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
  103. parameter_groups = [
  104. (ParamTypes.path, path_params),
  105. (ParamTypes.query, query_params),
  106. (ParamTypes.header, header_params),
  107. (ParamTypes.cookie, cookie_params),
  108. ]
  109. default_convert_underscores = True
  110. if len(flat_dependant.header_params) == 1:
  111. first_field = flat_dependant.header_params[0]
  112. if lenient_issubclass(first_field.type_, BaseModel):
  113. default_convert_underscores = getattr(
  114. first_field.field_info, "convert_underscores", True
  115. )
  116. for param_type, param_group in parameter_groups:
  117. for param in param_group:
  118. field_info = param.field_info
  119. # field_info = cast(Param, field_info)
  120. if not getattr(field_info, "include_in_schema", True):
  121. continue
  122. param_schema = get_schema_from_model_field(
  123. field=param,
  124. schema_generator=schema_generator,
  125. model_name_map=model_name_map,
  126. field_mapping=field_mapping,
  127. separate_input_output_schemas=separate_input_output_schemas,
  128. )
  129. name = param.alias
  130. convert_underscores = getattr(
  131. param.field_info,
  132. "convert_underscores",
  133. default_convert_underscores,
  134. )
  135. if (
  136. param_type == ParamTypes.header
  137. and param.alias == param.name
  138. and convert_underscores
  139. ):
  140. name = param.name.replace("_", "-")
  141. parameter = {
  142. "name": name,
  143. "in": param_type.value,
  144. "required": param.required,
  145. "schema": param_schema,
  146. }
  147. if field_info.description:
  148. parameter["description"] = field_info.description
  149. openapi_examples = getattr(field_info, "openapi_examples", None)
  150. example = getattr(field_info, "example", None)
  151. if openapi_examples:
  152. parameter["examples"] = jsonable_encoder(openapi_examples)
  153. elif example != Undefined:
  154. parameter["example"] = jsonable_encoder(example)
  155. if getattr(field_info, "deprecated", None):
  156. parameter["deprecated"] = True
  157. parameters.append(parameter)
  158. return parameters
  159. def get_openapi_operation_request_body(
  160. *,
  161. body_field: Optional[ModelField],
  162. schema_generator: GenerateJsonSchema,
  163. model_name_map: ModelNameMap,
  164. field_mapping: Dict[
  165. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  166. ],
  167. separate_input_output_schemas: bool = True,
  168. ) -> Optional[Dict[str, Any]]:
  169. if not body_field:
  170. return None
  171. assert isinstance(body_field, ModelField)
  172. body_schema = get_schema_from_model_field(
  173. field=body_field,
  174. schema_generator=schema_generator,
  175. model_name_map=model_name_map,
  176. field_mapping=field_mapping,
  177. separate_input_output_schemas=separate_input_output_schemas,
  178. )
  179. field_info = cast(Body, body_field.field_info)
  180. request_media_type = field_info.media_type
  181. required = body_field.required
  182. request_body_oai: Dict[str, Any] = {}
  183. if required:
  184. request_body_oai["required"] = required
  185. request_media_content: Dict[str, Any] = {"schema": body_schema}
  186. if field_info.openapi_examples:
  187. request_media_content["examples"] = jsonable_encoder(
  188. field_info.openapi_examples
  189. )
  190. elif field_info.example != Undefined:
  191. request_media_content["example"] = jsonable_encoder(field_info.example)
  192. request_body_oai["content"] = {request_media_type: request_media_content}
  193. return request_body_oai
  194. def generate_operation_id(
  195. *, route: routing.APIRoute, method: str
  196. ) -> str: # pragma: nocover
  197. warnings.warn(
  198. "fastapi.openapi.utils.generate_operation_id() was deprecated, "
  199. "it is not used internally, and will be removed soon",
  200. DeprecationWarning,
  201. stacklevel=2,
  202. )
  203. if route.operation_id:
  204. return route.operation_id
  205. path: str = route.path_format
  206. return generate_operation_id_for_path(name=route.name, path=path, method=method)
  207. def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
  208. if route.summary:
  209. return route.summary
  210. return route.name.replace("_", " ").title()
  211. def get_openapi_operation_metadata(
  212. *, route: routing.APIRoute, method: str, operation_ids: Set[str]
  213. ) -> Dict[str, Any]:
  214. operation: Dict[str, Any] = {}
  215. if route.tags:
  216. operation["tags"] = route.tags
  217. operation["summary"] = generate_operation_summary(route=route, method=method)
  218. if route.description:
  219. operation["description"] = route.description
  220. operation_id = route.operation_id or route.unique_id
  221. if operation_id in operation_ids:
  222. message = (
  223. f"Duplicate Operation ID {operation_id} for function "
  224. + f"{route.endpoint.__name__}"
  225. )
  226. file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
  227. if file_name:
  228. message += f" at {file_name}"
  229. warnings.warn(message, stacklevel=1)
  230. operation_ids.add(operation_id)
  231. operation["operationId"] = operation_id
  232. if route.deprecated:
  233. operation["deprecated"] = route.deprecated
  234. return operation
  235. def get_openapi_path(
  236. *,
  237. route: routing.APIRoute,
  238. operation_ids: Set[str],
  239. schema_generator: GenerateJsonSchema,
  240. model_name_map: ModelNameMap,
  241. field_mapping: Dict[
  242. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  243. ],
  244. separate_input_output_schemas: bool = True,
  245. ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
  246. path = {}
  247. security_schemes: Dict[str, Any] = {}
  248. definitions: Dict[str, Any] = {}
  249. assert route.methods is not None, "Methods must be a list"
  250. if isinstance(route.response_class, DefaultPlaceholder):
  251. current_response_class: Type[Response] = route.response_class.value
  252. else:
  253. current_response_class = route.response_class
  254. assert current_response_class, "A response class is needed to generate OpenAPI"
  255. route_response_media_type: Optional[str] = current_response_class.media_type
  256. if route.include_in_schema:
  257. for method in route.methods:
  258. operation = get_openapi_operation_metadata(
  259. route=route, method=method, operation_ids=operation_ids
  260. )
  261. parameters: List[Dict[str, Any]] = []
  262. flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
  263. security_definitions, operation_security = get_openapi_security_definitions(
  264. flat_dependant=flat_dependant
  265. )
  266. if operation_security:
  267. operation.setdefault("security", []).extend(operation_security)
  268. if security_definitions:
  269. security_schemes.update(security_definitions)
  270. operation_parameters = _get_openapi_operation_parameters(
  271. dependant=route.dependant,
  272. schema_generator=schema_generator,
  273. model_name_map=model_name_map,
  274. field_mapping=field_mapping,
  275. separate_input_output_schemas=separate_input_output_schemas,
  276. )
  277. parameters.extend(operation_parameters)
  278. if parameters:
  279. all_parameters = {
  280. (param["in"], param["name"]): param for param in parameters
  281. }
  282. required_parameters = {
  283. (param["in"], param["name"]): param
  284. for param in parameters
  285. if param.get("required")
  286. }
  287. # Make sure required definitions of the same parameter take precedence
  288. # over non-required definitions
  289. all_parameters.update(required_parameters)
  290. operation["parameters"] = list(all_parameters.values())
  291. if method in METHODS_WITH_BODY:
  292. request_body_oai = get_openapi_operation_request_body(
  293. body_field=route.body_field,
  294. schema_generator=schema_generator,
  295. model_name_map=model_name_map,
  296. field_mapping=field_mapping,
  297. separate_input_output_schemas=separate_input_output_schemas,
  298. )
  299. if request_body_oai:
  300. operation["requestBody"] = request_body_oai
  301. if route.callbacks:
  302. callbacks = {}
  303. for callback in route.callbacks:
  304. if isinstance(callback, routing.APIRoute):
  305. (
  306. cb_path,
  307. cb_security_schemes,
  308. cb_definitions,
  309. ) = get_openapi_path(
  310. route=callback,
  311. operation_ids=operation_ids,
  312. schema_generator=schema_generator,
  313. model_name_map=model_name_map,
  314. field_mapping=field_mapping,
  315. separate_input_output_schemas=separate_input_output_schemas,
  316. )
  317. callbacks[callback.name] = {callback.path: cb_path}
  318. operation["callbacks"] = callbacks
  319. if route.status_code is not None:
  320. status_code = str(route.status_code)
  321. else:
  322. # It would probably make more sense for all response classes to have an
  323. # explicit default status_code, and to extract it from them, instead of
  324. # doing this inspection tricks, that would probably be in the future
  325. # TODO: probably make status_code a default class attribute for all
  326. # responses in Starlette
  327. response_signature = inspect.signature(current_response_class.__init__)
  328. status_code_param = response_signature.parameters.get("status_code")
  329. if status_code_param is not None:
  330. if isinstance(status_code_param.default, int):
  331. status_code = str(status_code_param.default)
  332. operation.setdefault("responses", {}).setdefault(status_code, {})[
  333. "description"
  334. ] = route.response_description
  335. if route_response_media_type and is_body_allowed_for_status_code(
  336. route.status_code
  337. ):
  338. response_schema = {"type": "string"}
  339. if lenient_issubclass(current_response_class, JSONResponse):
  340. if route.response_field:
  341. response_schema = get_schema_from_model_field(
  342. field=route.response_field,
  343. schema_generator=schema_generator,
  344. model_name_map=model_name_map,
  345. field_mapping=field_mapping,
  346. separate_input_output_schemas=separate_input_output_schemas,
  347. )
  348. else:
  349. response_schema = {}
  350. operation.setdefault("responses", {}).setdefault(
  351. status_code, {}
  352. ).setdefault("content", {}).setdefault(route_response_media_type, {})[
  353. "schema"
  354. ] = response_schema
  355. if route.responses:
  356. operation_responses = operation.setdefault("responses", {})
  357. for (
  358. additional_status_code,
  359. additional_response,
  360. ) in route.responses.items():
  361. process_response = additional_response.copy()
  362. process_response.pop("model", None)
  363. status_code_key = str(additional_status_code).upper()
  364. if status_code_key == "DEFAULT":
  365. status_code_key = "default"
  366. openapi_response = operation_responses.setdefault(
  367. status_code_key, {}
  368. )
  369. assert isinstance(process_response, dict), (
  370. "An additional response must be a dict"
  371. )
  372. field = route.response_fields.get(additional_status_code)
  373. additional_field_schema: Optional[Dict[str, Any]] = None
  374. if field:
  375. additional_field_schema = get_schema_from_model_field(
  376. field=field,
  377. schema_generator=schema_generator,
  378. model_name_map=model_name_map,
  379. field_mapping=field_mapping,
  380. separate_input_output_schemas=separate_input_output_schemas,
  381. )
  382. media_type = route_response_media_type or "application/json"
  383. additional_schema = (
  384. process_response.setdefault("content", {})
  385. .setdefault(media_type, {})
  386. .setdefault("schema", {})
  387. )
  388. deep_dict_update(additional_schema, additional_field_schema)
  389. status_text: Optional[str] = status_code_ranges.get(
  390. str(additional_status_code).upper()
  391. ) or http.client.responses.get(int(additional_status_code))
  392. description = (
  393. process_response.get("description")
  394. or openapi_response.get("description")
  395. or status_text
  396. or "Additional Response"
  397. )
  398. deep_dict_update(openapi_response, process_response)
  399. openapi_response["description"] = description
  400. http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
  401. all_route_params = get_flat_params(route.dependant)
  402. if (all_route_params or route.body_field) and not any(
  403. status in operation["responses"]
  404. for status in [http422, "4XX", "default"]
  405. ):
  406. operation["responses"][http422] = {
  407. "description": "Validation Error",
  408. "content": {
  409. "application/json": {
  410. "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
  411. }
  412. },
  413. }
  414. if "ValidationError" not in definitions:
  415. definitions.update(
  416. {
  417. "ValidationError": validation_error_definition,
  418. "HTTPValidationError": validation_error_response_definition,
  419. }
  420. )
  421. if route.openapi_extra:
  422. deep_dict_update(operation, route.openapi_extra)
  423. path[method.lower()] = operation
  424. return path, security_schemes, definitions
  425. def get_fields_from_routes(
  426. routes: Sequence[BaseRoute],
  427. ) -> List[ModelField]:
  428. body_fields_from_routes: List[ModelField] = []
  429. responses_from_routes: List[ModelField] = []
  430. request_fields_from_routes: List[ModelField] = []
  431. callback_flat_models: List[ModelField] = []
  432. for route in routes:
  433. if getattr(route, "include_in_schema", None) and isinstance(
  434. route, routing.APIRoute
  435. ):
  436. if route.body_field:
  437. assert isinstance(route.body_field, ModelField), (
  438. "A request body must be a Pydantic Field"
  439. )
  440. body_fields_from_routes.append(route.body_field)
  441. if route.response_field:
  442. responses_from_routes.append(route.response_field)
  443. if route.response_fields:
  444. responses_from_routes.extend(route.response_fields.values())
  445. if route.callbacks:
  446. callback_flat_models.extend(get_fields_from_routes(route.callbacks))
  447. params = get_flat_params(route.dependant)
  448. request_fields_from_routes.extend(params)
  449. flat_models = callback_flat_models + list(
  450. body_fields_from_routes + responses_from_routes + request_fields_from_routes
  451. )
  452. return flat_models
  453. def get_openapi(
  454. *,
  455. title: str,
  456. version: str,
  457. openapi_version: str = "3.1.0",
  458. summary: Optional[str] = None,
  459. description: Optional[str] = None,
  460. routes: Sequence[BaseRoute],
  461. webhooks: Optional[Sequence[BaseRoute]] = None,
  462. tags: Optional[List[Dict[str, Any]]] = None,
  463. servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
  464. terms_of_service: Optional[str] = None,
  465. contact: Optional[Dict[str, Union[str, Any]]] = None,
  466. license_info: Optional[Dict[str, Union[str, Any]]] = None,
  467. separate_input_output_schemas: bool = True,
  468. ) -> Dict[str, Any]:
  469. info: Dict[str, Any] = {"title": title, "version": version}
  470. if summary:
  471. info["summary"] = summary
  472. if description:
  473. info["description"] = description
  474. if terms_of_service:
  475. info["termsOfService"] = terms_of_service
  476. if contact:
  477. info["contact"] = contact
  478. if license_info:
  479. info["license"] = license_info
  480. output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
  481. if servers:
  482. output["servers"] = servers
  483. components: Dict[str, Dict[str, Any]] = {}
  484. paths: Dict[str, Dict[str, Any]] = {}
  485. webhook_paths: Dict[str, Dict[str, Any]] = {}
  486. operation_ids: Set[str] = set()
  487. all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
  488. model_name_map = get_compat_model_name_map(all_fields)
  489. schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
  490. field_mapping, definitions = get_definitions(
  491. fields=all_fields,
  492. schema_generator=schema_generator,
  493. model_name_map=model_name_map,
  494. separate_input_output_schemas=separate_input_output_schemas,
  495. )
  496. for route in routes or []:
  497. if isinstance(route, routing.APIRoute):
  498. result = get_openapi_path(
  499. route=route,
  500. operation_ids=operation_ids,
  501. schema_generator=schema_generator,
  502. model_name_map=model_name_map,
  503. field_mapping=field_mapping,
  504. separate_input_output_schemas=separate_input_output_schemas,
  505. )
  506. if result:
  507. path, security_schemes, path_definitions = result
  508. if path:
  509. paths.setdefault(route.path_format, {}).update(path)
  510. if security_schemes:
  511. components.setdefault("securitySchemes", {}).update(
  512. security_schemes
  513. )
  514. if path_definitions:
  515. definitions.update(path_definitions)
  516. for webhook in webhooks or []:
  517. if isinstance(webhook, routing.APIRoute):
  518. result = get_openapi_path(
  519. route=webhook,
  520. operation_ids=operation_ids,
  521. schema_generator=schema_generator,
  522. model_name_map=model_name_map,
  523. field_mapping=field_mapping,
  524. separate_input_output_schemas=separate_input_output_schemas,
  525. )
  526. if result:
  527. path, security_schemes, path_definitions = result
  528. if path:
  529. webhook_paths.setdefault(webhook.path_format, {}).update(path)
  530. if security_schemes:
  531. components.setdefault("securitySchemes", {}).update(
  532. security_schemes
  533. )
  534. if path_definitions:
  535. definitions.update(path_definitions)
  536. if definitions:
  537. components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
  538. if components:
  539. output["components"] = components
  540. output["paths"] = paths
  541. if webhook_paths:
  542. output["webhooks"] = webhook_paths
  543. if tags:
  544. output["tags"] = tags
  545. return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore