You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
338 lines
11 KiB
338 lines
11 KiB
# Copyright 2009-present MongoDB, Inc. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
"""Bits and pieces used by the driver that don't really fit elsewhere.""" |
|
from __future__ import annotations |
|
|
|
import sys |
|
import traceback |
|
from collections import abc |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Container, |
|
Iterable, |
|
List, |
|
Mapping, |
|
NoReturn, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
TypeVar, |
|
Union, |
|
cast, |
|
) |
|
|
|
from bson.son import SON |
|
from pymongo import ASCENDING |
|
from pymongo.errors import ( |
|
CursorNotFound, |
|
DuplicateKeyError, |
|
ExecutionTimeout, |
|
NotPrimaryError, |
|
OperationFailure, |
|
WriteConcernError, |
|
WriteError, |
|
WTimeoutError, |
|
_wtimeout_error, |
|
) |
|
from pymongo.hello import HelloCompat |
|
|
|
if TYPE_CHECKING: |
|
from pymongo.cursor import _Hint |
|
from pymongo.operations import _IndexList |
|
from pymongo.typings import _DocumentOut |
|
|
|
# From the SDAM spec, the "node is shutting down" codes. |
|
_SHUTDOWN_CODES: frozenset = frozenset( |
|
[ |
|
11600, # InterruptedAtShutdown |
|
91, # ShutdownInProgress |
|
] |
|
) |
|
# From the SDAM spec, the "not primary" error codes are combined with the |
|
# "node is recovering" error codes (of which the "node is shutting down" |
|
# errors are a subset). |
|
_NOT_PRIMARY_CODES: frozenset = ( |
|
frozenset( |
|
[ |
|
10058, # LegacyNotPrimary <=3.2 "not primary" error code |
|
10107, # NotWritablePrimary |
|
13435, # NotPrimaryNoSecondaryOk |
|
11602, # InterruptedDueToReplStateChange |
|
13436, # NotPrimaryOrSecondary |
|
189, # PrimarySteppedDown |
|
] |
|
) |
|
| _SHUTDOWN_CODES |
|
) |
|
# From the retryable writes spec. |
|
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( |
|
[ |
|
7, # HostNotFound |
|
6, # HostUnreachable |
|
89, # NetworkTimeout |
|
9001, # SocketException |
|
262, # ExceededTimeLimit |
|
] |
|
) |
|
|
|
# Server code raised when re-authentication is required |
|
_REAUTHENTICATION_REQUIRED_CODE: int = 391 |
|
|
|
|
|
def _gen_index_name(keys: _IndexList) -> str: |
|
"""Generate an index name from the set of fields it is over.""" |
|
return "_".join(["{}_{}".format(*item) for item in keys]) |
|
|
|
|
|
def _index_list( |
|
key_or_list: _Hint, direction: Optional[Union[int, str]] = None |
|
) -> Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]: |
|
"""Helper to generate a list of (key, direction) pairs. |
|
|
|
Takes such a list, or a single key, or a single key and direction. |
|
""" |
|
if direction is not None: |
|
if not isinstance(key_or_list, str): |
|
raise TypeError("Expected a string and a direction") |
|
return [(key_or_list, direction)] |
|
else: |
|
if isinstance(key_or_list, str): |
|
return [(key_or_list, ASCENDING)] |
|
if isinstance(key_or_list, abc.ItemsView): |
|
return list(key_or_list) |
|
elif not isinstance(key_or_list, (list, tuple)): |
|
raise TypeError("if no direction is specified, key_or_list must be an instance of list") |
|
values: List[Tuple[str, int]] = [] |
|
for item in key_or_list: |
|
if isinstance(item, str): |
|
item = (item, ASCENDING) |
|
values.append(item) |
|
return values |
|
|
|
|
|
def _index_document(index_list: _IndexList) -> SON[str, Any]: |
|
"""Helper to generate an index specifying document. |
|
|
|
Takes a list of (key, direction) pairs. |
|
""" |
|
if isinstance(index_list, abc.Mapping): |
|
raise TypeError( |
|
"passing a dict to sort/create_index/hint is not " |
|
"allowed - use a list of tuples instead. did you " |
|
"mean %r?" % list(index_list.items()) |
|
) |
|
elif not isinstance(index_list, (list, tuple)): |
|
raise TypeError("must use a list of (key, direction) pairs, not: " + repr(index_list)) |
|
if not len(index_list): |
|
raise ValueError("key_or_list must not be the empty list") |
|
|
|
index: SON[str, Any] = SON() |
|
for item in index_list: |
|
if isinstance(item, str): |
|
item = (item, ASCENDING) |
|
key, value = item |
|
if not isinstance(key, str): |
|
raise TypeError("first item in each key pair must be an instance of str") |
|
if not isinstance(value, (str, int, abc.Mapping)): |
|
raise TypeError( |
|
"second item in each key pair must be 1, -1, " |
|
"'2d', or another valid MongoDB index specifier." |
|
) |
|
index[key] = value |
|
return index |
|
|
|
|
|
def _check_command_response( |
|
response: _DocumentOut, |
|
max_wire_version: Optional[int], |
|
allowable_errors: Optional[Container[Union[int, str]]] = None, |
|
parse_write_concern_error: bool = False, |
|
) -> None: |
|
"""Check the response to a command for errors.""" |
|
if "ok" not in response: |
|
# Server didn't recognize our message as a command. |
|
raise OperationFailure( |
|
response.get("$err"), # type: ignore[arg-type] |
|
response.get("code"), |
|
response, |
|
max_wire_version, |
|
) |
|
|
|
if parse_write_concern_error and "writeConcernError" in response: |
|
_error = response["writeConcernError"] |
|
_labels = response.get("errorLabels") |
|
if _labels: |
|
_error.update({"errorLabels": _labels}) |
|
_raise_write_concern_error(_error) |
|
|
|
if response["ok"]: |
|
return |
|
|
|
details = response |
|
# Mongos returns the error details in a 'raw' object |
|
# for some errors. |
|
if "raw" in response: |
|
for shard in response["raw"].values(): |
|
# Grab the first non-empty raw error from a shard. |
|
if shard.get("errmsg") and not shard.get("ok"): |
|
details = shard |
|
break |
|
|
|
errmsg = details["errmsg"] |
|
code = details.get("code") |
|
|
|
# For allowable errors, only check for error messages when the code is not |
|
# included. |
|
if allowable_errors: |
|
if code is not None: |
|
if code in allowable_errors: |
|
return |
|
elif errmsg in allowable_errors: |
|
return |
|
|
|
# Server is "not primary" or "recovering" |
|
if code is not None: |
|
if code in _NOT_PRIMARY_CODES: |
|
raise NotPrimaryError(errmsg, response) |
|
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: |
|
raise NotPrimaryError(errmsg, response) |
|
|
|
# Other errors |
|
# findAndModify with upsert can raise duplicate key error |
|
if code in (11000, 11001, 12582): |
|
raise DuplicateKeyError(errmsg, code, response, max_wire_version) |
|
elif code == 50: |
|
raise ExecutionTimeout(errmsg, code, response, max_wire_version) |
|
elif code == 43: |
|
raise CursorNotFound(errmsg, code, response, max_wire_version) |
|
|
|
raise OperationFailure(errmsg, code, response, max_wire_version) |
|
|
|
|
|
def _raise_last_write_error(write_errors: List[Any]) -> NoReturn: |
|
# If the last batch had multiple errors only report |
|
# the last error to emulate continue_on_error. |
|
error = write_errors[-1] |
|
if error.get("code") == 11000: |
|
raise DuplicateKeyError(error.get("errmsg"), 11000, error) |
|
raise WriteError(error.get("errmsg"), error.get("code"), error) |
|
|
|
|
|
def _raise_write_concern_error(error: Any) -> NoReturn: |
|
if _wtimeout_error(error): |
|
# Make sure we raise WTimeoutError |
|
raise WTimeoutError(error.get("errmsg"), error.get("code"), error) |
|
raise WriteConcernError(error.get("errmsg"), error.get("code"), error) |
|
|
|
|
|
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: |
|
"""Return the writeConcernError or None.""" |
|
wce = result.get("writeConcernError") |
|
if wce: |
|
# The server reports errorLabels at the top level but it's more |
|
# convenient to attach it to the writeConcernError doc itself. |
|
error_labels = result.get("errorLabels") |
|
if error_labels: |
|
wce["errorLabels"] = error_labels |
|
return wce |
|
|
|
|
|
def _check_write_command_response(result: Mapping[str, Any]) -> None: |
|
"""Backward compatibility helper for write command error handling.""" |
|
# Prefer write errors over write concern errors |
|
write_errors = result.get("writeErrors") |
|
if write_errors: |
|
_raise_last_write_error(write_errors) |
|
|
|
wce = _get_wce_doc(result) |
|
if wce: |
|
_raise_write_concern_error(wce) |
|
|
|
|
|
def _fields_list_to_dict( |
|
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str |
|
) -> Mapping[str, Any]: |
|
"""Takes a sequence of field names and returns a matching dictionary. |
|
|
|
["a", "b"] becomes {"a": 1, "b": 1} |
|
|
|
and |
|
|
|
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} |
|
""" |
|
if isinstance(fields, abc.Mapping): |
|
return fields |
|
|
|
if isinstance(fields, (abc.Sequence, abc.Set)): |
|
if not all(isinstance(field, str) for field in fields): |
|
raise TypeError(f"{option_name} must be a list of key names, each an instance of str") |
|
return dict.fromkeys(fields, 1) |
|
|
|
raise TypeError(f"{option_name} must be a mapping or list of key names") |
|
|
|
|
|
def _handle_exception() -> None: |
|
"""Print exceptions raised by subscribers to stderr.""" |
|
# Heavily influenced by logging.Handler.handleError. |
|
|
|
# See note here: |
|
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__ |
|
if sys.stderr: |
|
einfo = sys.exc_info() |
|
try: |
|
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) |
|
except OSError: |
|
pass |
|
finally: |
|
del einfo |
|
|
|
|
|
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories |
|
F = TypeVar("F", bound=Callable[..., Any]) |
|
|
|
|
|
def _handle_reauth(func: F) -> F: |
|
def inner(*args: Any, **kwargs: Any) -> Any: |
|
no_reauth = kwargs.pop("no_reauth", False) |
|
from pymongo.message import _BulkWriteContext |
|
from pymongo.pool import Connection |
|
|
|
try: |
|
return func(*args, **kwargs) |
|
except OperationFailure as exc: |
|
if no_reauth: |
|
raise |
|
if exc.code == _REAUTHENTICATION_REQUIRED_CODE: |
|
# Look for an argument that either is a Connection |
|
# or has a connection attribute, so we can trigger |
|
# a reauth. |
|
conn = None |
|
for arg in args: |
|
if isinstance(arg, Connection): |
|
conn = arg |
|
break |
|
if isinstance(arg, _BulkWriteContext): |
|
conn = arg.conn |
|
break |
|
if conn: |
|
conn.authenticate(reauthenticate=True) |
|
else: |
|
raise |
|
return func(*args, **kwargs) |
|
raise |
|
|
|
return cast(F, inner)
|
|
|