You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
145 lines
4.8 KiB
145 lines
4.8 KiB
import inspect |
|
import re |
|
import typing |
|
|
|
from starlette.requests import Request |
|
from starlette.responses import Response |
|
from starlette.routing import BaseRoute, Mount, Route |
|
|
|
try: |
|
import yaml |
|
except ModuleNotFoundError: # pragma: nocover |
|
yaml = None # type: ignore[assignment] |
|
|
|
|
|
class OpenAPIResponse(Response): |
|
media_type = "application/vnd.oai.openapi" |
|
|
|
def render(self, content: typing.Any) -> bytes: |
|
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." |
|
assert isinstance( |
|
content, dict |
|
), "The schema passed to OpenAPIResponse should be a dictionary." |
|
return yaml.dump(content, default_flow_style=False).encode("utf-8") |
|
|
|
|
|
class EndpointInfo(typing.NamedTuple): |
|
path: str |
|
http_method: str |
|
func: typing.Callable |
|
|
|
|
|
class BaseSchemaGenerator: |
|
def get_schema(self, routes: typing.List[BaseRoute]) -> dict: |
|
raise NotImplementedError() # pragma: no cover |
|
|
|
def get_endpoints( |
|
self, routes: typing.List[BaseRoute] |
|
) -> typing.List[EndpointInfo]: |
|
""" |
|
Given the routes, yields the following information: |
|
|
|
- path |
|
eg: /users/ |
|
- http_method |
|
one of 'get', 'post', 'put', 'patch', 'delete', 'options' |
|
- func |
|
method ready to extract the docstring |
|
""" |
|
endpoints_info: list = [] |
|
|
|
for route in routes: |
|
if isinstance(route, Mount): |
|
path = self._remove_converter(route.path) |
|
routes = route.routes or [] |
|
sub_endpoints = [ |
|
EndpointInfo( |
|
path="".join((path, sub_endpoint.path)), |
|
http_method=sub_endpoint.http_method, |
|
func=sub_endpoint.func, |
|
) |
|
for sub_endpoint in self.get_endpoints(routes) |
|
] |
|
endpoints_info.extend(sub_endpoints) |
|
|
|
elif not isinstance(route, Route) or not route.include_in_schema: |
|
continue |
|
|
|
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): |
|
path = self._remove_converter(route.path) |
|
for method in route.methods or ["GET"]: |
|
if method == "HEAD": |
|
continue |
|
endpoints_info.append( |
|
EndpointInfo(path, method.lower(), route.endpoint) |
|
) |
|
else: |
|
path = self._remove_converter(route.path) |
|
for method in ["get", "post", "put", "patch", "delete", "options"]: |
|
if not hasattr(route.endpoint, method): |
|
continue |
|
func = getattr(route.endpoint, method) |
|
endpoints_info.append(EndpointInfo(path, method.lower(), func)) |
|
|
|
return endpoints_info |
|
|
|
def _remove_converter(self, path: str) -> str: |
|
""" |
|
Remove the converter from the path. |
|
For example, a route like this: |
|
Route("/users/{id:int}", endpoint=get_user, methods=["GET"]) |
|
Should be represented as `/users/{id}` in the OpenAPI schema. |
|
""" |
|
return re.sub(r":\w+}", "}", path) |
|
|
|
def parse_docstring(self, func_or_method: typing.Callable) -> dict: |
|
""" |
|
Given a function, parse the docstring as YAML and return a dictionary of info. |
|
""" |
|
docstring = func_or_method.__doc__ |
|
if not docstring: |
|
return {} |
|
|
|
assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." |
|
|
|
# We support having regular docstrings before the schema |
|
# definition. Here we return just the schema part from |
|
# the docstring. |
|
docstring = docstring.split("---")[-1] |
|
|
|
parsed = yaml.safe_load(docstring) |
|
|
|
if not isinstance(parsed, dict): |
|
# A regular docstring (not yaml formatted) can return |
|
# a simple string here, which wouldn't follow the schema. |
|
return {} |
|
|
|
return parsed |
|
|
|
def OpenAPIResponse(self, request: Request) -> Response: |
|
routes = request.app.routes |
|
schema = self.get_schema(routes=routes) |
|
return OpenAPIResponse(schema) |
|
|
|
|
|
class SchemaGenerator(BaseSchemaGenerator): |
|
def __init__(self, base_schema: dict) -> None: |
|
self.base_schema = base_schema |
|
|
|
def get_schema(self, routes: typing.List[BaseRoute]) -> dict: |
|
schema = dict(self.base_schema) |
|
schema.setdefault("paths", {}) |
|
endpoints_info = self.get_endpoints(routes) |
|
|
|
for endpoint in endpoints_info: |
|
parsed = self.parse_docstring(endpoint.func) |
|
|
|
if not parsed: |
|
continue |
|
|
|
if endpoint.path not in schema["paths"]: |
|
schema["paths"][endpoint.path] = {} |
|
|
|
schema["paths"][endpoint.path][endpoint.http_method] = parsed |
|
|
|
return schema
|
|
|