|
- import inspect
- import typing
-
- from starlette.requests import Request
- from starlette.responses import Response
- from starlette.routing import BaseRoute, Mount, Route
-
- try:
- import yaml
- except ImportError: # pragma: nocover
- yaml = None # type: ignore
-
-
- class OpenAPIResponse(Response):
- media_type = "application/vnd.oai.openapi"
-
- def render(self, content: typing.Any) -> bytes:
- assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
- assert isinstance(
- content, dict
- ), "The schema passed to OpenAPIResponse should be a dictionary."
- return yaml.dump(content, default_flow_style=False).encode("utf-8")
-
-
- class EndpointInfo(typing.NamedTuple):
- path: str
- http_method: str
- func: typing.Callable
-
-
- class BaseSchemaGenerator:
- def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
- raise NotImplementedError() # pragma: no cover
-
- def get_endpoints(
- self, routes: typing.List[BaseRoute]
- ) -> typing.List[EndpointInfo]:
- """
- Given the routes, yields the following information:
-
- - path
- eg: /users/
- - http_method
- one of 'get', 'post', 'put', 'patch', 'delete', 'options'
- - func
- method ready to extract the docstring
- """
- endpoints_info: list = []
-
- for route in routes:
- if isinstance(route, Mount):
- routes = route.routes or []
- sub_endpoints = [
- EndpointInfo(
- path="".join((route.path, sub_endpoint.path)),
- http_method=sub_endpoint.http_method,
- func=sub_endpoint.func,
- )
- for sub_endpoint in self.get_endpoints(routes)
- ]
- endpoints_info.extend(sub_endpoints)
-
- elif not isinstance(route, Route) or not route.include_in_schema:
- continue
-
- elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
- for method in route.methods or ["GET"]:
- if method == "HEAD":
- continue
- endpoints_info.append(
- EndpointInfo(route.path, method.lower(), route.endpoint)
- )
- else:
- for method in ["get", "post", "put", "patch", "delete", "options"]:
- if not hasattr(route.endpoint, method):
- continue
- func = getattr(route.endpoint, method)
- endpoints_info.append(
- EndpointInfo(route.path, method.lower(), func)
- )
-
- return endpoints_info
-
- def parse_docstring(self, func_or_method: typing.Callable) -> dict:
- """
- Given a function, parse the docstring as YAML and return a dictionary of info.
- """
- docstring = func_or_method.__doc__
- if not docstring:
- return {}
-
- assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
-
- # We support having regular docstrings before the schema
- # definition. Here we return just the schema part from
- # the docstring.
- docstring = docstring.split("---")[-1]
-
- parsed = yaml.safe_load(docstring)
-
- if not isinstance(parsed, dict):
- # A regular docstring (not yaml formatted) can return
- # a simple string here, which wouldn't follow the schema.
- return {}
-
- return parsed
-
- def OpenAPIResponse(self, request: Request) -> Response:
- routes = request.app.routes
- schema = self.get_schema(routes=routes)
- return OpenAPIResponse(schema)
-
-
- class SchemaGenerator(BaseSchemaGenerator):
- def __init__(self, base_schema: dict) -> None:
- self.base_schema = base_schema
-
- def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
- schema = dict(self.base_schema)
- schema.setdefault("paths", {})
- endpoints_info = self.get_endpoints(routes)
-
- for endpoint in endpoints_info:
-
- parsed = self.parse_docstring(endpoint.func)
-
- if not parsed:
- continue
-
- if endpoint.path not in schema["paths"]:
- schema["paths"][endpoint.path] = {}
-
- schema["paths"][endpoint.path][endpoint.http_method] = parsed
-
- return schema
|