You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
151 lines
4.4 KiB
151 lines
4.4 KiB
# Copyright 2022-present MongoDB, Inc. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); you |
|
# may not use this file except in compliance with the License. You |
|
# may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
# implied. See the License for the specific language governing |
|
# permissions and limitations under the License. |
|
|
|
"""Internal helpers for CSOT.""" |
|
|
|
from __future__ import annotations |
|
|
|
import functools |
|
import time |
|
from collections import deque |
|
from contextvars import ContextVar, Token |
|
from typing import Any, Callable, Deque, MutableMapping, Optional, Tuple, TypeVar, cast |
|
|
|
from pymongo.write_concern import WriteConcern |
|
|
|
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None) |
|
RTT: ContextVar[float] = ContextVar("RTT", default=0.0) |
|
DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf")) |
|
|
|
|
|
def get_timeout() -> Optional[float]: |
|
return TIMEOUT.get(None) |
|
|
|
|
|
def get_rtt() -> float: |
|
return RTT.get() |
|
|
|
|
|
def get_deadline() -> float: |
|
return DEADLINE.get() |
|
|
|
|
|
def set_rtt(rtt: float) -> None: |
|
RTT.set(rtt) |
|
|
|
|
|
def remaining() -> Optional[float]: |
|
if not get_timeout(): |
|
return None |
|
return DEADLINE.get() - time.monotonic() |
|
|
|
|
|
def clamp_remaining(max_timeout: float) -> float: |
|
"""Return the remaining timeout clamped to a max value.""" |
|
timeout = remaining() |
|
if timeout is None: |
|
return max_timeout |
|
return min(timeout, max_timeout) |
|
|
|
|
|
class _TimeoutContext: |
|
"""Internal timeout context manager. |
|
|
|
Use :func:`pymongo.timeout` instead:: |
|
|
|
with pymongo.timeout(0.5): |
|
client.test.test.insert_one({}) |
|
""" |
|
|
|
__slots__ = ("_timeout", "_tokens") |
|
|
|
def __init__(self, timeout: Optional[float]): |
|
self._timeout = timeout |
|
self._tokens: Optional[Tuple[Token, Token, Token]] = None |
|
|
|
def __enter__(self) -> _TimeoutContext: |
|
timeout_token = TIMEOUT.set(self._timeout) |
|
prev_deadline = DEADLINE.get() |
|
next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") |
|
deadline_token = DEADLINE.set(min(prev_deadline, next_deadline)) |
|
rtt_token = RTT.set(0.0) |
|
self._tokens = (timeout_token, deadline_token, rtt_token) |
|
return self |
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
|
if self._tokens: |
|
timeout_token, deadline_token, rtt_token = self._tokens |
|
TIMEOUT.reset(timeout_token) |
|
DEADLINE.reset(deadline_token) |
|
RTT.reset(rtt_token) |
|
|
|
|
|
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories |
|
F = TypeVar("F", bound=Callable[..., Any]) |
|
|
|
|
|
def apply(func: F) -> F: |
|
"""Apply the client's timeoutMS to this operation.""" |
|
|
|
@functools.wraps(func) |
|
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> F: |
|
if get_timeout() is None: |
|
timeout = self._timeout |
|
if timeout is not None: |
|
with _TimeoutContext(timeout): |
|
return func(self, *args, **kwargs) |
|
return func(self, *args, **kwargs) |
|
|
|
return cast(F, csot_wrapper) |
|
|
|
|
|
def apply_write_concern(cmd: MutableMapping, write_concern: Optional[WriteConcern]) -> None: |
|
"""Apply the given write concern to a command.""" |
|
if not write_concern or write_concern.is_server_default: |
|
return |
|
wc = write_concern.document |
|
if get_timeout() is not None: |
|
wc.pop("wtimeout", None) |
|
if wc: |
|
cmd["writeConcern"] = wc |
|
|
|
|
|
_MAX_RTT_SAMPLES: int = 10 |
|
_MIN_RTT_SAMPLES: int = 2 |
|
|
|
|
|
class MovingMinimum: |
|
"""Tracks a minimum RTT within the last 10 RTT samples.""" |
|
|
|
samples: Deque[float] |
|
|
|
def __init__(self) -> None: |
|
self.samples = deque(maxlen=_MAX_RTT_SAMPLES) |
|
|
|
def add_sample(self, sample: float) -> None: |
|
if sample < 0: |
|
# Likely system time change while waiting for hello response |
|
# and not using time.monotonic. Ignore it, the next one will |
|
# probably be valid. |
|
return |
|
self.samples.append(sample) |
|
|
|
def get(self) -> float: |
|
"""Get the min, or 0.0 if there aren't enough samples yet.""" |
|
if len(self.samples) >= _MIN_RTT_SAMPLES: |
|
return min(self.samples) |
|
return 0.0 |
|
|
|
def reset(self) -> None: |
|
self.samples.clear()
|
|
|