You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

193 lines
7.2 KiB

import enum
import json
import typing
from starlette.requests import HTTPConnection
from starlette.types import Message, Receive, Scope, Send
class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
self.code = code
self.reason = reason or ""
class WebSocket(HTTPConnection):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
super().__init__(scope)
assert scope["type"] == "websocket"
self._receive = receive
self._send = send
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.
"""
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
raise RuntimeError(
'Expected ASGI message "websocket.connect", '
f"but got {message_type!r}"
)
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
'Expected ASGI message "websocket.receive" or '
f'"websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError(
'Cannot call "receive" once a disconnect message has been received.'
)
async def send(self, message: Message) -> None:
"""
Send ASGI websocket messages, ensuring valid state transitions.
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.accept" or '
f'"websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.send" or "websocket.close", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(
self,
subprotocol: typing.Optional[str] = None,
headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None,
) -> None:
headers = headers or []
if self.client_state == WebSocketState.CONNECTING:
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
raise WebSocketDisconnect(message["code"])
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]
async def receive_json(self, mode: str = "text") -> typing.Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
async def iter_text(self) -> typing.AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
try:
while True:
yield await self.receive_json()
except WebSocketDisconnect:
pass
async def send_text(self, data: str) -> None:
await self.send({"type": "websocket.send", "text": data})
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data, separators=(",", ":"))
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(
self, code: int = 1000, reason: typing.Optional[str] = None
) -> None:
await self.send(
{"type": "websocket.close", "code": code, "reason": reason or ""}
)
class WebSocketClose:
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
self.code = code
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)