You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
4.1 KiB
120 lines
4.1 KiB
import typing |
|
from os import PathLike |
|
|
|
from starlette.background import BackgroundTask |
|
from starlette.datastructures import URL |
|
from starlette.requests import Request |
|
from starlette.responses import Response |
|
from starlette.types import Receive, Scope, Send |
|
|
|
try: |
|
import jinja2 |
|
|
|
# @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1 |
|
# hence we try to get pass_context (most installs will be >=3.1) |
|
# and fall back to contextfunction, |
|
# adding a type ignore for mypy to let us access an attribute that may not exist |
|
if hasattr(jinja2, "pass_context"): |
|
pass_context = jinja2.pass_context |
|
else: # pragma: nocover |
|
pass_context = jinja2.contextfunction # type: ignore[attr-defined] |
|
except ModuleNotFoundError: # pragma: nocover |
|
jinja2 = None # type: ignore[assignment] |
|
|
|
|
|
class _TemplateResponse(Response): |
|
media_type = "text/html" |
|
|
|
def __init__( |
|
self, |
|
template: typing.Any, |
|
context: dict, |
|
status_code: int = 200, |
|
headers: typing.Optional[typing.Mapping[str, str]] = None, |
|
media_type: typing.Optional[str] = None, |
|
background: typing.Optional[BackgroundTask] = None, |
|
): |
|
self.template = template |
|
self.context = context |
|
content = template.render(context) |
|
super().__init__(content, status_code, headers, media_type, background) |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
request = self.context.get("request", {}) |
|
extensions = request.get("extensions", {}) |
|
if "http.response.debug" in extensions: |
|
await send( |
|
{ |
|
"type": "http.response.debug", |
|
"info": { |
|
"template": self.template, |
|
"context": self.context, |
|
}, |
|
} |
|
) |
|
await super().__call__(scope, receive, send) |
|
|
|
|
|
class Jinja2Templates: |
|
""" |
|
templates = Jinja2Templates("templates") |
|
|
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
directory: typing.Union[str, PathLike], |
|
context_processors: typing.Optional[ |
|
typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]] |
|
] = None, |
|
**env_options: typing.Any, |
|
) -> None: |
|
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" |
|
self.env = self._create_env(directory, **env_options) |
|
self.context_processors = context_processors or [] |
|
|
|
def _create_env( |
|
self, directory: typing.Union[str, PathLike], **env_options: typing.Any |
|
) -> "jinja2.Environment": |
|
@pass_context |
|
def url_for(context: dict, name: str, **path_params: typing.Any) -> URL: |
|
request = context["request"] |
|
return request.url_for(name, **path_params) |
|
|
|
loader = jinja2.FileSystemLoader(directory) |
|
env_options.setdefault("loader", loader) |
|
env_options.setdefault("autoescape", True) |
|
|
|
env = jinja2.Environment(**env_options) |
|
env.globals["url_for"] = url_for |
|
return env |
|
|
|
def get_template(self, name: str) -> "jinja2.Template": |
|
return self.env.get_template(name) |
|
|
|
def TemplateResponse( |
|
self, |
|
name: str, |
|
context: dict, |
|
status_code: int = 200, |
|
headers: typing.Optional[typing.Mapping[str, str]] = None, |
|
media_type: typing.Optional[str] = None, |
|
background: typing.Optional[BackgroundTask] = None, |
|
) -> _TemplateResponse: |
|
if "request" not in context: |
|
raise ValueError('context must include a "request" key') |
|
|
|
request = typing.cast(Request, context["request"]) |
|
for context_processor in self.context_processors: |
|
context.update(context_processor(request)) |
|
|
|
template = self.get_template(name) |
|
return _TemplateResponse( |
|
template, |
|
context, |
|
status_code=status_code, |
|
headers=headers, |
|
media_type=media_type, |
|
background=background, |
|
)
|
|
|