@mytec: pushed back before 1.1
This commit is contained in:
@@ -0,0 +1,364 @@
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Bits and pieces used by the driver that don't really fit elsewhere."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from collections import abc
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Container,
|
||||
Iterable,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.errors import (
|
||||
CursorNotFound,
|
||||
DuplicateKeyError,
|
||||
ExecutionTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WriteConcernError,
|
||||
WriteError,
|
||||
WTimeoutError,
|
||||
_wtimeout_error,
|
||||
)
|
||||
from pymongo.hello import HelloCompat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.cursor_shared import _Hint
|
||||
from pymongo.operations import _IndexList
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
|
||||
# From the SDAM spec, the "node is shutting down" codes.
|
||||
|
||||
_SHUTDOWN_CODES: frozenset[int] = frozenset(
|
||||
[
|
||||
11600, # InterruptedAtShutdown
|
||||
91, # ShutdownInProgress
|
||||
]
|
||||
)
|
||||
# From the SDAM spec, the "not primary" error codes are combined with the
|
||||
# "node is recovering" error codes (of which the "node is shutting down"
|
||||
# errors are a subset).
|
||||
_NOT_PRIMARY_CODES: frozenset[int] = (
|
||||
frozenset(
|
||||
[
|
||||
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
||||
10107, # NotWritablePrimary
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13436, # NotPrimaryOrSecondary
|
||||
189, # PrimarySteppedDown
|
||||
]
|
||||
)
|
||||
| _SHUTDOWN_CODES
|
||||
)
|
||||
# From the retryable writes spec.
|
||||
_RETRYABLE_ERROR_CODES: frozenset[int] = _NOT_PRIMARY_CODES | frozenset(
|
||||
[
|
||||
7, # HostNotFound
|
||||
6, # HostUnreachable
|
||||
89, # NetworkTimeout
|
||||
9001, # SocketException
|
||||
262, # ExceededTimeLimit
|
||||
134, # ReadConcernMajorityNotAvailableYet
|
||||
]
|
||||
)
|
||||
|
||||
# Server code raised when re-authentication is required
|
||||
_REAUTHENTICATION_REQUIRED_CODE: int = 391
|
||||
|
||||
# Server code raised when authentication fails.
|
||||
_AUTHENTICATION_FAILURE_CODE: int = 18
|
||||
|
||||
# Note - to avoid bugs from forgetting which if these is all lowercase and
|
||||
# which are camelCase, and at the same time avoid having to add a test for
|
||||
# every command, use all lowercase here and test against command_name.lower().
|
||||
_SENSITIVE_COMMANDS: set[str] = {
|
||||
"authenticate",
|
||||
"saslstart",
|
||||
"saslcontinue",
|
||||
"getnonce",
|
||||
"createuser",
|
||||
"updateuser",
|
||||
"copydbgetnonce",
|
||||
"copydbsaslstart",
|
||||
"copydb",
|
||||
}
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
from pymongo import _csot
|
||||
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
socket_timeout = options.socket_timeout
|
||||
connect_timeout = options.connect_timeout
|
||||
if timeout:
|
||||
details["timeoutMS"] = timeout * 1000
|
||||
if socket_timeout and not timeout:
|
||||
details["socketTimeoutMS"] = socket_timeout * 1000
|
||||
if connect_timeout:
|
||||
details["connectTimeoutMS"] = connect_timeout * 1000
|
||||
return details
|
||||
|
||||
|
||||
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
|
||||
result = ""
|
||||
if details:
|
||||
result += " (configured timeouts:"
|
||||
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
|
||||
if timeout in details:
|
||||
result += f" {timeout}: {details[timeout]}ms,"
|
||||
result = result[:-1]
|
||||
result += ")"
|
||||
return result
|
||||
|
||||
|
||||
def _gen_index_name(keys: _IndexList) -> str:
|
||||
"""Generate an index name from the set of fields it is over."""
|
||||
return "_".join(["{}_{}".format(*item) for item in keys])
|
||||
|
||||
|
||||
def _index_list(
|
||||
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
|
||||
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
|
||||
"""Helper to generate a list of (key, direction) pairs.
|
||||
|
||||
Takes such a list, or a single key, or a single key and direction.
|
||||
"""
|
||||
if direction is not None:
|
||||
if not isinstance(key_or_list, str):
|
||||
raise TypeError(f"Expected a string and a direction, not {type(key_or_list)}")
|
||||
return [(key_or_list, direction)]
|
||||
else:
|
||||
if isinstance(key_or_list, str):
|
||||
return [(key_or_list, ASCENDING)]
|
||||
elif isinstance(key_or_list, abc.ItemsView):
|
||||
return list(key_or_list) # type: ignore[arg-type]
|
||||
elif isinstance(key_or_list, abc.Mapping):
|
||||
return list(key_or_list.items())
|
||||
elif not isinstance(key_or_list, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"if no direction is specified, key_or_list must be an instance of list, not {type(key_or_list)}"
|
||||
)
|
||||
values: list[tuple[str, int]] = []
|
||||
for item in key_or_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
values.append(item)
|
||||
return values
|
||||
|
||||
|
||||
def _index_document(index_list: _IndexList) -> dict[str, Any]:
|
||||
"""Helper to generate an index specifying document.
|
||||
|
||||
Takes a list of (key, direction) pairs.
|
||||
"""
|
||||
if not isinstance(index_list, (list, tuple, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
|
||||
)
|
||||
if not len(index_list):
|
||||
raise ValueError("key_or_list must not be empty")
|
||||
|
||||
index: dict[str, Any] = {}
|
||||
|
||||
if isinstance(index_list, abc.Mapping):
|
||||
for key in index_list:
|
||||
value = index_list[key]
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
else:
|
||||
for item in index_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
key, value = item
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
return index
|
||||
|
||||
|
||||
def _validate_index_key_pair(key: Any, value: Any) -> None:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(f"first item in each key pair must be an instance of str, not {type(key)}")
|
||||
if not isinstance(value, (str, int, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"second item in each key pair must be 1, -1, "
|
||||
"'2d', or another valid MongoDB index specifier."
|
||||
f", not {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_command_response(
|
||||
response: _DocumentOut,
|
||||
max_wire_version: Optional[int],
|
||||
allowable_errors: Optional[Container[Union[int, str]]] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
pool_opts: Optional[PoolOptions] = None,
|
||||
) -> None:
|
||||
"""Check the response to a command for errors."""
|
||||
if "ok" not in response:
|
||||
# Server didn't recognize our message as a command.
|
||||
raise OperationFailure(
|
||||
response.get("$err"), # type: ignore[arg-type]
|
||||
response.get("code"),
|
||||
response,
|
||||
max_wire_version,
|
||||
)
|
||||
|
||||
if parse_write_concern_error and "writeConcernError" in response:
|
||||
_error = response["writeConcernError"]
|
||||
_labels = response.get("errorLabels")
|
||||
if _labels:
|
||||
_error.update({"errorLabels": _labels})
|
||||
_raise_write_concern_error(_error)
|
||||
|
||||
if response["ok"]:
|
||||
return
|
||||
|
||||
details = response
|
||||
# Mongos returns the error details in a 'raw' object
|
||||
# for some errors.
|
||||
if "raw" in response:
|
||||
for shard in response["raw"].values():
|
||||
# Grab the first non-empty raw error from a shard.
|
||||
if shard.get("errmsg") and not shard.get("ok"):
|
||||
details = shard
|
||||
break
|
||||
|
||||
errmsg = details["errmsg"]
|
||||
code = details.get("code")
|
||||
|
||||
# For allowable errors, only check for error messages when the code is not
|
||||
# included.
|
||||
if allowable_errors:
|
||||
if code is not None:
|
||||
if code in allowable_errors:
|
||||
return
|
||||
elif errmsg in allowable_errors:
|
||||
return
|
||||
|
||||
# Server is "not primary" or "recovering"
|
||||
if code is not None:
|
||||
if code in _NOT_PRIMARY_CODES:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
|
||||
# Other errors
|
||||
# findAndModify with upsert can raise duplicate key error
|
||||
if code in (11000, 11001, 12582):
|
||||
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
|
||||
elif code == 50:
|
||||
# Append timeout details to MaxTimeMSExpired responses.
|
||||
if pool_opts:
|
||||
timeout_details = _get_timeout_details(pool_opts)
|
||||
errmsg += format_timeout_details(timeout_details)
|
||||
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
|
||||
elif code == 43:
|
||||
raise CursorNotFound(errmsg, code, response, max_wire_version)
|
||||
|
||||
raise OperationFailure(errmsg, code, response, max_wire_version)
|
||||
|
||||
|
||||
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
|
||||
# If the last batch had multiple errors only report
|
||||
# the last error to emulate continue_on_error.
|
||||
error = write_errors[-1]
|
||||
if error.get("code") == 11000:
|
||||
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
|
||||
raise WriteError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _raise_write_concern_error(error: Any) -> NoReturn:
|
||||
if _wtimeout_error(error):
|
||||
# Make sure we raise WTimeoutError
|
||||
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
|
||||
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
||||
"""Return the writeConcernError or None."""
|
||||
wce = result.get("writeConcernError")
|
||||
if wce:
|
||||
# The server reports errorLabels at the top level but it's more
|
||||
# convenient to attach it to the writeConcernError doc itself.
|
||||
error_labels = result.get("errorLabels")
|
||||
if error_labels:
|
||||
# Copy to avoid changing the original document.
|
||||
wce = wce.copy()
|
||||
wce["errorLabels"] = error_labels
|
||||
return wce
|
||||
|
||||
|
||||
def _check_write_command_response(result: Mapping[str, Any]) -> None:
|
||||
"""Backward compatibility helper for write command error handling."""
|
||||
# Prefer write errors over write concern errors
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
_raise_last_write_error(write_errors)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
_raise_write_concern_error(wce)
|
||||
|
||||
|
||||
def _fields_list_to_dict(
|
||||
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
|
||||
) -> Mapping[str, Any]:
|
||||
"""Takes a sequence of field names and returns a matching dictionary.
|
||||
|
||||
["a", "b"] becomes {"a": 1, "b": 1}
|
||||
|
||||
and
|
||||
|
||||
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
|
||||
"""
|
||||
if isinstance(fields, abc.Mapping):
|
||||
return fields
|
||||
|
||||
if isinstance(fields, (abc.Sequence, abc.Set)):
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
|
||||
return dict.fromkeys(fields, 1)
|
||||
|
||||
raise TypeError(f"{option_name} must be a mapping or list of key names")
|
||||
|
||||
|
||||
def _handle_exception() -> None:
|
||||
"""Print exceptions raised by subscribers to stderr."""
|
||||
# Heavily influenced by logging.Handler.handleError.
|
||||
|
||||
# See note here:
|
||||
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
|
||||
if sys.stderr:
|
||||
einfo = sys.exc_info()
|
||||
try:
|
||||
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
del einfo
|
||||
Reference in New Issue
Block a user