You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

411 lines
17 KiB

  1. import http.client
  2. import inspect
  3. from enum import Enum
  4. from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
  5. from fastapi import routing
  6. from fastapi.datastructures import DefaultPlaceholder
  7. from fastapi.dependencies.models import Dependant
  8. from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
  9. from fastapi.encoders import jsonable_encoder
  10. from fastapi.openapi.constants import (
  11. METHODS_WITH_BODY,
  12. REF_PREFIX,
  13. STATUS_CODES_WITH_NO_BODY,
  14. )
  15. from fastapi.openapi.models import OpenAPI
  16. from fastapi.params import Body, Param
  17. from fastapi.responses import Response
  18. from fastapi.utils import (
  19. deep_dict_update,
  20. generate_operation_id_for_path,
  21. get_model_definitions,
  22. )
  23. from pydantic import BaseModel
  24. from pydantic.fields import ModelField, Undefined
  25. from pydantic.schema import (
  26. field_schema,
  27. get_flat_models_from_fields,
  28. get_model_name_map,
  29. )
  30. from pydantic.utils import lenient_issubclass
  31. from starlette.responses import JSONResponse
  32. from starlette.routing import BaseRoute
  33. from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
  34. validation_error_definition = {
  35. "title": "ValidationError",
  36. "type": "object",
  37. "properties": {
  38. "loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
  39. "msg": {"title": "Message", "type": "string"},
  40. "type": {"title": "Error Type", "type": "string"},
  41. },
  42. "required": ["loc", "msg", "type"],
  43. }
  44. validation_error_response_definition = {
  45. "title": "HTTPValidationError",
  46. "type": "object",
  47. "properties": {
  48. "detail": {
  49. "title": "Detail",
  50. "type": "array",
  51. "items": {"$ref": REF_PREFIX + "ValidationError"},
  52. }
  53. },
  54. }
  55. status_code_ranges: Dict[str, str] = {
  56. "1XX": "Information",
  57. "2XX": "Success",
  58. "3XX": "Redirection",
  59. "4XX": "Client Error",
  60. "5XX": "Server Error",
  61. "DEFAULT": "Default Response",
  62. }
  63. def get_openapi_security_definitions(
  64. flat_dependant: Dependant,
  65. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  66. security_definitions = {}
  67. operation_security = []
  68. for security_requirement in flat_dependant.security_requirements:
  69. security_definition = jsonable_encoder(
  70. security_requirement.security_scheme.model,
  71. by_alias=True,
  72. exclude_none=True,
  73. )
  74. security_name = security_requirement.security_scheme.scheme_name
  75. security_definitions[security_name] = security_definition
  76. operation_security.append({security_name: security_requirement.scopes})
  77. return security_definitions, operation_security
  78. def get_openapi_operation_parameters(
  79. *,
  80. all_route_params: Sequence[ModelField],
  81. model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
  82. ) -> List[Dict[str, Any]]:
  83. parameters = []
  84. for param in all_route_params:
  85. field_info = param.field_info
  86. field_info = cast(Param, field_info)
  87. parameter = {
  88. "name": param.alias,
  89. "in": field_info.in_.value,
  90. "required": param.required,
  91. "schema": field_schema(
  92. param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  93. )[0],
  94. }
  95. if field_info.description:
  96. parameter["description"] = field_info.description
  97. if field_info.examples:
  98. parameter["examples"] = jsonable_encoder(field_info.examples)
  99. elif field_info.example != Undefined:
  100. parameter["example"] = jsonable_encoder(field_info.example)
  101. if field_info.deprecated:
  102. parameter["deprecated"] = field_info.deprecated
  103. parameters.append(parameter)
  104. return parameters
  105. def get_openapi_operation_request_body(
  106. *,
  107. body_field: Optional[ModelField],
  108. model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
  109. ) -> Optional[Dict[str, Any]]:
  110. if not body_field:
  111. return None
  112. assert isinstance(body_field, ModelField)
  113. body_schema, _, _ = field_schema(
  114. body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  115. )
  116. field_info = cast(Body, body_field.field_info)
  117. request_media_type = field_info.media_type
  118. required = body_field.required
  119. request_body_oai: Dict[str, Any] = {}
  120. if required:
  121. request_body_oai["required"] = required
  122. request_media_content: Dict[str, Any] = {"schema": body_schema}
  123. if field_info.examples:
  124. request_media_content["examples"] = jsonable_encoder(field_info.examples)
  125. elif field_info.example != Undefined:
  126. request_media_content["example"] = jsonable_encoder(field_info.example)
  127. request_body_oai["content"] = {request_media_type: request_media_content}
  128. return request_body_oai
  129. def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
  130. if route.operation_id:
  131. return route.operation_id
  132. path: str = route.path_format
  133. return generate_operation_id_for_path(name=route.name, path=path, method=method)
  134. def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
  135. if route.summary:
  136. return route.summary
  137. return route.name.replace("_", " ").title()
  138. def get_openapi_operation_metadata(
  139. *, route: routing.APIRoute, method: str
  140. ) -> Dict[str, Any]:
  141. operation: Dict[str, Any] = {}
  142. if route.tags:
  143. operation["tags"] = route.tags
  144. operation["summary"] = generate_operation_summary(route=route, method=method)
  145. if route.description:
  146. operation["description"] = route.description
  147. operation["operationId"] = generate_operation_id(route=route, method=method)
  148. if route.deprecated:
  149. operation["deprecated"] = route.deprecated
  150. return operation
  151. def get_openapi_path(
  152. *, route: routing.APIRoute, model_name_map: Dict[type, str]
  153. ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
  154. path = {}
  155. security_schemes: Dict[str, Any] = {}
  156. definitions: Dict[str, Any] = {}
  157. assert route.methods is not None, "Methods must be a list"
  158. if isinstance(route.response_class, DefaultPlaceholder):
  159. current_response_class: Type[Response] = route.response_class.value
  160. else:
  161. current_response_class = route.response_class
  162. assert current_response_class, "A response class is needed to generate OpenAPI"
  163. route_response_media_type: Optional[str] = current_response_class.media_type
  164. if route.include_in_schema:
  165. for method in route.methods:
  166. operation = get_openapi_operation_metadata(route=route, method=method)
  167. parameters: List[Dict[str, Any]] = []
  168. flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
  169. security_definitions, operation_security = get_openapi_security_definitions(
  170. flat_dependant=flat_dependant
  171. )
  172. if operation_security:
  173. operation.setdefault("security", []).extend(operation_security)
  174. if security_definitions:
  175. security_schemes.update(security_definitions)
  176. all_route_params = get_flat_params(route.dependant)
  177. operation_parameters = get_openapi_operation_parameters(
  178. all_route_params=all_route_params, model_name_map=model_name_map
  179. )
  180. parameters.extend(operation_parameters)
  181. if parameters:
  182. operation["parameters"] = list(
  183. {param["name"]: param for param in parameters}.values()
  184. )
  185. if method in METHODS_WITH_BODY:
  186. request_body_oai = get_openapi_operation_request_body(
  187. body_field=route.body_field, model_name_map=model_name_map
  188. )
  189. if request_body_oai:
  190. operation["requestBody"] = request_body_oai
  191. if route.callbacks:
  192. callbacks = {}
  193. for callback in route.callbacks:
  194. if isinstance(callback, routing.APIRoute):
  195. (
  196. cb_path,
  197. cb_security_schemes,
  198. cb_definitions,
  199. ) = get_openapi_path(
  200. route=callback, model_name_map=model_name_map
  201. )
  202. callbacks[callback.name] = {callback.path: cb_path}
  203. operation["callbacks"] = callbacks
  204. if route.status_code is not None:
  205. status_code = str(route.status_code)
  206. else:
  207. # It would probably make more sense for all response classes to have an
  208. # explicit default status_code, and to extract it from them, instead of
  209. # doing this inspection tricks, that would probably be in the future
  210. # TODO: probably make status_code a default class attribute for all
  211. # responses in Starlette
  212. response_signature = inspect.signature(current_response_class.__init__)
  213. status_code_param = response_signature.parameters.get("status_code")
  214. if status_code_param is not None:
  215. if isinstance(status_code_param.default, int):
  216. status_code = str(status_code_param.default)
  217. operation.setdefault("responses", {}).setdefault(status_code, {})[
  218. "description"
  219. ] = route.response_description
  220. if (
  221. route_response_media_type
  222. and route.status_code not in STATUS_CODES_WITH_NO_BODY
  223. ):
  224. response_schema = {"type": "string"}
  225. if lenient_issubclass(current_response_class, JSONResponse):
  226. if route.response_field:
  227. response_schema, _, _ = field_schema(
  228. route.response_field,
  229. model_name_map=model_name_map,
  230. ref_prefix=REF_PREFIX,
  231. )
  232. else:
  233. response_schema = {}
  234. operation.setdefault("responses", {}).setdefault(
  235. status_code, {}
  236. ).setdefault("content", {}).setdefault(route_response_media_type, {})[
  237. "schema"
  238. ] = response_schema
  239. if route.responses:
  240. operation_responses = operation.setdefault("responses", {})
  241. for (
  242. additional_status_code,
  243. additional_response,
  244. ) in route.responses.items():
  245. process_response = additional_response.copy()
  246. process_response.pop("model", None)
  247. status_code_key = str(additional_status_code).upper()
  248. if status_code_key == "DEFAULT":
  249. status_code_key = "default"
  250. openapi_response = operation_responses.setdefault(
  251. status_code_key, {}
  252. )
  253. assert isinstance(
  254. process_response, dict
  255. ), "An additional response must be a dict"
  256. field = route.response_fields.get(additional_status_code)
  257. additional_field_schema: Optional[Dict[str, Any]] = None
  258. if field:
  259. additional_field_schema, _, _ = field_schema(
  260. field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  261. )
  262. media_type = route_response_media_type or "application/json"
  263. additional_schema = (
  264. process_response.setdefault("content", {})
  265. .setdefault(media_type, {})
  266. .setdefault("schema", {})
  267. )
  268. deep_dict_update(additional_schema, additional_field_schema)
  269. status_text: Optional[str] = status_code_ranges.get(
  270. str(additional_status_code).upper()
  271. ) or http.client.responses.get(int(additional_status_code))
  272. description = (
  273. process_response.get("description")
  274. or openapi_response.get("description")
  275. or status_text
  276. or "Additional Response"
  277. )
  278. deep_dict_update(openapi_response, process_response)
  279. openapi_response["description"] = description
  280. http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
  281. if (all_route_params or route.body_field) and not any(
  282. [
  283. status in operation["responses"]
  284. for status in [http422, "4XX", "default"]
  285. ]
  286. ):
  287. operation["responses"][http422] = {
  288. "description": "Validation Error",
  289. "content": {
  290. "application/json": {
  291. "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
  292. }
  293. },
  294. }
  295. if "ValidationError" not in definitions:
  296. definitions.update(
  297. {
  298. "ValidationError": validation_error_definition,
  299. "HTTPValidationError": validation_error_response_definition,
  300. }
  301. )
  302. if route.openapi_extra:
  303. deep_dict_update(operation, route.openapi_extra)
  304. path[method.lower()] = operation
  305. return path, security_schemes, definitions
  306. def get_flat_models_from_routes(
  307. routes: Sequence[BaseRoute],
  308. ) -> Set[Union[Type[BaseModel], Type[Enum]]]:
  309. body_fields_from_routes: List[ModelField] = []
  310. responses_from_routes: List[ModelField] = []
  311. request_fields_from_routes: List[ModelField] = []
  312. callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
  313. for route in routes:
  314. if getattr(route, "include_in_schema", None) and isinstance(
  315. route, routing.APIRoute
  316. ):
  317. if route.body_field:
  318. assert isinstance(
  319. route.body_field, ModelField
  320. ), "A request body must be a Pydantic Field"
  321. body_fields_from_routes.append(route.body_field)
  322. if route.response_field:
  323. responses_from_routes.append(route.response_field)
  324. if route.response_fields:
  325. responses_from_routes.extend(route.response_fields.values())
  326. if route.callbacks:
  327. callback_flat_models |= get_flat_models_from_routes(route.callbacks)
  328. params = get_flat_params(route.dependant)
  329. request_fields_from_routes.extend(params)
  330. flat_models = callback_flat_models | get_flat_models_from_fields(
  331. body_fields_from_routes + responses_from_routes + request_fields_from_routes,
  332. known_models=set(),
  333. )
  334. return flat_models
  335. def get_openapi(
  336. *,
  337. title: str,
  338. version: str,
  339. openapi_version: str = "3.0.2",
  340. description: Optional[str] = None,
  341. routes: Sequence[BaseRoute],
  342. tags: Optional[List[Dict[str, Any]]] = None,
  343. servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
  344. terms_of_service: Optional[str] = None,
  345. contact: Optional[Dict[str, Union[str, Any]]] = None,
  346. license_info: Optional[Dict[str, Union[str, Any]]] = None,
  347. ) -> Dict[str, Any]:
  348. info: Dict[str, Any] = {"title": title, "version": version}
  349. if description:
  350. info["description"] = description
  351. if terms_of_service:
  352. info["termsOfService"] = terms_of_service
  353. if contact:
  354. info["contact"] = contact
  355. if license_info:
  356. info["license"] = license_info
  357. output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
  358. if servers:
  359. output["servers"] = servers
  360. components: Dict[str, Dict[str, Any]] = {}
  361. paths: Dict[str, Dict[str, Any]] = {}
  362. flat_models = get_flat_models_from_routes(routes)
  363. model_name_map = get_model_name_map(flat_models)
  364. definitions = get_model_definitions(
  365. flat_models=flat_models, model_name_map=model_name_map
  366. )
  367. for route in routes:
  368. if isinstance(route, routing.APIRoute):
  369. result = get_openapi_path(route=route, model_name_map=model_name_map)
  370. if result:
  371. path, security_schemes, path_definitions = result
  372. if path:
  373. paths.setdefault(route.path_format, {}).update(path)
  374. if security_schemes:
  375. components.setdefault("securitySchemes", {}).update(
  376. security_schemes
  377. )
  378. if path_definitions:
  379. definitions.update(path_definitions)
  380. if definitions:
  381. components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
  382. if components:
  383. output["components"] = components
  384. output["paths"] = paths
  385. if tags:
  386. output["tags"] = tags
  387. return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore