You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

136 line
4.4 KiB

  1. import inspect
  2. import typing
  3. from starlette.requests import Request
  4. from starlette.responses import Response
  5. from starlette.routing import BaseRoute, Mount, Route
  6. try:
  7. import yaml
  8. except ImportError: # pragma: nocover
  9. yaml = None # type: ignore
  10. class OpenAPIResponse(Response):
  11. media_type = "application/vnd.oai.openapi"
  12. def render(self, content: typing.Any) -> bytes:
  13. assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
  14. assert isinstance(
  15. content, dict
  16. ), "The schema passed to OpenAPIResponse should be a dictionary."
  17. return yaml.dump(content, default_flow_style=False).encode("utf-8")
  18. class EndpointInfo(typing.NamedTuple):
  19. path: str
  20. http_method: str
  21. func: typing.Callable
  22. class BaseSchemaGenerator:
  23. def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
  24. raise NotImplementedError() # pragma: no cover
  25. def get_endpoints(
  26. self, routes: typing.List[BaseRoute]
  27. ) -> typing.List[EndpointInfo]:
  28. """
  29. Given the routes, yields the following information:
  30. - path
  31. eg: /users/
  32. - http_method
  33. one of 'get', 'post', 'put', 'patch', 'delete', 'options'
  34. - func
  35. method ready to extract the docstring
  36. """
  37. endpoints_info: list = []
  38. for route in routes:
  39. if isinstance(route, Mount):
  40. routes = route.routes or []
  41. sub_endpoints = [
  42. EndpointInfo(
  43. path="".join((route.path, sub_endpoint.path)),
  44. http_method=sub_endpoint.http_method,
  45. func=sub_endpoint.func,
  46. )
  47. for sub_endpoint in self.get_endpoints(routes)
  48. ]
  49. endpoints_info.extend(sub_endpoints)
  50. elif not isinstance(route, Route) or not route.include_in_schema:
  51. continue
  52. elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
  53. for method in route.methods or ["GET"]:
  54. if method == "HEAD":
  55. continue
  56. endpoints_info.append(
  57. EndpointInfo(route.path, method.lower(), route.endpoint)
  58. )
  59. else:
  60. for method in ["get", "post", "put", "patch", "delete", "options"]:
  61. if not hasattr(route.endpoint, method):
  62. continue
  63. func = getattr(route.endpoint, method)
  64. endpoints_info.append(
  65. EndpointInfo(route.path, method.lower(), func)
  66. )
  67. return endpoints_info
  68. def parse_docstring(self, func_or_method: typing.Callable) -> dict:
  69. """
  70. Given a function, parse the docstring as YAML and return a dictionary of info.
  71. """
  72. docstring = func_or_method.__doc__
  73. if not docstring:
  74. return {}
  75. assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
  76. # We support having regular docstrings before the schema
  77. # definition. Here we return just the schema part from
  78. # the docstring.
  79. docstring = docstring.split("---")[-1]
  80. parsed = yaml.safe_load(docstring)
  81. if not isinstance(parsed, dict):
  82. # A regular docstring (not yaml formatted) can return
  83. # a simple string here, which wouldn't follow the schema.
  84. return {}
  85. return parsed
  86. def OpenAPIResponse(self, request: Request) -> Response:
  87. routes = request.app.routes
  88. schema = self.get_schema(routes=routes)
  89. return OpenAPIResponse(schema)
  90. class SchemaGenerator(BaseSchemaGenerator):
  91. def __init__(self, base_schema: dict) -> None:
  92. self.base_schema = base_schema
  93. def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
  94. schema = dict(self.base_schema)
  95. schema.setdefault("paths", {})
  96. endpoints_info = self.get_endpoints(routes)
  97. for endpoint in endpoints_info:
  98. parsed = self.parse_docstring(endpoint.func)
  99. if not parsed:
  100. continue
  101. if endpoint.path not in schema["paths"]:
  102. schema["paths"][endpoint.path] = {}
  103. schema["paths"][endpoint.path][endpoint.http_method] = parsed
  104. return schema