You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
4.0 KiB
109 lines
4.0 KiB
import typing |
|
|
|
from starlette._utils import is_async_callable |
|
from starlette.concurrency import run_in_threadpool |
|
from starlette.exceptions import HTTPException, WebSocketException |
|
from starlette.requests import Request |
|
from starlette.responses import PlainTextResponse, Response |
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send |
|
from starlette.websockets import WebSocket |
|
|
|
|
|
class ExceptionMiddleware: |
|
def __init__( |
|
self, |
|
app: ASGIApp, |
|
handlers: typing.Optional[ |
|
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] |
|
] = None, |
|
debug: bool = False, |
|
) -> None: |
|
self.app = app |
|
self.debug = debug # TODO: We ought to handle 404 cases if debug is set. |
|
self._status_handlers: typing.Dict[int, typing.Callable] = {} |
|
self._exception_handlers: typing.Dict[ |
|
typing.Type[Exception], typing.Callable |
|
] = { |
|
HTTPException: self.http_exception, |
|
WebSocketException: self.websocket_exception, |
|
} |
|
if handlers is not None: |
|
for key, value in handlers.items(): |
|
self.add_exception_handler(key, value) |
|
|
|
def add_exception_handler( |
|
self, |
|
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], |
|
handler: typing.Callable[[Request, Exception], Response], |
|
) -> None: |
|
if isinstance(exc_class_or_status_code, int): |
|
self._status_handlers[exc_class_or_status_code] = handler |
|
else: |
|
assert issubclass(exc_class_or_status_code, Exception) |
|
self._exception_handlers[exc_class_or_status_code] = handler |
|
|
|
def _lookup_exception_handler( |
|
self, exc: Exception |
|
) -> typing.Optional[typing.Callable]: |
|
for cls in type(exc).__mro__: |
|
if cls in self._exception_handlers: |
|
return self._exception_handlers[cls] |
|
return None |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
if scope["type"] not in ("http", "websocket"): |
|
await self.app(scope, receive, send) |
|
return |
|
|
|
response_started = False |
|
|
|
async def sender(message: Message) -> None: |
|
nonlocal response_started |
|
|
|
if message["type"] == "http.response.start": |
|
response_started = True |
|
await send(message) |
|
|
|
try: |
|
await self.app(scope, receive, sender) |
|
except Exception as exc: |
|
handler = None |
|
|
|
if isinstance(exc, HTTPException): |
|
handler = self._status_handlers.get(exc.status_code) |
|
|
|
if handler is None: |
|
handler = self._lookup_exception_handler(exc) |
|
|
|
if handler is None: |
|
raise exc |
|
|
|
if response_started: |
|
msg = "Caught handled exception, but response already started." |
|
raise RuntimeError(msg) from exc |
|
|
|
if scope["type"] == "http": |
|
request = Request(scope, receive=receive) |
|
if is_async_callable(handler): |
|
response = await handler(request, exc) |
|
else: |
|
response = await run_in_threadpool(handler, request, exc) |
|
await response(scope, receive, sender) |
|
elif scope["type"] == "websocket": |
|
websocket = WebSocket(scope, receive=receive, send=send) |
|
if is_async_callable(handler): |
|
await handler(websocket, exc) |
|
else: |
|
await run_in_threadpool(handler, websocket, exc) |
|
|
|
def http_exception(self, request: Request, exc: HTTPException) -> Response: |
|
if exc.status_code in {204, 304}: |
|
return Response(status_code=exc.status_code, headers=exc.headers) |
|
return PlainTextResponse( |
|
exc.detail, status_code=exc.status_code, headers=exc.headers |
|
) |
|
|
|
async def websocket_exception( |
|
self, websocket: WebSocket, exc: WebSocketException |
|
) -> None: |
|
await websocket.close(code=exc.code, reason=exc.reason)
|
|
|