Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 
 

136 righe
4.4 KiB

  1. import inspect
  2. import typing
  3. from starlette.requests import Request
  4. from starlette.responses import Response
  5. from starlette.routing import BaseRoute, Mount, Route
  6. try:
  7. import yaml
  8. except ImportError: # pragma: nocover
  9. yaml = None # type: ignore
  10. class OpenAPIResponse(Response):
  11. media_type = "application/vnd.oai.openapi"
  12. def render(self, content: typing.Any) -> bytes:
  13. assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
  14. assert isinstance(
  15. content, dict
  16. ), "The schema passed to OpenAPIResponse should be a dictionary."
  17. return yaml.dump(content, default_flow_style=False).encode("utf-8")
  18. class EndpointInfo(typing.NamedTuple):
  19. path: str
  20. http_method: str
  21. func: typing.Callable
  22. class BaseSchemaGenerator:
  23. def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
  24. raise NotImplementedError() # pragma: no cover
  25. def get_endpoints(
  26. self, routes: typing.List[BaseRoute]
  27. ) -> typing.List[EndpointInfo]:
  28. """
  29. Given the routes, yields the following information:
  30. - path
  31. eg: /users/
  32. - http_method
  33. one of 'get', 'post', 'put', 'patch', 'delete', 'options'
  34. - func
  35. method ready to extract the docstring
  36. """
  37. endpoints_info: list = []
  38. for route in routes:
  39. if isinstance(route, Mount):
  40. routes = route.routes or []
  41. sub_endpoints = [
  42. EndpointInfo(
  43. path="".join((route.path, sub_endpoint.path)),
  44. http_method=sub_endpoint.http_method,
  45. func=sub_endpoint.func,
  46. )
  47. for sub_endpoint in self.get_endpoints(routes)
  48. ]
  49. endpoints_info.extend(sub_endpoints)
  50. elif not isinstance(route, Route) or not route.include_in_schema:
  51. continue
  52. elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
  53. for method in route.methods or ["GET"]:
  54. if method == "HEAD":
  55. continue
  56. endpoints_info.append(
  57. EndpointInfo(route.path, method.lower(), route.endpoint)
  58. )
  59. else:
  60. for method in ["get", "post", "put", "patch", "delete", "options"]:
  61. if not hasattr(route.endpoint, method):
  62. continue
  63. func = getattr(route.endpoint, method)
  64. endpoints_info.append(
  65. EndpointInfo(route.path, method.lower(), func)
  66. )
  67. return endpoints_info
  68. def parse_docstring(self, func_or_method: typing.Callable) -> dict:
  69. """
  70. Given a function, parse the docstring as YAML and return a dictionary of info.
  71. """
  72. docstring = func_or_method.__doc__
  73. if not docstring:
  74. return {}
  75. assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
  76. # We support having regular docstrings before the schema
  77. # definition. Here we return just the schema part from
  78. # the docstring.
  79. docstring = docstring.split("---")[-1]
  80. parsed = yaml.safe_load(docstring)
  81. if not isinstance(parsed, dict):
  82. # A regular docstring (not yaml formatted) can return
  83. # a simple string here, which wouldn't follow the schema.
  84. return {}
  85. return parsed
  86. def OpenAPIResponse(self, request: Request) -> Response:
  87. routes = request.app.routes
  88. schema = self.get_schema(routes=routes)
  89. return OpenAPIResponse(schema)
  90. class SchemaGenerator(BaseSchemaGenerator):
  91. def __init__(self, base_schema: dict) -> None:
  92. self.base_schema = base_schema
  93. def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
  94. schema = dict(self.base_schema)
  95. schema.setdefault("paths", {})
  96. endpoints_info = self.get_endpoints(routes)
  97. for endpoint in endpoints_info:
  98. parsed = self.parse_docstring(endpoint.func)
  99. if not parsed:
  100. continue
  101. if endpoint.path not in schema["paths"]:
  102. schema["paths"][endpoint.path] = {}
  103. schema["paths"][endpoint.path][endpoint.http_method] = parsed
  104. return schema