@mytec: pushed back before 1.1
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,254 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Perform aggregation operations on a collection or database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.server import Server
|
||||
from pymongo.typings import _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class _AggregationCommand:
|
||||
"""The internal abstract base class for aggregation cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.collection.Collection.aggregate`, or
|
||||
:meth:`pymongo.database.Database.aggregate` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Database[Any], Collection[Any]],
|
||||
cursor_class: type[CommandCursor[Any]],
|
||||
pipeline: _Pipeline,
|
||||
options: MutableMapping[str, Any],
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
user_fields: Optional[MutableMapping[str, Any]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
if "explain" in options:
|
||||
raise ConfigurationError(
|
||||
"The explain option is not supported. Use Database.command instead."
|
||||
)
|
||||
|
||||
self._target = target
|
||||
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
self._pipeline = pipeline
|
||||
self._performs_write = False
|
||||
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
|
||||
self._performs_write = True
|
||||
|
||||
common.validate_is_mapping("options", options)
|
||||
if let is not None:
|
||||
common.validate_is_mapping("let", let)
|
||||
options["let"] = let
|
||||
if comment is not None:
|
||||
options["comment"] = comment
|
||||
|
||||
self._options = options
|
||||
|
||||
# This is the batchSize that will be used for setting the initial
|
||||
# batchSize for the cursor, as well as the subsequent getMores.
|
||||
self._batch_size = common.validate_non_negative_integer_or_none(
|
||||
"batchSize", self._options.pop("batchSize", None)
|
||||
)
|
||||
|
||||
# If the cursor option is already specified, avoid overriding it.
|
||||
self._options.setdefault("cursor", {})
|
||||
# If the pipeline performs a write, we ignore the initial batchSize
|
||||
# since the server doesn't return results in this case.
|
||||
if self._batch_size is not None and not self._performs_write:
|
||||
self._options["cursor"]["batchSize"] = self._batch_size
|
||||
|
||||
self._cursor_class = cursor_class
|
||||
self._user_fields = user_fields
|
||||
self._result_processor = result_processor
|
||||
|
||||
self._collation = validate_collation_or_none(options.pop("collation", None))
|
||||
|
||||
self._max_await_time_ms = options.pop("maxAwaitTimeMS", None)
|
||||
self._write_preference: Optional[_AggWritePref] = None
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> Union[str, int]:
|
||||
"""The argument to pass to the aggregate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
"""The namespace in which the aggregate command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection[Any]:
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _database(self) -> Database[Any]:
|
||||
"""The database against which the aggregation command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_read_preference(
|
||||
self, session: Optional[ClientSession]
|
||||
) -> Union[_AggWritePref, _ServerMode]:
|
||||
if self._write_preference:
|
||||
return self._write_preference
|
||||
pref = self._target._read_preference_for(session)
|
||||
if self._performs_write and pref != ReadPreference.PRIMARY:
|
||||
self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment]
|
||||
return pref
|
||||
|
||||
def get_cursor(
|
||||
self,
|
||||
session: Optional[ClientSession],
|
||||
server: Server,
|
||||
conn: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
# Serialize command.
|
||||
cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline}
|
||||
cmd.update(self._options)
|
||||
|
||||
# Apply this target's read concern if:
|
||||
# readConcern has not been specified as a kwarg and either
|
||||
# - server version is >= 4.2 or
|
||||
# - server version is >= 3.2 and pipeline doesn't use $out
|
||||
if ("readConcern" not in cmd) and (
|
||||
not self._performs_write or (conn.max_wire_version >= 8)
|
||||
):
|
||||
read_concern = self._target.read_concern
|
||||
else:
|
||||
read_concern = None
|
||||
|
||||
# Apply this target's write concern if:
|
||||
# writeConcern has not been specified as a kwarg and pipeline doesn't
|
||||
# perform a write operation
|
||||
if "writeConcern" not in cmd and self._performs_write:
|
||||
write_concern = self._target._write_concern_for(session)
|
||||
else:
|
||||
write_concern = None
|
||||
|
||||
# Run command.
|
||||
result = conn.command(
|
||||
self._database.name,
|
||||
cmd,
|
||||
read_preference,
|
||||
self._target.codec_options,
|
||||
parse_write_concern_error=True,
|
||||
read_concern=read_concern,
|
||||
write_concern=write_concern,
|
||||
collation=self._collation,
|
||||
session=session,
|
||||
client=self._database.client,
|
||||
user_fields=self._user_fields,
|
||||
)
|
||||
|
||||
if self._result_processor:
|
||||
self._result_processor(result, conn)
|
||||
|
||||
# Extract cursor from result or mock/fake one if necessary.
|
||||
if "cursor" in result:
|
||||
cursor = result["cursor"]
|
||||
else:
|
||||
# Unacknowledged $out/$merge write. Fake a cursor.
|
||||
cursor = {
|
||||
"id": 0,
|
||||
"firstBatch": result.get("result", []),
|
||||
"ns": self._cursor_namespace,
|
||||
}
|
||||
|
||||
# Create and return cursor instance.
|
||||
cmd_cursor = self._cursor_class(
|
||||
self._cursor_collection(cursor),
|
||||
cursor,
|
||||
conn.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session,
|
||||
comment=self._options.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
|
||||
|
||||
class _CollectionAggregationCommand(_AggregationCommand):
|
||||
_target: Collection[Any]
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> str:
|
||||
return self._target.name
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
return self._target.full_name
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]:
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
return self._target
|
||||
|
||||
@property
|
||||
def _database(self) -> Database[Any]:
|
||||
return self._target.database
|
||||
|
||||
|
||||
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# For raw-batches, we set the initial batchSize for the cursor to 0.
|
||||
if not self._performs_write:
|
||||
self._options["cursor"]["batchSize"] = 0
|
||||
|
||||
|
||||
class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
_target: Database[Any]
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
return f"{self._target.name}.$cmd.aggregate"
|
||||
|
||||
@property
|
||||
def _database(self) -> Database[Any]:
|
||||
return self._target
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]:
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
# Collection level aggregate may not always return the "ns" field
|
||||
# according to our MockupDB tests. Let's handle that case for db level
|
||||
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
|
||||
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
|
||||
return self._database[collname]
|
||||
@@ -0,0 +1,450 @@
|
||||
# Copyright 2013-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import socket
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.auth_shared import (
|
||||
MongoCredential,
|
||||
_authenticate_scram_start,
|
||||
_parse_scram_response,
|
||||
_xor,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
from pymongo.synchronous.auth_aws import _authenticate_aws
|
||||
from pymongo.synchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
)
|
||||
from pymongo.synchronous.helpers import _getaddrinfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
try:
|
||||
import winkerberos as kerberos # type:ignore[import]
|
||||
|
||||
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
|
||||
_USE_PRINCIPAL = True
|
||||
except ImportError:
|
||||
try:
|
||||
import kerberos # type:ignore[import]
|
||||
except ImportError:
|
||||
HAVE_KERBEROS = False
|
||||
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
digest = "sha256"
|
||||
digestmod = hashlib.sha256
|
||||
data = saslprep(credentials.password).encode("utf-8")
|
||||
else:
|
||||
digest = "sha1"
|
||||
digestmod = hashlib.sha1
|
||||
data = _password_digest(username, credentials.password).encode("utf-8")
|
||||
source = credentials.source
|
||||
cache = credentials.cache
|
||||
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
assert isinstance(ctx, _ScramContext)
|
||||
assert ctx.scram_data is not None
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
assert res is not None
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
iterations = int(parsed[b"i"])
|
||||
if iterations < 4096:
|
||||
raise OperationFailure("Server returned an invalid iteration count.")
|
||||
salt = parsed[b"s"]
|
||||
rnonce = parsed[b"r"]
|
||||
if not rnonce.startswith(nonce):
|
||||
raise OperationFailure("Server returned an invalid nonce.")
|
||||
|
||||
without_proof = b"c=biws,r=" + rnonce
|
||||
if cache.data:
|
||||
client_key, server_key, csalt, citerations = cache.data
|
||||
else:
|
||||
client_key, server_key, csalt, citerations = None, None, None, None
|
||||
|
||||
# Salt and / or iterations could change for a number of different
|
||||
# reasons. Either changing invalidates the cache.
|
||||
if not client_key or salt != csalt or iterations != citerations:
|
||||
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
|
||||
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
|
||||
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
|
||||
cache.data = (client_key, server_key, salt, iterations)
|
||||
stored_key = digestmod(client_key).digest()
|
||||
auth_msg = b",".join((first_bare, server_first, without_proof))
|
||||
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
|
||||
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
|
||||
client_final = b",".join((without_proof, client_proof))
|
||||
|
||||
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(client_final),
|
||||
}
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
raise OperationFailure("Server returned an invalid signature.")
|
||||
|
||||
# A third empty challenge may be required if the server does not support
|
||||
# skipEmptyExchange: SERVER-44857.
|
||||
if not res["done"]:
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(b""),
|
||||
}
|
||||
res = conn.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
|
||||
def _password_digest(username: str, password: str) -> str:
|
||||
"""Get a password digest to use for authentication."""
|
||||
if not isinstance(password, str):
|
||||
raise TypeError("password must be an instance of str")
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError(f"username must be an instance of str, not {type(username)}")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
"""Get an auth key to use for authentication."""
|
||||
digest = _password_digest(username, password)
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{nonce}{username}{digest}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
if option in [False, "none"]:
|
||||
return hostname
|
||||
|
||||
af, socktype, proto, canonname, sockaddr = (
|
||||
_getaddrinfo(
|
||||
hostname,
|
||||
None,
|
||||
family=0,
|
||||
type=0,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
flags=socket.AI_CANONNAME,
|
||||
)
|
||||
)[0] # type: ignore[index]
|
||||
|
||||
# For forward just to resolve the cname as dns.lookup() will not return it.
|
||||
if option == "forward":
|
||||
return canonname.lower()
|
||||
|
||||
try:
|
||||
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
|
||||
except socket.gaierror:
|
||||
return canonname.lower()
|
||||
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
'The "kerberos" module must be installed to use GSSAPI authentication.'
|
||||
)
|
||||
|
||||
try:
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = props.service_host or conn.address[0]
|
||||
host = _canonicalize_hostname(host, props.canonicalize_host_name)
|
||||
service = props.service_name + "@" + host
|
||||
if props.service_realm is not None:
|
||||
service = service + "@" + props.service_realm
|
||||
|
||||
if password is not None:
|
||||
if _USE_PRINCIPAL:
|
||||
# Note that, though we use unquote_plus for unquoting URI
|
||||
# options, we use quote here. Microsoft's UrlUnescape (used
|
||||
# by WinKerberos) doesn't support +.
|
||||
principal = ":".join((quote(username), quote(password)))
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
|
||||
)
|
||||
else:
|
||||
if "@" in username:
|
||||
user, domain = username.split("@", 1)
|
||||
else:
|
||||
user, domain = username, None
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service,
|
||||
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
|
||||
user=user,
|
||||
domain=domain,
|
||||
password=password,
|
||||
)
|
||||
else:
|
||||
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
|
||||
|
||||
if result != kerberos.AUTH_GSS_COMPLETE:
|
||||
raise OperationFailure("Kerberos context failed to initialize.")
|
||||
|
||||
try:
|
||||
# pykerberos uses a weird mix of exceptions and return values
|
||||
# to indicate errors.
|
||||
# 0 == continue, 1 == complete, -1 == error
|
||||
# Only authGSSClientStep can return 0.
|
||||
if kerberos.authGSSClientStep(ctx, "") != 0:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
# Start a SASL conversation with mongod/s
|
||||
# Note: pykerberos deals with base64 encoded byte strings.
|
||||
# Since mongo accepts base64 strings as the payload we don't
|
||||
# have to use bson.binary.Binary.
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "GSSAPI",
|
||||
"payload": payload,
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
|
||||
if result == -1:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx) or ""
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
else:
|
||||
raise OperationFailure("Kerberos authentication failed to complete.")
|
||||
|
||||
# Once the security context is established actually authenticate.
|
||||
# See RFC 4752, Section 3.1, last two paragraphs.
|
||||
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
|
||||
|
||||
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
conn.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
|
||||
except kerberos.KrbError as exc:
|
||||
raise OperationFailure(str(exc)) from None
|
||||
|
||||
|
||||
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
payload = (f"\x00{username}\x00{password}").encode()
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "PLAIN",
|
||||
"payload": Binary(payload),
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
conn.command(source, cmd)
|
||||
|
||||
|
||||
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, conn.address).speculate_command()
|
||||
conn.command("$external", cmd)
|
||||
|
||||
|
||||
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = conn.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = (conn.command(source, cmd, publish_events=False)).get("saslSupportedMechs", [])
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
|
||||
else:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
else:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable[..., None]] = {
|
||||
"GSSAPI": _authenticate_gssapi,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
"MONGODB-AWS": _authenticate_aws,
|
||||
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
|
||||
"PLAIN": _authenticate_plain,
|
||||
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
|
||||
"DEFAULT": _authenticate_default,
|
||||
}
|
||||
|
||||
|
||||
class _AuthContext:
|
||||
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(
|
||||
creds: MongoCredential, address: tuple[str, int]
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return cast(_AuthContext, spec_cls(creds, address))
|
||||
return None
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self) -> bool:
|
||||
return bool(self.speculative_authenticate)
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(
|
||||
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
|
||||
) -> None:
|
||||
super().__init__(credentials, address)
|
||||
self.scram_data: Optional[tuple[bytes, bytes]] = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
|
||||
# The 'db' field is included only on the speculative command.
|
||||
cmd["db"] = self.credentials.source
|
||||
# Save for later use.
|
||||
self.scram_data = (nonce, first_bare)
|
||||
return cmd
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self) -> MutableMapping[str, Any]:
|
||||
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
return cmd
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.get_spec_auth_cmd()
|
||||
if cmd is None:
|
||||
return None
|
||||
cmd["db"] = self.credentials.source
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
"MONGODB-OIDC": _OIDCContext,
|
||||
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
}
|
||||
|
||||
|
||||
def authenticate(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
_authenticate_oidc(credentials, conn, reauthenticate)
|
||||
else:
|
||||
auth_func(credentials, conn)
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-AWS Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Type
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
try:
|
||||
import pymongo_auth_aws # type:ignore[import]
|
||||
except ImportError as e:
|
||||
raise ConfigurationError(
|
||||
"MONGODB-AWS authentication requires pymongo-auth-aws: "
|
||||
"install with: python -m pip install 'pymongo[aws]'"
|
||||
) from e
|
||||
# Delayed import.
|
||||
from pymongo_auth_aws.auth import ( # type:ignore[import]
|
||||
set_cached_credentials,
|
||||
set_use_cached_credentials,
|
||||
)
|
||||
|
||||
set_use_cached_credentials(True)
|
||||
|
||||
if conn.max_wire_version < 9:
|
||||
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
|
||||
|
||||
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
|
||||
# Dependency injection:
|
||||
def binary_type(self) -> Type[Binary]:
|
||||
"""Return the bson.binary.Binary type."""
|
||||
return Binary
|
||||
|
||||
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
|
||||
"""Encode a dictionary to BSON."""
|
||||
return bson.encode(doc)
|
||||
|
||||
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
|
||||
"""Decode BSON to a dictionary."""
|
||||
return bson.decode(data)
|
||||
|
||||
try:
|
||||
ctx = AwsSaslContext(
|
||||
pymongo_auth_aws.AwsCredential(
|
||||
credentials.username,
|
||||
credentials.password,
|
||||
credentials.mechanism_properties.aws_session_token,
|
||||
)
|
||||
)
|
||||
client_payload = ctx.step(None)
|
||||
client_first = {"saslStart": 1, "mechanism": "MONGODB-AWS", "payload": client_payload}
|
||||
server_first = conn.command("$external", client_first)
|
||||
res = server_first
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
client_payload = ctx.step(res["payload"])
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": server_first["conversationId"],
|
||||
"payload": client_payload,
|
||||
}
|
||||
res = conn.command("$external", cmd)
|
||||
if res["done"]:
|
||||
# SASL complete.
|
||||
break
|
||||
except pymongo_auth_aws.PyMongoAuthAwsError as exc:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
# Convert to OperationFailure and include pymongo-auth-aws version.
|
||||
raise OperationFailure(
|
||||
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
|
||||
) from None
|
||||
except Exception:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
raise
|
||||
@@ -0,0 +1,303 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._csot import remaining
|
||||
from pymongo.auth_oidc_shared import (
|
||||
CALLBACK_VERSION,
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS,
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS,
|
||||
TIME_BETWEEN_CALLS_SECONDS,
|
||||
OIDCCallback,
|
||||
OIDCCallbackContext,
|
||||
OIDCCallbackResult,
|
||||
OIDCIdPInfo,
|
||||
_OIDCProperties,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
|
||||
from pymongo.lock import Lock, _create_lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
if credentials.cache.data:
|
||||
return credentials.cache.data
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if properties.human_callback is not None:
|
||||
found = False
|
||||
allowed_hosts = properties.allowed_hosts
|
||||
for patt in allowed_hosts:
|
||||
if patt == address[0]:
|
||||
found = True
|
||||
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
|
||||
found = True
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache data.
|
||||
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
if not _IS_SYNC:
|
||||
lock: Lock = field(default_factory=_create_lock) # type: ignore[assignment]
|
||||
else:
|
||||
lock: threading.Lock = field(default_factory=_create_lock) # type: ignore[assignment, no-redef]
|
||||
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# Invalidate the token for the connection.
|
||||
self._invalidate(conn)
|
||||
# Call the appropriate auth logic for the callback type.
|
||||
if self.properties.callback:
|
||||
return self._authenticate_machine(conn)
|
||||
return self._authenticate_human(conn)
|
||||
|
||||
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle an initial authenticate request."""
|
||||
# First handle speculative auth.
|
||||
# If it succeeded, we are done.
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
resp = ctx.speculative_authenticate
|
||||
if resp and resp["done"]:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return resp
|
||||
|
||||
# If spec auth failed, call the appropriate auth logic for the callback type.
|
||||
# We cannot assume that the token is invalid, because a proxy may have been
|
||||
# involved that stripped the speculative auth information.
|
||||
if self.properties.callback:
|
||||
return self._authenticate_machine(conn)
|
||||
return self._authenticate_human(conn)
|
||||
|
||||
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
|
||||
"""Get the appropriate speculative auth command."""
|
||||
if not self.access_token:
|
||||
return None
|
||||
return self._get_start_command({"jwt": self.access_token})
|
||||
|
||||
def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
|
||||
# If there is a cached access token, try to authenticate with it. If
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# fetch a new access token, and try to authenticate again. If authentication
|
||||
# fails for any other reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return self._authenticate_machine(conn)
|
||||
raise
|
||||
return self._sasl_start_jwt(conn)
|
||||
|
||||
def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
# If we have a cached access token, try a JwtStepRequest.
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# and try to authenticate again. If authentication fails for any other
|
||||
# reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# If we have a cached refresh token, try a JwtStepRequest with that.
|
||||
# If authentication fails with error code 18, invalidate the access and
|
||||
# refresh tokens, and try to authenticate again. If authentication fails for
|
||||
# any other reason, raise the error to the user.
|
||||
if self.refresh_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self.refresh_token = None
|
||||
return self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# Start a new Two-Step SASL conversation.
|
||||
# Run a PrincipalStepRequest to get the IdpInfo.
|
||||
cmd = self._get_start_command(None)
|
||||
start_resp = self._run_command(conn, cmd)
|
||||
# Attempt to authenticate with a JwtStepRequest.
|
||||
return self._sasl_continue_jwt(conn, start_resp)
|
||||
|
||||
def _get_access_token(self) -> Optional[str]:
|
||||
properties = self.properties
|
||||
cb: Union[None, OIDCCallback]
|
||||
resp: OIDCCallbackResult
|
||||
|
||||
is_human = properties.human_callback is not None
|
||||
if is_human and self.idp_info is None:
|
||||
return None
|
||||
|
||||
if properties.callback:
|
||||
cb = properties.callback
|
||||
if properties.human_callback:
|
||||
cb = properties.human_callback
|
||||
|
||||
prev_token = self.access_token
|
||||
if prev_token:
|
||||
return prev_token
|
||||
|
||||
if cb is None and not prev_token:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock: # type: ignore[attr-defined]
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
# Ensure that we are waiting a min time between callback invocations.
|
||||
delta = time.time() - self.last_call_time
|
||||
if delta < TIME_BETWEEN_CALLS_SECONDS:
|
||||
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
self.last_call_time = time.time()
|
||||
|
||||
if is_human:
|
||||
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
|
||||
assert self.idp_info is not None
|
||||
else:
|
||||
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
|
||||
context = OIDCCallbackContext(
|
||||
timeout_seconds=timeout,
|
||||
version=CALLBACK_VERSION,
|
||||
refresh_token=self.refresh_token,
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
if not _IS_SYNC:
|
||||
resp = asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
|
||||
else:
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError(
|
||||
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
|
||||
)
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
return self.access_token
|
||||
|
||||
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
|
||||
try:
|
||||
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self._invalidate(conn)
|
||||
raise
|
||||
|
||||
def _is_auth_error(self, err: Exception) -> bool:
|
||||
if not isinstance(err, OperationFailure):
|
||||
return False
|
||||
return err.code == _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
def _invalidate(self, conn: Connection) -> None:
|
||||
# Ignore the invalidation if a token gen id is given and is less than our
|
||||
# current token gen id.
|
||||
token_gen_id = conn.oidc_token_gen_id or 0
|
||||
if token_gen_id is not None and token_gen_id < self.token_gen_id:
|
||||
return
|
||||
self.access_token = None
|
||||
|
||||
def _sasl_continue_jwt(
|
||||
self, conn: Connection, start_resp: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict[str, Any] = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return self._run_command(conn, cmd)
|
||||
|
||||
def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
return self._run_command(conn, cmd)
|
||||
|
||||
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
|
||||
if payload is None:
|
||||
principal_name = self.username
|
||||
if principal_name:
|
||||
payload = {"n": principal_name}
|
||||
else:
|
||||
payload = {}
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
|
||||
|
||||
def _get_continue_command(
|
||||
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
|
||||
) -> MutableMapping[str, Any]:
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {
|
||||
"saslContinue": 1,
|
||||
"payload": bin_payload,
|
||||
"conversationId": start_resp["conversationId"],
|
||||
}
|
||||
|
||||
|
||||
def _authenticate_oidc(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
if reauthenticate:
|
||||
return authenticator.reauthenticate(conn)
|
||||
else:
|
||||
return authenticator.authenticate(conn)
|
||||
@@ -0,0 +1,751 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The bulk write operations interface.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot, common
|
||||
from pymongo.bulk_shared import (
|
||||
_COMMANDS,
|
||||
_DELETE_ALL,
|
||||
_merge_command,
|
||||
_raise_bulk_write_error,
|
||||
_Run,
|
||||
)
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
_BulkWriteContext,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_EncryptedBulkWriteContext,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class _Bulk:
|
||||
"""The private guts of the bulk write API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
ordered: bool,
|
||||
bypass_document_validation: Optional[bool],
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Initialize a _Bulk instance."""
|
||||
self.collection = collection.with_options(
|
||||
codec_options=collection.codec_options._replace(
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
)
|
||||
self.let = let
|
||||
if self.let is not None:
|
||||
common.validate_is_document_type("let", self.let)
|
||||
self.comment: Optional[str] = comment
|
||||
self.ordered = ordered
|
||||
self.ops: list[tuple[int, Mapping[str, Any]]] = []
|
||||
self.executed = False
|
||||
self.bypass_doc_val = bypass_document_validation
|
||||
self.uses_collation = False
|
||||
self.uses_array_filters = False
|
||||
self.uses_hint_update = False
|
||||
self.uses_hint_delete = False
|
||||
self.uses_sort = False
|
||||
self.is_retryable = True
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
# Extra state so that we know where to pick up on a retry attempt.
|
||||
self.current_run = None
|
||||
self.next_run = None
|
||||
self.is_encrypted = False
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
|
||||
encrypter = self.collection.database.client._encrypter
|
||||
if encrypter and not encrypter._bypass_auto_encryption:
|
||||
self.is_encrypted = True
|
||||
return _EncryptedBulkWriteContext
|
||||
else:
|
||||
self.is_encrypted = False
|
||||
return _BulkWriteContext
|
||||
|
||||
def add_insert(self, document: _DocumentOut) -> None:
|
||||
"""Add an insert document to the list of ops."""
|
||||
validate_is_document_type("document", document)
|
||||
# Generate ObjectId client side.
|
||||
if not (isinstance(document, RawBSONDocument) or "_id" in document):
|
||||
document["_id"] = ObjectId()
|
||||
self.ops.append((_INSERT, document))
|
||||
|
||||
def add_update(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
multi: bool,
|
||||
upsert: Optional[bool],
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
sort: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create an update document and add it to the list of ops."""
|
||||
validate_ok_for_update(update)
|
||||
cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi}
|
||||
if upsert is not None:
|
||||
cmd["upsert"] = upsert
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if array_filters is not None:
|
||||
self.uses_array_filters = True
|
||||
cmd["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
if sort is not None:
|
||||
self.uses_sort = True
|
||||
cmd["sort"] = sort
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append((_UPDATE, cmd))
|
||||
|
||||
def add_replace(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: Optional[bool],
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
sort: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create a replace document and add it to the list of ops."""
|
||||
validate_ok_for_replace(replacement)
|
||||
cmd: dict[str, Any] = {"q": selector, "u": replacement}
|
||||
if upsert is not None:
|
||||
cmd["upsert"] = upsert
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
if sort is not None:
|
||||
self.uses_sort = True
|
||||
cmd["sort"] = sort
|
||||
self.ops.append((_UPDATE, cmd))
|
||||
|
||||
def add_delete(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
limit: int,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a delete document and add it to the list of ops."""
|
||||
cmd: dict[str, Any] = {"q": selector, "limit": limit}
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if hint is not None:
|
||||
self.uses_hint_delete = True
|
||||
cmd["hint"] = hint
|
||||
if limit == _DELETE_ALL:
|
||||
# A bulk_write containing a delete_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append((_DELETE, cmd))
|
||||
|
||||
def gen_ordered(self) -> Iterator[Optional[_Run]]:
|
||||
"""Generate batches of operations, batched by type of
|
||||
operation, in the order **provided**.
|
||||
"""
|
||||
run = None
|
||||
for idx, (op_type, operation) in enumerate(self.ops):
|
||||
if run is None:
|
||||
run = _Run(op_type)
|
||||
elif run.op_type != op_type:
|
||||
yield run
|
||||
run = _Run(op_type)
|
||||
run.add(idx, operation)
|
||||
yield run
|
||||
|
||||
def gen_unordered(self) -> Iterator[_Run]:
|
||||
"""Generate batches of operations, batched by type of
|
||||
operation, in arbitrary order.
|
||||
"""
|
||||
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
|
||||
for idx, (op_type, operation) in enumerate(self.ops):
|
||||
operations[op_type].add(idx, operation)
|
||||
|
||||
for run in operations:
|
||||
if run.ops:
|
||||
yield run
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: MongoClient[Any],
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
cmd[bwc.field] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
|
||||
client._process_response(reply, bwc.session) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
|
||||
if bwc.publish:
|
||||
bwc._fail(request_id, failure, duration)
|
||||
# Process the response from the server.
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
|
||||
raise
|
||||
return reply # type: ignore[return-value]
|
||||
|
||||
def unack_write(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: MongoClient[Any],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
cmd = bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if result is not None:
|
||||
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
|
||||
else:
|
||||
# Comply with APM spec.
|
||||
reply = {"ok": 1}
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if bwc.publish:
|
||||
assert bwc.start_time is not None
|
||||
bwc._fail(request_id, failure, duration)
|
||||
raise
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _execute_batch_unack(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: MongoClient[Any],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
write_concern=WriteConcern(w=0),
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
# Though this isn't strictly a "legacy" write, the helper
|
||||
# handles publishing commands and sending our message
|
||||
# without receiving a result. Send 0 for max_doc_size
|
||||
# to disable size checking. Size checking is handled while
|
||||
# the documents are encoded to BSON.
|
||||
self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
|
||||
|
||||
return to_send
|
||||
|
||||
def _execute_batch(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: MongoClient[Any],
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
result = bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
codec_options=bwc.codec,
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
|
||||
|
||||
return result, to_send # type: ignore[return-value]
|
||||
|
||||
def _execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
final_write_concern: Optional[WriteConcern] = None,
|
||||
) -> None:
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
listeners = client._event_listeners
|
||||
|
||||
if not self.current_run:
|
||||
self.current_run = next(generator)
|
||||
self.next_run = None
|
||||
run = self.current_run
|
||||
|
||||
# Connection.command validates the session, but we use
|
||||
# Connection.write_command
|
||||
conn.validate_session(client, session)
|
||||
last_run = False
|
||||
|
||||
while run:
|
||||
if not self.retrying:
|
||||
self.next_run = next(generator, None)
|
||||
if self.next_run is None:
|
||||
last_run = True
|
||||
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
session,
|
||||
run.op_type,
|
||||
self.collection.codec_options,
|
||||
)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
# If this is the last possible operation, use the
|
||||
# final write concern.
|
||||
if last_run and (len(run.ops) - run.idx_offset) == 1:
|
||||
write_concern = final_write_concern or write_concern
|
||||
|
||||
cmd = {cmd_name: self.collection.name, "ordered": self.ordered}
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
|
||||
cmd["let"] = self.let
|
||||
if session:
|
||||
# Start a new retryable write unless one was already
|
||||
# started for this command.
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
|
||||
conn.send_cluster_time(cmd, session, client)
|
||||
conn.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
conn.apply_timeout(client, cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
if write_concern.acknowledged:
|
||||
result, to_send = self._execute_batch(bwc, cmd, ops, client)
|
||||
|
||||
# Retryable writeConcernErrors halt the execution of this run.
|
||||
wce = result.get("writeConcernError", {})
|
||||
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(run, full, run.idx_offset, result)
|
||||
_raise_bulk_write_error(full)
|
||||
|
||||
_merge_command(run, full_result, run.idx_offset, result)
|
||||
|
||||
# We're no longer in a retry once a command succeeds.
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
if self.ordered and "writeErrors" in result:
|
||||
break
|
||||
else:
|
||||
to_send = self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
|
||||
run.idx_offset += len(to_send)
|
||||
|
||||
# We're supposed to continue if errors are
|
||||
# at the write concern level (e.g. wtimeout)
|
||||
if self.ordered and full_result["writeErrors"]:
|
||||
break
|
||||
# Reset our state
|
||||
self.current_run = run = self.next_run
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
operation: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute using write commands."""
|
||||
# nModified is only reported for write commands, not legacy ops.
|
||||
full_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
op_id = _randint()
|
||||
|
||||
def retryable_bulk(
|
||||
session: Optional[ClientSession], conn: Connection, retryable: bool
|
||||
) -> None:
|
||||
self._execute_command(
|
||||
generator,
|
||||
write_concern,
|
||||
session,
|
||||
conn,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
)
|
||||
|
||||
client = self.collection.database.client
|
||||
_ = client._retryable_write(
|
||||
self.is_retryable,
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self, # type: ignore[arg-type]
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
|
||||
_raise_bulk_write_error(full_result)
|
||||
return full_result
|
||||
|
||||
def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
listeners = client._event_listeners
|
||||
op_id = _randint()
|
||||
|
||||
if not self.current_run:
|
||||
self.current_run = next(generator)
|
||||
run = self.current_run
|
||||
|
||||
while run:
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
None,
|
||||
run.op_type,
|
||||
self.collection.codec_options,
|
||||
)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
cmd = {
|
||||
cmd_name: self.collection.name,
|
||||
"ordered": False,
|
||||
"writeConcern": {"w": 0},
|
||||
}
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
run.idx_offset += len(to_send)
|
||||
self.current_run = run = next(generator, None)
|
||||
|
||||
def execute_command_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 WriteConcern, ordered."""
|
||||
full_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
# Ordered bulk writes have to be acknowledged so that we stop
|
||||
# processing at the first error, even when the application
|
||||
# specified unacknowledged writeConcern.
|
||||
initial_write_concern = WriteConcern()
|
||||
op_id = _randint()
|
||||
try:
|
||||
self._execute_command(
|
||||
generator,
|
||||
initial_write_concern,
|
||||
None,
|
||||
conn,
|
||||
op_id,
|
||||
False,
|
||||
full_result,
|
||||
write_concern,
|
||||
)
|
||||
except OperationFailure:
|
||||
pass
|
||||
|
||||
def execute_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
"""Execute all operations, returning no results (w=0)."""
|
||||
if self.uses_collation:
|
||||
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
|
||||
if self.uses_array_filters:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Guard against unsupported unacknowledged writes.
|
||||
unack = write_concern and not write_concern.acknowledged
|
||||
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if unack and self.uses_hint_update and conn.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
if unack and self.uses_sort and conn.max_wire_version < 25:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands."
|
||||
)
|
||||
# Cannot have both unacknowledged writes and bypass document validation.
|
||||
if self.bypass_doc_val:
|
||||
raise OperationFailure(
|
||||
"Cannot set bypass_document_validation with unacknowledged write concern"
|
||||
)
|
||||
|
||||
if self.ordered:
|
||||
return self.execute_command_no_results(conn, generator, write_concern)
|
||||
return self.execute_op_msg_no_results(conn, generator)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
operation: str,
|
||||
) -> Any:
|
||||
"""Execute operations."""
|
||||
if not self.ops:
|
||||
raise InvalidOperation("No operations to execute")
|
||||
if self.executed:
|
||||
raise InvalidOperation("Bulk operations can only be executed once.")
|
||||
self.executed = True
|
||||
write_concern = write_concern or self.collection.write_concern
|
||||
session = _validate_session_write_concern(session, write_concern)
|
||||
|
||||
if self.ordered:
|
||||
generator = self.gen_ordered()
|
||||
else:
|
||||
generator = self.gen_unordered()
|
||||
|
||||
client = self.collection.database.client
|
||||
if not write_concern.acknowledged:
|
||||
with client._conn_for_writes(session, operation) as connection:
|
||||
self.execute_no_results(connection, generator, write_concern)
|
||||
return None
|
||||
else:
|
||||
return self.execute_command(generator, write_concern, session, operation)
|
||||
@@ -0,0 +1,494 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Watch changes on a collection, a database, or the entire cluster."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.synchronous.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# The change streams spec considers the following server errors from the
|
||||
# getMore command non-resumable. All other getMore errors are resumable.
|
||||
_RESUMABLE_GETMORE_ERRORS = frozenset(
|
||||
[
|
||||
6, # HostUnreachable
|
||||
7, # HostNotFound
|
||||
89, # NetworkTimeout
|
||||
91, # ShutdownInProgress
|
||||
189, # PrimarySteppedDown
|
||||
262, # ExceededTimeLimit
|
||||
9001, # SocketException
|
||||
10107, # NotWritablePrimary
|
||||
11600, # InterruptedAtShutdown
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
13436, # NotPrimaryOrSecondary
|
||||
63, # StaleShardVersion
|
||||
150, # StaleEpoch
|
||||
13388, # StaleConfig
|
||||
234, # RetryChangeStream
|
||||
133, # FailedToSatisfyReadPreference
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
"""Return True if given a resumable change stream error."""
|
||||
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
|
||||
return True
|
||||
if isinstance(exc, OperationFailure):
|
||||
if exc._max_wire_version is None:
|
||||
return False
|
||||
return (
|
||||
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
|
||||
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
|
||||
return False
|
||||
|
||||
|
||||
class ChangeStream(Generic[_DocumentType]):
|
||||
"""The internal abstract base class for change stream cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.collection.Collection.watch`,
|
||||
:meth:`pymongo.database.Database.watch`, or
|
||||
:meth:`pymongo.mongo_client.MongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[
|
||||
MongoClient[_DocumentType],
|
||||
Database[_DocumentType],
|
||||
Collection[_DocumentType],
|
||||
],
|
||||
pipeline: Optional[_Pipeline],
|
||||
full_document: Optional[str],
|
||||
resume_after: Optional[Mapping[str, Any]],
|
||||
max_await_time_ms: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional[ClientSession],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
show_expanded_events: Optional[bool] = None,
|
||||
) -> None:
|
||||
if pipeline is None:
|
||||
pipeline = []
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
common.validate_string_or_none("full_document", full_document)
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
# in the pipeline.
|
||||
self._target = target.with_options( # type: ignore
|
||||
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
|
||||
)
|
||||
else:
|
||||
self._target = target
|
||||
|
||||
self._pipeline = copy.deepcopy(pipeline)
|
||||
self._full_document = full_document
|
||||
self._full_document_before_change = full_document_before_change
|
||||
self._uses_start_after = start_after is not None
|
||||
self._uses_resume_after = resume_after is not None
|
||||
self._resume_token = copy.deepcopy(start_after or resume_after)
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._batch_size = batch_size
|
||||
self._collation = collation
|
||||
self._start_at_operation_time = start_at_operation_time
|
||||
self._session = session
|
||||
self._comment = comment
|
||||
self._closed = False
|
||||
self._timeout = self._target._timeout
|
||||
self._show_expanded_events = show_expanded_events
|
||||
|
||||
def _initialize_cursor(self) -> None:
|
||||
# Initialize cursor.
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
|
||||
"""The aggregation command class to be used."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient: # type: ignore[type-arg]
|
||||
"""The client against which the aggregation commands for
|
||||
this ChangeStream will be run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the $changeStream pipeline stage."""
|
||||
options: dict[str, Any] = {}
|
||||
if self._full_document is not None:
|
||||
options["fullDocument"] = self._full_document
|
||||
|
||||
if self._full_document_before_change is not None:
|
||||
options["fullDocumentBeforeChange"] = self._full_document_before_change
|
||||
|
||||
resume_token = self.resume_token
|
||||
if resume_token is not None:
|
||||
if self._uses_start_after:
|
||||
options["startAfter"] = resume_token
|
||||
else:
|
||||
options["resumeAfter"] = resume_token
|
||||
|
||||
elif self._start_at_operation_time is not None:
|
||||
options["startAtOperationTime"] = self._start_at_operation_time
|
||||
|
||||
if self._show_expanded_events:
|
||||
options["showExpandedEvents"] = self._show_expanded_events
|
||||
|
||||
return options
|
||||
|
||||
def _command_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the aggregation command."""
|
||||
options = {}
|
||||
if self._max_await_time_ms is not None:
|
||||
options["maxAwaitTimeMS"] = self._max_await_time_ms
|
||||
if self._batch_size is not None:
|
||||
options["batchSize"] = self._batch_size
|
||||
return options
|
||||
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this ChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
|
||||
This is implemented as a callback because we need access to the wire
|
||||
version in order to determine whether to cache this value.
|
||||
"""
|
||||
if not result["cursor"]["firstBatch"]:
|
||||
if "postBatchResumeToken" in result["cursor"]:
|
||||
self._resume_token = result["cursor"]["postBatchResumeToken"]
|
||||
elif (
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and conn.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
if self._start_at_operation_time is None:
|
||||
raise OperationFailure(
|
||||
"Expected field 'operationTime' missing from command "
|
||||
f"response : {result!r}"
|
||||
)
|
||||
|
||||
def _run_aggregation_cmd(self, session: Optional[ClientSession]) -> CommandCursor: # type: ignore[type-arg]
|
||||
"""Run the full aggregation pipeline for this ChangeStream and return
|
||||
the corresponding CommandCursor.
|
||||
"""
|
||||
cmd = self._aggregation_command_class(
|
||||
self._target,
|
||||
CommandCursor,
|
||||
self._aggregation_pipeline(),
|
||||
self._command_options(),
|
||||
result_processor=self._process_result,
|
||||
comment=self._comment,
|
||||
)
|
||||
return self._client._retryable_read(
|
||||
cmd.get_cursor,
|
||||
self._target._read_preference_for(session),
|
||||
session,
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
def _create_cursor(self) -> CommandCursor: # type: ignore[type-arg]
|
||||
with self._client._tmp_session(self._session) as s:
|
||||
return self._run_aggregation_cmd(session=s)
|
||||
|
||||
def _resume(self) -> None:
|
||||
"""Reestablish this change stream after a resumable error."""
|
||||
try:
|
||||
self._cursor.close()
|
||||
except PyMongoError:
|
||||
pass
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close this ChangeStream."""
|
||||
self._closed = True
|
||||
self._cursor.close()
|
||||
|
||||
def __iter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
@property
|
||||
def resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cached resume token that will be used to resume after the most
|
||||
recently returned change.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
return copy.deepcopy(self._resume_token)
|
||||
|
||||
@_csot.apply
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor.
|
||||
|
||||
This method blocks until the next change document is returned or an
|
||||
unrecoverable error is raised. This method is used when iterating over
|
||||
all changes in the cursor. For example::
|
||||
|
||||
try:
|
||||
resume_token = None
|
||||
pipeline = [{'$match': {'operationType': 'insert'}}]
|
||||
with db.collection.watch(pipeline) as stream:
|
||||
for insert_change in stream:
|
||||
print(insert_change)
|
||||
resume_token = stream.resume_token
|
||||
except pymongo.errors.PyMongoError:
|
||||
# The ChangeStream encountered an unrecoverable error or the
|
||||
# resume attempt failed to recreate the cursor.
|
||||
if resume_token is None:
|
||||
# There is no usable resume token because there was a
|
||||
# failure during ChangeStream initialization.
|
||||
logging.error('...')
|
||||
else:
|
||||
# Use the interrupted ChangeStream's resume token to create
|
||||
# a new ChangeStream. The new stream will continue from the
|
||||
# last seen insert change without missing any events.
|
||||
with db.collection.watch(
|
||||
pipeline, resume_after=resume_token) as stream:
|
||||
for insert_change in stream:
|
||||
print(insert_change)
|
||||
|
||||
Raises :exc:`StopIteration` if this ChangeStream is closed.
|
||||
"""
|
||||
while self.alive:
|
||||
doc = self.try_next()
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopIteration
|
||||
|
||||
__next__ = next
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
return not self._closed
|
||||
|
||||
@_csot.apply
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next change document without waiting
|
||||
indefinitely for the next change. For example::
|
||||
|
||||
with db.collection.watch() as stream:
|
||||
while stream.alive:
|
||||
change = stream.try_next()
|
||||
# Note that the ChangeStream's resume token may be updated
|
||||
# even when no changes are returned.
|
||||
print("Current resume token: %r" % (stream.resume_token,))
|
||||
if change is not None:
|
||||
print("Change document: %r" % (change,))
|
||||
continue
|
||||
# We end up here when there are no recent changes.
|
||||
# Sleep for a while before trying again to avoid flooding
|
||||
# the server with getMore requests when no changes are
|
||||
# available.
|
||||
time.sleep(10)
|
||||
|
||||
If no change document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there have been no changes) then ``None`` is returned.
|
||||
|
||||
:return: The next change document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
if not self._closed and not self._cursor.alive:
|
||||
self._resume()
|
||||
|
||||
# Attempt to get the next change with at most one getMore and at most
|
||||
# one resume attempt.
|
||||
try:
|
||||
try:
|
||||
change = self._cursor._try_next(True)
|
||||
except PyMongoError as exc:
|
||||
if not _resumable(exc):
|
||||
raise
|
||||
self._resume()
|
||||
change = self._cursor._try_next(False)
|
||||
except PyMongoError as exc:
|
||||
# Close the stream after a fatal error.
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
self.close()
|
||||
raise
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
# Check if the cursor was invalidated.
|
||||
if not self._cursor.alive:
|
||||
self._closed = True
|
||||
|
||||
# If no changes are available.
|
||||
if change is None:
|
||||
# We have either iterated over all documents in the cursor,
|
||||
# OR the most-recently returned batch is empty. In either case,
|
||||
# update the cached resume token with the postBatchResumeToken if
|
||||
# one was returned. We also clear the startAtOperationTime.
|
||||
if self._cursor._post_batch_resume_token is not None:
|
||||
self._resume_token = self._cursor._post_batch_resume_token
|
||||
self._start_at_operation_time = None
|
||||
return change
|
||||
|
||||
# Else, changes are available.
|
||||
try:
|
||||
resume_token = change["_id"]
|
||||
except KeyError:
|
||||
self.close()
|
||||
raise InvalidOperation(
|
||||
"Cannot provide resume functionality when the resume token is missing."
|
||||
) from None
|
||||
|
||||
# If this is the last change document from the current batch, cache the
|
||||
# postBatchResumeToken.
|
||||
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
|
||||
resume_token = self._cursor._post_batch_resume_token
|
||||
|
||||
# Hereafter, don't use startAfter; instead use resumeAfter.
|
||||
self._uses_start_after = False
|
||||
self._uses_resume_after = True
|
||||
|
||||
# Cache the resume token and clear startAtOperationTime.
|
||||
self._resume_token = resume_token
|
||||
self._start_at_operation_time = None
|
||||
|
||||
if self._decode_custom:
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
def __enter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class CollectionChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.collection.Collection.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: Collection[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
|
||||
return _CollectionAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class DatabaseChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.database.Database.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: Database[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
|
||||
return _DatabaseAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.client
|
||||
|
||||
|
||||
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
options = super()._change_stream_options()
|
||||
options["allChangesForCluster"] = True
|
||||
return options
|
||||
@@ -0,0 +1,754 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The client-level bulk write operations interface.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot, common
|
||||
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo._client_bulk_shared import (
|
||||
_merge_command,
|
||||
_throw_client_bulk_write_exception,
|
||||
)
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WaitQueueTimeoutError,
|
||||
)
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import (
|
||||
_ClientBulkWriteContext,
|
||||
_convert_client_bulk_exception,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import (
|
||||
ClientBulkWriteResult,
|
||||
DeleteResult,
|
||||
InsertOneResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from pymongo.typings import _DocumentOut, _Pipeline
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class _ClientBulk:
|
||||
"""The private guts of the client-level bulk write API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: MongoClient[Any],
|
||||
write_concern: WriteConcern,
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
verbose_results: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a _ClientBulk instance."""
|
||||
self.client = client
|
||||
self.write_concern = write_concern
|
||||
self.let = let
|
||||
if self.let is not None:
|
||||
common.validate_is_document_type("let", self.let)
|
||||
self.ordered = ordered
|
||||
self.bypass_doc_val = bypass_document_validation
|
||||
self.comment = comment
|
||||
self.verbose_results = verbose_results
|
||||
self.ops: list[tuple[str, Mapping[str, Any]]] = []
|
||||
self.namespaces: list[str] = []
|
||||
self.idx_offset: int = 0
|
||||
self.total_ops: int = 0
|
||||
self.executed = False
|
||||
self.uses_collation = False
|
||||
self.uses_array_filters = False
|
||||
self.is_retryable = self.client.options.retry_writes
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]:
|
||||
return _ClientBulkWriteContext
|
||||
|
||||
def add_insert(self, namespace: str, document: _DocumentOut) -> None:
|
||||
"""Add an insert document to the list of ops."""
|
||||
validate_is_document_type("document", document)
|
||||
# Generate ObjectId client side.
|
||||
if not (isinstance(document, RawBSONDocument) or "_id" in document):
|
||||
document["_id"] = ObjectId()
|
||||
cmd = {"insert": -1, "document": document}
|
||||
self.ops.append(("insert", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
def add_update(
|
||||
self,
|
||||
namespace: str,
|
||||
selector: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
multi: bool,
|
||||
upsert: Optional[bool] = None,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
sort: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create an update document and add it to the list of ops."""
|
||||
validate_ok_for_update(update)
|
||||
cmd = {
|
||||
"update": -1,
|
||||
"filter": selector,
|
||||
"updateMods": update,
|
||||
"multi": multi,
|
||||
}
|
||||
if upsert is not None:
|
||||
cmd["upsert"] = upsert
|
||||
if array_filters is not None:
|
||||
self.uses_array_filters = True
|
||||
cmd["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
cmd["hint"] = hint
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if sort is not None:
|
||||
cmd["sort"] = sort
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append(("update", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
def add_replace(
|
||||
self,
|
||||
namespace: str,
|
||||
selector: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: Optional[bool] = None,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
sort: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create a replace document and add it to the list of ops."""
|
||||
validate_ok_for_replace(replacement)
|
||||
cmd = {
|
||||
"update": -1,
|
||||
"filter": selector,
|
||||
"updateMods": replacement,
|
||||
"multi": False,
|
||||
}
|
||||
if upsert is not None:
|
||||
cmd["upsert"] = upsert
|
||||
if hint is not None:
|
||||
cmd["hint"] = hint
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if sort is not None:
|
||||
cmd["sort"] = sort
|
||||
self.ops.append(("replace", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
def add_delete(
|
||||
self,
|
||||
namespace: str,
|
||||
selector: Mapping[str, Any],
|
||||
multi: bool,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a delete document and add it to the list of ops."""
|
||||
cmd = {"delete": -1, "filter": selector, "multi": multi}
|
||||
if hint is not None:
|
||||
cmd["hint"] = hint
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append(("delete", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: Union[bytes, dict[str, Any]],
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: MongoClient[Any],
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for Connection.write_command that handles event publishing."""
|
||||
cmd["ops"] = op_docs
|
||||
cmd["nsInfo"] = ns_docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._start(cmd, request_id, op_docs, ns_docs)
|
||||
try:
|
||||
reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
|
||||
# Process the response from the server.
|
||||
self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
|
||||
if bwc.publish:
|
||||
bwc._fail(request_id, failure, duration)
|
||||
# Top-level error will be embedded in ClientBulkWriteException.
|
||||
reply = {"error": exc}
|
||||
# Process the response from the server.
|
||||
if isinstance(exc, OperationFailure):
|
||||
self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
|
||||
else:
|
||||
self.client._process_response({}, bwc.session) # type: ignore[arg-type]
|
||||
return reply # type: ignore[return-value]
|
||||
|
||||
def unack_write(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: MongoClient[Any],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
cmd = bwc._start(cmd, request_id, op_docs, ns_docs)
|
||||
try:
|
||||
result = bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if result is not None:
|
||||
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
|
||||
else:
|
||||
# Comply with APM spec.
|
||||
reply = {"ok": 1}
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if bwc.publish:
|
||||
assert bwc.start_time is not None
|
||||
bwc._fail(request_id, failure, duration)
|
||||
# Top-level error will be embedded in ClientBulkWriteException.
|
||||
reply = {"error": exc}
|
||||
return reply
|
||||
|
||||
def _execute_batch_unack(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: dict[str, Any],
|
||||
ops: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Executes a batch of bulkWrite server commands (unack)."""
|
||||
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
|
||||
self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
|
||||
return to_send_ops, to_send_ns
|
||||
|
||||
def _execute_batch(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: dict[str, Any],
|
||||
ops: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Executes a batch of bulkWrite server commands (ack)."""
|
||||
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
|
||||
result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
|
||||
return result, to_send_ops, to_send_ns # type: ignore[return-value]
|
||||
|
||||
def _process_results_cursor(
|
||||
self,
|
||||
full_result: MutableMapping[str, Any],
|
||||
result: MutableMapping[str, Any],
|
||||
conn: Connection,
|
||||
session: Optional[ClientSession],
|
||||
) -> None:
|
||||
"""Internal helper for processing the server reply command cursor."""
|
||||
if result.get("cursor"):
|
||||
if session:
|
||||
session._leave_alive = True
|
||||
coll = Collection(
|
||||
database=Database(self.client, "admin"),
|
||||
name="$cmd.bulkWrite",
|
||||
)
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
result["cursor"],
|
||||
conn.address,
|
||||
session=session,
|
||||
comment=self.comment,
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(conn)
|
||||
|
||||
# Iterate the cursor to get individual write results.
|
||||
try:
|
||||
for doc in cmd_cursor:
|
||||
original_index = doc["idx"] + self.idx_offset
|
||||
op_type, op = self.ops[original_index]
|
||||
|
||||
if not doc["ok"]:
|
||||
result["writeErrors"].append(doc)
|
||||
if self.ordered:
|
||||
return
|
||||
|
||||
# Record individual write result.
|
||||
if doc["ok"] and self.verbose_results:
|
||||
if op_type == "insert":
|
||||
inserted_id = op["document"]["_id"]
|
||||
res = InsertOneResult(inserted_id, acknowledged=True) # type: ignore[assignment]
|
||||
if op_type in ["update", "replace"]:
|
||||
op_type = "update"
|
||||
res = UpdateResult(doc, acknowledged=True, in_client_bulk=True) # type: ignore[assignment]
|
||||
if op_type == "delete":
|
||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||
full_result[f"{op_type}Results"][original_index] = res
|
||||
except Exception as exc:
|
||||
# Attempt to close the cursor, then raise top-level error.
|
||||
if cmd_cursor.alive:
|
||||
cmd_cursor.close()
|
||||
result["error"] = _convert_client_bulk_exception(exc)
|
||||
|
||||
def _execute_command(
|
||||
self,
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
final_write_concern: Optional[WriteConcern] = None,
|
||||
) -> None:
|
||||
"""Internal helper for executing batches of bulkWrite commands."""
|
||||
db_name = "admin"
|
||||
cmd_name = "bulkWrite"
|
||||
listeners = self.client._event_listeners
|
||||
|
||||
# Connection.command validates the session, but we use
|
||||
# Connection.write_command
|
||||
conn.validate_session(self.client, session)
|
||||
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners, # type: ignore[arg-type]
|
||||
session,
|
||||
self.client.codec_options,
|
||||
)
|
||||
|
||||
while self.idx_offset < self.total_ops:
|
||||
# If this is the last possible batch, use the
|
||||
# final write concern.
|
||||
if self.total_ops - self.idx_offset <= bwc.max_write_batch_size:
|
||||
write_concern = final_write_concern or write_concern
|
||||
|
||||
# Construct the server command, specifying the relevant options.
|
||||
cmd = {"bulkWrite": 1}
|
||||
cmd["errorsOnly"] = not self.verbose_results
|
||||
cmd["ordered"] = self.ordered # type: ignore[assignment]
|
||||
not_in_transaction = session and not session.in_transaction
|
||||
if not_in_transaction or not session:
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment # type: ignore[assignment]
|
||||
if self.let:
|
||||
cmd["let"] = self.let
|
||||
|
||||
if session:
|
||||
# Start a new retryable write unless one was already
|
||||
# started for this command.
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
|
||||
conn.send_cluster_time(cmd, session, self.client)
|
||||
conn.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
conn.apply_timeout(self.client, cmd)
|
||||
ops = islice(self.ops, self.idx_offset, None)
|
||||
namespaces = islice(self.namespaces, self.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one server command.
|
||||
if write_concern.acknowledged:
|
||||
raw_result, to_send_ops, _ = self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
|
||||
result = raw_result
|
||||
|
||||
# Top-level server/network error.
|
||||
if result.get("error"):
|
||||
error = result["error"]
|
||||
retryable_top_level_error = (
|
||||
hasattr(error, "details")
|
||||
and isinstance(error.details, dict)
|
||||
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
|
||||
)
|
||||
retryable_network_error = isinstance(
|
||||
error, ConnectionFailure
|
||||
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
|
||||
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
if retryable and (retryable_top_level_error or retryable_network_error):
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(self.ops, self.idx_offset, full, result)
|
||||
_throw_client_bulk_write_exception(full, self.verbose_results)
|
||||
else:
|
||||
_merge_command(self.ops, self.idx_offset, full_result, result)
|
||||
_throw_client_bulk_write_exception(full_result, self.verbose_results)
|
||||
|
||||
result["error"] = None
|
||||
result["writeErrors"] = []
|
||||
if result.get("nErrors", 0) < len(to_send_ops):
|
||||
full_result["anySuccessful"] = True
|
||||
|
||||
# Top-level command error.
|
||||
if not result["ok"]:
|
||||
result["error"] = raw_result
|
||||
_merge_command(self.ops, self.idx_offset, full_result, result)
|
||||
break
|
||||
|
||||
if retryable:
|
||||
# Retryable writeConcernErrors halt the execution of this batch.
|
||||
wce = result.get("writeConcernError", {})
|
||||
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(self.ops, self.idx_offset, full, result)
|
||||
_throw_client_bulk_write_exception(full, self.verbose_results)
|
||||
|
||||
# Process the server reply as a command cursor.
|
||||
self._process_results_cursor(full_result, result, conn, session)
|
||||
|
||||
# Merge this batch's results with the full results.
|
||||
_merge_command(self.ops, self.idx_offset, full_result, result)
|
||||
|
||||
# We're no longer in a retry once a command succeeds.
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
else:
|
||||
to_send_ops, _ = self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
|
||||
|
||||
self.idx_offset += len(to_send_ops)
|
||||
|
||||
# We halt execution if we hit a top-level error,
|
||||
# or an individual error in an ordered bulk write.
|
||||
if full_result["error"] or (self.ordered and full_result["writeErrors"]):
|
||||
break
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
session: Optional[ClientSession],
|
||||
operation: str,
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Execute commands with w=1 WriteConcern."""
|
||||
full_result: MutableMapping[str, Any] = {
|
||||
"anySuccessful": False,
|
||||
"error": None,
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nDeleted": 0,
|
||||
"insertResults": {},
|
||||
"updateResults": {},
|
||||
"deleteResults": {},
|
||||
}
|
||||
op_id = _randint()
|
||||
|
||||
def retryable_bulk(
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
retryable: bool,
|
||||
) -> None:
|
||||
if conn.max_wire_version < 25:
|
||||
raise InvalidOperation(
|
||||
"MongoClient.bulk_write requires MongoDB server version 8.0+."
|
||||
)
|
||||
self._execute_command(
|
||||
self.write_concern,
|
||||
session,
|
||||
conn,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
)
|
||||
|
||||
self.client._retryable_write(
|
||||
self.is_retryable,
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self,
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
if full_result["error"] or full_result["writeErrors"] or full_result["writeConcernErrors"]:
|
||||
_throw_client_bulk_write_exception(full_result, self.verbose_results)
|
||||
return full_result
|
||||
|
||||
def execute_command_unack(
|
||||
self,
|
||||
conn: Connection,
|
||||
) -> None:
|
||||
"""Execute commands with OP_MSG and w=0 writeConcern. Always unordered."""
|
||||
db_name = "admin"
|
||||
cmd_name = "bulkWrite"
|
||||
listeners = self.client._event_listeners
|
||||
op_id = _randint()
|
||||
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners, # type: ignore[arg-type]
|
||||
None,
|
||||
self.client.codec_options,
|
||||
)
|
||||
|
||||
while self.idx_offset < self.total_ops:
|
||||
# Construct the server command, specifying the relevant options.
|
||||
cmd = {"bulkWrite": 1}
|
||||
cmd["errorsOnly"] = True
|
||||
cmd["ordered"] = False
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
cmd["writeConcern"] = {"w": 0} # type: ignore[assignment]
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment # type: ignore[assignment]
|
||||
if self.let:
|
||||
cmd["let"] = self.let
|
||||
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(self.ops, self.idx_offset, None)
|
||||
namespaces = islice(self.namespaces, self.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one server command.
|
||||
to_send_ops, _ = self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
|
||||
|
||||
self.idx_offset += len(to_send_ops)
|
||||
|
||||
def execute_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
) -> None:
|
||||
"""Execute all operations, returning no results (w=0)."""
|
||||
if self.uses_collation:
|
||||
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
|
||||
if self.uses_array_filters:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Cannot have both unacknowledged writes and bypass document validation.
|
||||
if self.bypass_doc_val is not None:
|
||||
raise OperationFailure(
|
||||
"Cannot set bypass_document_validation with unacknowledged write concern"
|
||||
)
|
||||
|
||||
return self.execute_command_unack(conn)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
session: Optional[ClientSession],
|
||||
operation: str,
|
||||
) -> Any:
|
||||
"""Execute operations."""
|
||||
if not self.ops:
|
||||
raise InvalidOperation("No operations to execute")
|
||||
if self.executed:
|
||||
raise InvalidOperation("Bulk operations can only be executed once.")
|
||||
self.executed = True
|
||||
session = _validate_session_write_concern(session, self.write_concern)
|
||||
|
||||
if not self.write_concern.acknowledged:
|
||||
with self.client._conn_for_writes(session, operation) as connection:
|
||||
if connection.max_wire_version < 25:
|
||||
raise InvalidOperation(
|
||||
"MongoClient.bulk_write requires MongoDB server version 8.0+."
|
||||
)
|
||||
self.execute_no_results(connection)
|
||||
return ClientBulkWriteResult(None, False, False) # type: ignore[arg-type]
|
||||
|
||||
result = self.execute_command(session, operation)
|
||||
return ClientBulkWriteResult(
|
||||
result,
|
||||
self.write_concern.acknowledged,
|
||||
self.verbose_results,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,472 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CommandCursor class to iterate over command results."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo import _csot
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
_OpReply,
|
||||
_RawBatchGetMore,
|
||||
)
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.synchronous.cursor import _ConnectionManager
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class CommandCursor(Generic[_DocumentType]):
|
||||
"""A cursor / iterator over command cursors."""
|
||||
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new command cursor."""
|
||||
self._sock_mgr: Any = None
|
||||
self._collection: Collection[_DocumentType] = collection
|
||||
self._id = cursor_info["id"]
|
||||
self._data = deque(cursor_info["firstBatch"])
|
||||
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
|
||||
"postBatchResumeToken"
|
||||
)
|
||||
self._address = address
|
||||
self._batch_size = batch_size
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._timeout = self._collection.database.client.options.timeout
|
||||
self._session = session
|
||||
if self._session is not None:
|
||||
self._session._attached_to_cursor = True
|
||||
self._killed = self._id == 0
|
||||
self._comment = comment
|
||||
if self._killed:
|
||||
self._end_session()
|
||||
|
||||
if "ns" in cursor_info: # noqa: SIM401
|
||||
self._ns = cursor_info["ns"]
|
||||
else:
|
||||
self._ns = collection.full_name
|
||||
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError(
|
||||
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self._die_no_lock()
|
||||
|
||||
def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]:
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
|
||||
.. note:: batch_size can not override MongoDB's internal limits on the
|
||||
amount of data it will return to the client in a single batch (i.e
|
||||
if you set batch size to 1,000,000,000, MongoDB will currently only
|
||||
return 4-16MB of results per batch).
|
||||
|
||||
Raises :exc:`TypeError` if `batch_size` is not an integer.
|
||||
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
|
||||
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
self._batch_size = batch_size == 1 and 2 or batch_size
|
||||
return self
|
||||
|
||||
def _has_next(self) -> bool:
|
||||
"""Returns `True` if the cursor has documents remaining from the
|
||||
previous batch.
|
||||
"""
|
||||
return len(self._data) > 0
|
||||
|
||||
@property
|
||||
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""Retrieve the postBatchResumeToken from the response to a
|
||||
changeStream aggregate or getMore.
|
||||
"""
|
||||
return self._postbatchresumetoken
|
||||
|
||||
def _maybe_pin_connection(self, conn: Connection) -> None:
|
||||
client = self._collection.database.client
|
||||
if not client._should_pin_cursor(self._session):
|
||||
return
|
||||
if not self._sock_mgr:
|
||||
conn.pin_cursor()
|
||||
conn_mgr = _ConnectionManager(conn, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self._id == 0:
|
||||
conn_mgr.close()
|
||||
else:
|
||||
self._sock_mgr = conn_mgr
|
||||
|
||||
def _unpack_response(
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration`. Best to use a for loop::
|
||||
|
||||
for doc in collection.aggregate(pipeline):
|
||||
print(doc)
|
||||
|
||||
.. note:: :attr:`alive` can be True while iterating a cursor from
|
||||
a failed server. In this case :attr:`alive` will return False after
|
||||
:meth:`next` fails to retrieve the next batch of results from the
|
||||
server.
|
||||
"""
|
||||
return bool(len(self._data) or (not self._killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> int:
|
||||
"""Returns the id of the cursor."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def address(self) -> Optional[_Address]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self._session and not self._session._implicit:
|
||||
return self._session
|
||||
return None
|
||||
|
||||
def _prepare_to_die(self) -> tuple[int, Optional[_CursorAddress]]:
|
||||
already_killed = self._killed
|
||||
self._killed = True
|
||||
if self._id and not already_killed:
|
||||
cursor_id = self._id
|
||||
assert self._address is not None
|
||||
address = _CursorAddress(self._address, self._ns)
|
||||
else:
|
||||
# Skip killCursors.
|
||||
cursor_id = 0
|
||||
address = None
|
||||
return cursor_id, address
|
||||
|
||||
def _die_no_lock(self) -> None:
|
||||
"""Closes this cursor without acquiring a lock."""
|
||||
cursor_id, address = self._prepare_to_die()
|
||||
self._collection.database.client._cleanup_cursor_no_lock(
|
||||
cursor_id, address, self._sock_mgr, self._session
|
||||
)
|
||||
if self._session and self._session._implicit:
|
||||
self._session._attached_to_cursor = False
|
||||
self._session = None
|
||||
self._sock_mgr = None
|
||||
|
||||
def _die_lock(self) -> None:
|
||||
"""Closes this cursor."""
|
||||
cursor_id, address = self._prepare_to_die()
|
||||
self._collection.database.client._cleanup_cursor_lock(
|
||||
cursor_id,
|
||||
address,
|
||||
self._sock_mgr,
|
||||
self._session,
|
||||
)
|
||||
if self._session and self._session._implicit:
|
||||
self._session._attached_to_cursor = False
|
||||
self._session = None
|
||||
self._sock_mgr = None
|
||||
|
||||
def _end_session(self) -> None:
|
||||
if self._session and self._session._implicit:
|
||||
self._session._attached_to_cursor = False
|
||||
self._session._end_implicit_session()
|
||||
self._session = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor."""
|
||||
self._die_lock()
|
||||
|
||||
def _send_message(self, operation: _GetMore) -> None:
|
||||
"""Send a getmore message and handle the response."""
|
||||
client = self._collection.database.client
|
||||
try:
|
||||
response = client._run_operation(
|
||||
operation, self._unpack_response, address=self._address
|
||||
)
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
if exc.timeout:
|
||||
self._die_no_lock()
|
||||
else:
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
|
||||
self._id = cursor["id"]
|
||||
else:
|
||||
documents = response.docs
|
||||
assert isinstance(response.data, _OpReply)
|
||||
self._id = response.data.cursor_id
|
||||
|
||||
if self._id == 0:
|
||||
self.close()
|
||||
self._data = deque(documents)
|
||||
|
||||
def _refresh(self) -> int:
|
||||
"""Refreshes the cursor with more data from the server.
|
||||
|
||||
Returns the length of self._data after refresh. Will exit early if
|
||||
self._data is already non-empty. Raises OperationFailure when the
|
||||
cursor cannot be refreshed due to an error on the query.
|
||||
"""
|
||||
if len(self._data) or self._killed:
|
||||
return len(self._data)
|
||||
|
||||
if self._id: # Get More
|
||||
dbname, collname = self._ns.split(".", 1)
|
||||
read_pref = self._collection._read_preference_for(self.session)
|
||||
self._send_message(
|
||||
self._getmore_class(
|
||||
dbname,
|
||||
collname,
|
||||
self._batch_size,
|
||||
self._id,
|
||||
self._collection.codec_options,
|
||||
read_pref,
|
||||
self._session,
|
||||
self._collection.database.client,
|
||||
self._max_await_time_ms,
|
||||
self._sock_mgr,
|
||||
False,
|
||||
self._comment,
|
||||
)
|
||||
)
|
||||
else: # Cursor id is zero nothing else to return
|
||||
self._die_lock()
|
||||
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self) -> Iterator[_DocumentType]:
|
||||
return self
|
||||
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
# Block until a document is returnable.
|
||||
while self.alive:
|
||||
doc = self._try_next(True)
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopIteration
|
||||
|
||||
def __next__(self) -> _DocumentType:
|
||||
return self.next()
|
||||
|
||||
def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor blocking for at most one getMore command."""
|
||||
if not len(self._data) and not self._killed and get_more_allowed:
|
||||
self._refresh()
|
||||
if len(self._data):
|
||||
return self._data.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
self._refresh()
|
||||
if len(self._data):
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next document without waiting
|
||||
indefinitely for data.
|
||||
|
||||
If no document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there is no additional data) then ``None`` is returned.
|
||||
|
||||
:return: The next document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self._try_next(get_more_allowed=True)
|
||||
|
||||
def __enter__(self) -> CommandCursor[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
@_csot.apply
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
class RawBatchCommandCursor(CommandCursor[_DocumentType]):
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
see :meth:`~pymongo.collection.Collection.aggregate_raw_batches`
|
||||
instead.
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
assert not cursor_info.get("firstBatch")
|
||||
super().__init__(
|
||||
collection,
|
||||
cursor_info,
|
||||
address,
|
||||
batch_size,
|
||||
max_await_time_ms,
|
||||
session,
|
||||
comment,
|
||||
)
|
||||
|
||||
def _unpack_response( # type: ignore[override]
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions[dict[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
|
||||
if not legacy_response:
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, index: int) -> NoReturn:
|
||||
raise InvalidOperation("Cannot call __getitem__ on RawBatchCommandCursor")
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,102 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Miscellaneous pieces that need to be synchronized."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import socket
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo.errors import (
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _handle_reauth(func: F) -> F:
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.message import _BulkWriteContext
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except OperationFailure as exc:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a Connection
|
||||
# or has a connection attribute, so we can trigger
|
||||
# a reauth.
|
||||
conn = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Connection):
|
||||
conn = arg
|
||||
break
|
||||
if isinstance(arg, _BulkWriteContext):
|
||||
conn = arg.conn # type: ignore[assignment]
|
||||
break
|
||||
if conn:
|
||||
conn.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
def _getaddrinfo(
|
||||
host: Any, port: Any, **kwargs: Any
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
|
||||
]
|
||||
]:
|
||||
if not _IS_SYNC:
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
|
||||
else:
|
||||
return socket.getaddrinfo(host, port, **kwargs)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
next = builtins.next
|
||||
iter = builtins.iter
|
||||
else:
|
||||
|
||||
def next(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
|
||||
return cls.__next__()
|
||||
|
||||
def iter(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
|
||||
return cls.__iter__()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,543 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Class to monitor a MongoDB server on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.errors import NetworkTimeout, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||
from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.pool import ( # type: ignore[attr-defined]
|
||||
Connection,
|
||||
Pool,
|
||||
_CancellationContext,
|
||||
)
|
||||
from pymongo.synchronous.settings import TopologySettings
|
||||
from pymongo.synchronous.topology import Topology
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _sanitize(error: Exception) -> None:
|
||||
"""PYTHON-2433 Clear error traceback info."""
|
||||
error.__traceback__ = None
|
||||
error.__context__ = None
|
||||
error.__cause__ = None
|
||||
|
||||
|
||||
def _monotonic_duration(start: float) -> float:
|
||||
"""Return the duration since the given start time.
|
||||
|
||||
Accounts for buggy platforms where time.monotonic() is not monotonic.
|
||||
See PYTHON-4600.
|
||||
"""
|
||||
return max(0.0, time.monotonic() - start)
|
||||
|
||||
|
||||
class MonitorBase:
|
||||
def __init__(self, topology: Topology, name: str, interval: int, min_interval: float):
|
||||
"""Base class to do periodic work on a background thread.
|
||||
|
||||
The background thread is signaled to stop when the Topology or
|
||||
this instance is freed.
|
||||
"""
|
||||
|
||||
# We strongly reference the executor and it weakly references us via
|
||||
# this closure. When the monitor is freed, stop the executor soon.
|
||||
def target() -> bool:
|
||||
monitor = self_ref()
|
||||
if monitor is None:
|
||||
return False # Stop the executor.
|
||||
monitor._run() # type:ignore[attr-defined]
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
interval=interval, min_interval=min_interval, target=target, name=name
|
||||
)
|
||||
|
||||
self._executor = executor
|
||||
|
||||
def _on_topology_gc(dummy: Optional[Topology] = None) -> None:
|
||||
# This prevents GC from waiting 10 seconds for hello to complete
|
||||
# See test_cleanup_executors_on_client_del.
|
||||
monitor = self_ref()
|
||||
if monitor:
|
||||
monitor.gc_safe_close()
|
||||
|
||||
# Avoid cycles. When self or topology is freed, stop executor soon.
|
||||
self_ref = weakref.ref(self, executor.close)
|
||||
self._topology = weakref.proxy(topology, _on_topology_gc)
|
||||
_register(self)
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
self._executor.open()
|
||||
|
||||
def gc_safe_close(self) -> None:
|
||||
"""GC safe close."""
|
||||
self._executor.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close and stop monitoring.
|
||||
|
||||
open() restarts the monitor after closing.
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
self._executor.wake()
|
||||
|
||||
|
||||
class Monitor(MonitorBase):
|
||||
def __init__(
|
||||
self,
|
||||
server_description: ServerDescription,
|
||||
topology: Topology,
|
||||
pool: Pool,
|
||||
topology_settings: TopologySettings,
|
||||
):
|
||||
"""Class to monitor a MongoDB server on a background thread.
|
||||
|
||||
Pass an initial ServerDescription, a Topology, a Pool, and
|
||||
TopologySettings.
|
||||
|
||||
The Topology is weakly referenced. The Pool must be exclusive to this
|
||||
Monitor.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_monitor_thread",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
self._server_description = server_description
|
||||
self._pool = pool
|
||||
self._settings = topology_settings
|
||||
self._listeners = self._settings._pool_options._event_listeners
|
||||
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
|
||||
self._cancel_context: Optional[_CancellationContext] = None
|
||||
self._conn_id: Optional[int] = None
|
||||
self._rtt_monitor = _RttMonitor(
|
||||
topology,
|
||||
topology_settings,
|
||||
topology._create_pool_for_monitor(server_description.address),
|
||||
)
|
||||
if topology_settings.server_monitoring_mode == "stream":
|
||||
self._stream = True
|
||||
elif topology_settings.server_monitoring_mode == "poll":
|
||||
self._stream = False
|
||||
else:
|
||||
self._stream = not _is_faas()
|
||||
|
||||
def cancel_check(self) -> None:
|
||||
"""Cancel any concurrent hello check.
|
||||
|
||||
Note: this is called from a weakref.proxy callback and MUST NOT take
|
||||
any locks.
|
||||
"""
|
||||
context = self._cancel_context
|
||||
if context:
|
||||
# Note: we cannot close the socket because doing so may cause
|
||||
# concurrent reads/writes to hang until a timeout occurs
|
||||
# (depending on the platform).
|
||||
context.cancel()
|
||||
|
||||
def _start_rtt_monitor(self) -> None:
|
||||
"""Start an _RttMonitor that periodically runs ping."""
|
||||
# If this monitor is closed directly before (or during) this open()
|
||||
# call, the _RttMonitor will not be closed. Checking if this monitor
|
||||
# was closed directly after resolves the race.
|
||||
self._rtt_monitor.open()
|
||||
if self._executor._stopped:
|
||||
self._rtt_monitor.close()
|
||||
|
||||
def gc_safe_close(self) -> None:
|
||||
self._executor.close()
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
def join(self) -> None:
|
||||
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
self._rtt_monitor.close()
|
||||
# Increment the generation and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
self._reset_connection()
|
||||
|
||||
def _reset_connection(self) -> None:
|
||||
# Clear our pooled connection.
|
||||
self._pool.reset()
|
||||
|
||||
def _run(self) -> None:
|
||||
try:
|
||||
prev_sd = self._server_description
|
||||
try:
|
||||
self._server_description = self._check_server()
|
||||
except _OperationCancelled as exc:
|
||||
_sanitize(exc)
|
||||
# Already closed the connection, wait for the next check.
|
||||
self._server_description = ServerDescription(
|
||||
self._server_description.address, error=exc
|
||||
)
|
||||
if prev_sd.is_server_type_known:
|
||||
# Immediately retry since we've already waited 500ms to
|
||||
# discover that we've been cancelled.
|
||||
self._executor.skip_sleep()
|
||||
return
|
||||
|
||||
# Update the Topology and clear the server pool on error.
|
||||
self._topology.on_change(
|
||||
self._server_description,
|
||||
reset_pool=self._server_description.error,
|
||||
interrupt_connections=isinstance(self._server_description.error, NetworkTimeout),
|
||||
)
|
||||
|
||||
if self._stream and (
|
||||
self._server_description.is_server_type_known
|
||||
and self._server_description.topology_version
|
||||
):
|
||||
self._start_rtt_monitor()
|
||||
# Immediately check for the next streaming response.
|
||||
self._executor.skip_sleep()
|
||||
|
||||
if self._server_description.error and prev_sd.is_server_type_known:
|
||||
# Immediately retry on network errors.
|
||||
self._executor.skip_sleep()
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
finally:
|
||||
if self._executor._stopped:
|
||||
self._rtt_monitor.close()
|
||||
|
||||
def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
|
||||
Returns a ServerDescription.
|
||||
"""
|
||||
self._conn_id = None
|
||||
start = time.monotonic()
|
||||
try:
|
||||
return self._check_once()
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
_sanitize(error)
|
||||
sd = self._server_description
|
||||
address = sd.address
|
||||
duration = _monotonic_duration(start)
|
||||
awaited = bool(self._stream and sd.is_server_type_known and sd.topology_version)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
message=_SDAMStatusMessage.HEARTBEAT_FAIL,
|
||||
topologyId=self._topology._topology_id,
|
||||
serverHost=address[0],
|
||||
serverPort=address[1],
|
||||
awaited=awaited,
|
||||
durationMS=duration * 1000,
|
||||
failure=error,
|
||||
driverConnectionId=self._conn_id,
|
||||
)
|
||||
self._reset_connection()
|
||||
if isinstance(error, _OperationCancelled):
|
||||
raise
|
||||
self._rtt_monitor.reset()
|
||||
# Server type defaults to Unknown.
|
||||
return ServerDescription(address, error=error)
|
||||
|
||||
def _check_once(self) -> ServerDescription:
|
||||
"""A single attempt to call hello.
|
||||
|
||||
Returns a ServerDescription, or raises an exception.
|
||||
"""
|
||||
address = self._server_description.address
|
||||
sd = self._server_description
|
||||
|
||||
# XXX: "awaited" could be incorrectly set to True in the rare case
|
||||
# the pool checkout closes and recreates a connection.
|
||||
awaited = bool(
|
||||
self._pool.conns and self._stream and sd.is_server_type_known and sd.topology_version
|
||||
)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_started(address, awaited)
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
self._reset_connection()
|
||||
with self._pool.checkout() as conn:
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
message=_SDAMStatusMessage.HEARTBEAT_START,
|
||||
topologyId=self._topology._topology_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=address[0],
|
||||
serverPort=address[1],
|
||||
awaited=awaited,
|
||||
)
|
||||
|
||||
self._cancel_context = conn.cancel_context
|
||||
# Record the connection id so we can later attach it to the failed log message.
|
||||
self._conn_id = conn.id
|
||||
response, round_trip_time = self._check_with_socket(conn)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_succeeded(
|
||||
address, round_trip_time, response, response.awaitable
|
||||
)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
message=_SDAMStatusMessage.HEARTBEAT_SUCCESS,
|
||||
topologyId=self._topology._topology_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=address[0],
|
||||
serverPort=address[1],
|
||||
awaited=awaited,
|
||||
durationMS=round_trip_time * 1000,
|
||||
reply=response.document,
|
||||
)
|
||||
return sd
|
||||
|
||||
def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: # type: ignore[type-arg]
|
||||
"""Return (Hello, round_trip_time).
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
if conn.more_to_come:
|
||||
# Read the next streaming hello (MongoDB 4.4+).
|
||||
response = Hello(conn._next_reply(), awaitable=True)
|
||||
elif (
|
||||
self._stream and conn.performed_handshake and self._server_description.topology_version
|
||||
):
|
||||
# Initiate streaming hello (MongoDB 4.4+).
|
||||
response = conn._hello(
|
||||
self._server_description.topology_version,
|
||||
self._settings.heartbeat_frequency,
|
||||
)
|
||||
else:
|
||||
# New connection handshake or polling hello (MongoDB <4.4).
|
||||
response = conn._hello(None, None)
|
||||
duration = _monotonic_duration(start)
|
||||
return response, duration
|
||||
|
||||
|
||||
class SrvMonitor(MonitorBase):
|
||||
def __init__(self, topology: Topology, topology_settings: TopologySettings):
|
||||
"""Class to poll SRV records on a background thread.
|
||||
|
||||
Pass a Topology and a TopologySettings.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_srv_polling_thread",
|
||||
common.MIN_SRV_RESCAN_INTERVAL,
|
||||
topology_settings.heartbeat_frequency,
|
||||
)
|
||||
self._settings = topology_settings
|
||||
self._seedlist = self._settings._seeds
|
||||
assert isinstance(self._settings.fqdn, str)
|
||||
self._fqdn: str = self._settings.fqdn
|
||||
self._startup_time = time.monotonic()
|
||||
|
||||
def _run(self) -> None:
|
||||
# Don't poll right after creation, wait 60 seconds first
|
||||
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
|
||||
return
|
||||
seedlist = self._get_seedlist()
|
||||
if seedlist:
|
||||
self._seedlist = seedlist
|
||||
try:
|
||||
self._topology.on_srv_update(self._seedlist)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
|
||||
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
||||
"""Poll SRV records for a seedlist.
|
||||
|
||||
Returns a list of ServerDescriptions.
|
||||
"""
|
||||
try:
|
||||
resolver = _SrvResolver(
|
||||
self._fqdn,
|
||||
self._settings.pool_options.connect_timeout,
|
||||
self._settings.srv_service_name,
|
||||
)
|
||||
seedlist, ttl = resolver.get_hosts_and_min_ttl()
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except Exception as exc:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
# - SRV records must be rescanned every heartbeatFrequencyMS
|
||||
# - Topology must be left unchanged
|
||||
self.request_check()
|
||||
_debug_log(_SDAM_LOGGER, message="SRV monitor check failed", failure=repr(exc))
|
||||
return None
|
||||
else:
|
||||
self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
|
||||
return seedlist
|
||||
|
||||
|
||||
class _RttMonitor(MonitorBase):
|
||||
def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool):
|
||||
"""Maintain round trip times for a server.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_rtt_thread",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
|
||||
self._pool = pool
|
||||
self._moving_average = MovingAverage()
|
||||
self._moving_min = MovingMinimum()
|
||||
self._lock = _create_lock()
|
||||
|
||||
def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
# Increment the generation and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
self._pool.reset()
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
"""Add a RTT sample."""
|
||||
with self._lock:
|
||||
self._moving_average.add_sample(sample)
|
||||
self._moving_min.add_sample(sample)
|
||||
|
||||
def get(self) -> tuple[Optional[float], float]:
|
||||
"""Get the calculated average, or None if no samples yet and the min."""
|
||||
with self._lock:
|
||||
return self._moving_average.get(), self._moving_min.get()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the average RTT."""
|
||||
with self._lock:
|
||||
self._moving_average.reset()
|
||||
self._moving_min.reset()
|
||||
|
||||
def _run(self) -> None:
|
||||
try:
|
||||
# NOTE: This thread is only run when using the streaming
|
||||
# heartbeat protocol (MongoDB 4.4+).
|
||||
# XXX: Skip check if the server is unknown?
|
||||
rtt = self._ping()
|
||||
self.add_sample(rtt)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
except Exception:
|
||||
self._pool.reset()
|
||||
|
||||
def _ping(self) -> float:
|
||||
"""Run a "hello" command and return the RTT."""
|
||||
with self._pool.checkout() as conn:
|
||||
if self._executor._stopped:
|
||||
raise Exception("_RttMonitor closed")
|
||||
start = time.monotonic()
|
||||
conn.hello()
|
||||
return _monotonic_duration(start)
|
||||
|
||||
|
||||
# Close monitors to cancel any in progress streaming checks before joining
|
||||
# executor threads. For an explanation of how this works see the comment
|
||||
# about _EXECUTORS in periodic_executor.py.
|
||||
_MONITORS = set()
|
||||
|
||||
|
||||
def _register(monitor: MonitorBase) -> None:
|
||||
ref = weakref.ref(monitor, _unregister)
|
||||
_MONITORS.add(ref)
|
||||
|
||||
|
||||
def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None:
|
||||
_MONITORS.remove(monitor_ref)
|
||||
|
||||
|
||||
def _shutdown_monitors() -> None:
|
||||
if _MONITORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Closing monitors removes them.
|
||||
monitors = list(_MONITORS)
|
||||
|
||||
# Close all monitors.
|
||||
for ref in monitors:
|
||||
monitor = ref()
|
||||
if monitor:
|
||||
monitor.gc_safe_close()
|
||||
|
||||
monitor = None
|
||||
|
||||
|
||||
def _shutdown_resources() -> None:
|
||||
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
|
||||
shutdown = _shutdown_monitors
|
||||
if shutdown: # type:ignore[truthy-function]
|
||||
shutdown()
|
||||
shutdown = _shutdown_executors
|
||||
if shutdown: # type:ignore[truthy-function]
|
||||
shutdown()
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
atexit.register(_shutdown_resources)
|
||||
@@ -0,0 +1,298 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal network layer helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo import _csot, helpers_shared, message
|
||||
from pymongo.compression_support import _NO_COMPRESSION
|
||||
from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _OpMsg
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
receive_message,
|
||||
sendall,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def command(
|
||||
conn: Connection,
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[MongoClient[Any]],
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
address: Optional[_Address] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
max_bson_size: Optional[int] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
use_op_msg: bool = False,
|
||||
unacknowledged: bool = False,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
exhaust_allowed: bool = False,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
) -> _DocumentType:
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:param conn: a Connection instance
|
||||
:param dbname: name of the database on which to run the command
|
||||
:param spec: a command document as an ordered dict type, eg SON.
|
||||
:param is_mongos: are we connected to a mongos?
|
||||
:param read_preference: a read preference
|
||||
:param codec_options: a CodecOptions instance
|
||||
:param session: optional ClientSession instance.
|
||||
:param client: optional MongoClient instance for updating $clusterTime.
|
||||
:param check: raise OperationFailure if there are errors
|
||||
:param allowable_errors: errors to ignore if `check` is True
|
||||
:param address: the (host, port) of `conn`
|
||||
:param listeners: An instance of :class:`~pymongo.monitoring.EventListeners`
|
||||
:param max_bson_size: The maximum encoded bson size for this server
|
||||
:param read_concern: The read concern for this command.
|
||||
:param parse_write_concern_error: Whether to parse the ``writeConcernError``
|
||||
field in the command response.
|
||||
:param collation: The collation for this command.
|
||||
:param compression_ctx: optional compression Context.
|
||||
:param use_op_msg: True if we should use OP_MSG.
|
||||
:param unacknowledged: True if this is an unacknowledged command.
|
||||
:param user_fields: Response fields that should be decoded
|
||||
using the TypeDecoders from codec_options, passed to
|
||||
bson._decode_all_selective.
|
||||
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
|
||||
"""
|
||||
name = next(iter(spec))
|
||||
ns = dbname + ".$cmd"
|
||||
speculative_hello = False
|
||||
|
||||
# Publish the original command document, perhaps with lsid and $clusterTime.
|
||||
orig = spec
|
||||
if is_mongos and not use_op_msg:
|
||||
assert read_preference is not None
|
||||
spec = message._maybe_add_read_preference(spec, read_preference)
|
||||
if read_concern and not (session and session.in_transaction):
|
||||
if read_concern.level:
|
||||
spec["readConcern"] = read_concern.document
|
||||
if session:
|
||||
session._update_read_concern(spec, conn)
|
||||
if collation is not None:
|
||||
spec["collation"] = collation
|
||||
|
||||
publish = listeners is not None and listeners.enabled_for_commands
|
||||
start = datetime.datetime.now()
|
||||
if publish:
|
||||
speculative_hello = _is_speculative_authenticate(name, spec)
|
||||
|
||||
if compression_ctx and name.lower() in _NO_COMPRESSION:
|
||||
compression_ctx = None
|
||||
|
||||
if client and client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
spec = orig = client._encrypter.encrypt(dbname, spec, codec_options)
|
||||
|
||||
# Support CSOT
|
||||
if client:
|
||||
conn.apply_timeout(client, spec)
|
||||
_csot.apply_write_concern(spec, write_concern)
|
||||
|
||||
if use_op_msg:
|
||||
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
|
||||
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
|
||||
request_id, msg, size, max_doc_size = message._op_msg(
|
||||
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
|
||||
)
|
||||
# If this is an unacknowledged write then make sure the encoded doc(s)
|
||||
# are small enough, otherwise rely on the server to return an error.
|
||||
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
|
||||
message._raise_document_too_large(name, size, max_bson_size)
|
||||
else:
|
||||
request_id, msg, size = message._query(
|
||||
0, ns, 0, -1, spec, None, codec_options, compression_ctx
|
||||
)
|
||||
|
||||
if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
|
||||
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
command=spec,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_start(
|
||||
orig,
|
||||
dbname,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
try:
|
||||
sendall(conn.conn.get_conn, msg)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc: _DocumentOut = {"ok": 1}
|
||||
else:
|
||||
reply = receive_message(conn, request_id)
|
||||
conn.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields
|
||||
)
|
||||
|
||||
response_doc = unpacked_docs[0]
|
||||
if not conn.ready:
|
||||
cluster_time = response_doc.get("$clusterTime")
|
||||
if cluster_time:
|
||||
conn._cluster_time = cluster_time
|
||||
if client:
|
||||
client._process_response(response_doc, session)
|
||||
if check:
|
||||
helpers_shared._check_command_response(
|
||||
response_doc,
|
||||
conn.max_wire_version,
|
||||
allowable_errors,
|
||||
parse_write_concern_error=parse_write_concern_error,
|
||||
)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - start
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = message._convert_exception(exc)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbname,
|
||||
)
|
||||
raise
|
||||
duration = datetime.datetime.now() - start
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
reply=response_doc,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
speculative_authenticate="speculativeAuthenticate" in orig,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
response_doc,
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
speculative_hello=speculative_hello,
|
||||
database_name=dbname,
|
||||
)
|
||||
|
||||
if client and client._encrypter and reply:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
response_doc = cast(
|
||||
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
|
||||
)
|
||||
|
||||
return response_doc # type: ignore[return-value]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,383 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Communicate with one MongoDB server in a topology."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo.errors import NotPrimaryError, OperationFailure
|
||||
from pymongo.helpers_shared import _check_command_response
|
||||
from pymongo.logger import (
|
||||
_COMMAND_LOGGER,
|
||||
_SDAM_LOGGER,
|
||||
_CommandStatusMessage,
|
||||
_debug_log,
|
||||
_SDAMStatusMessage,
|
||||
)
|
||||
from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query
|
||||
from pymongo.response import PinnedResponse, Response
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from queue import Queue
|
||||
from weakref import ReferenceType
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler
|
||||
from pymongo.synchronous.monitor import Monitor
|
||||
from pymongo.synchronous.pool import Connection, Pool
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(
|
||||
self,
|
||||
server_description: ServerDescription,
|
||||
pool: Pool,
|
||||
monitor: Monitor,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
events: Optional[ReferenceType[Queue[Any]]] = None,
|
||||
) -> None:
|
||||
"""Represent one MongoDB server."""
|
||||
self._description = server_description
|
||||
self._pool = pool
|
||||
self._monitor = monitor
|
||||
self._topology_id = topology_id
|
||||
self._publish = listeners is not None and listeners.enabled_for_server
|
||||
self._listener = listeners
|
||||
self._events = None
|
||||
if self._publish:
|
||||
self._events = events() # type: ignore[misc]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
if not self._pool.opts.load_balanced:
|
||||
self._monitor.open()
|
||||
|
||||
def reset(self, service_id: Optional[ObjectId] = None) -> None:
|
||||
"""Clear the connection pool."""
|
||||
self.pool.reset(service_id)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clear the connection pool and stop the monitor.
|
||||
|
||||
Reconnect with open().
|
||||
"""
|
||||
if self._publish:
|
||||
assert self._listener is not None
|
||||
assert self._events is not None
|
||||
self._events.put(
|
||||
(
|
||||
self._listener.publish_server_closed,
|
||||
(self._description.address, self._topology_id),
|
||||
)
|
||||
)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
message=_SDAMStatusMessage.STOP_SERVER,
|
||||
topologyId=self._topology_id,
|
||||
serverHost=self._description.address[0],
|
||||
serverPort=self._description.address[1],
|
||||
)
|
||||
|
||||
self._monitor.close()
|
||||
self._pool.close()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""Check the server's state soon."""
|
||||
self._monitor.request_check()
|
||||
|
||||
def operation_to_command(
|
||||
self, operation: Union[_Query, _GetMore], conn: Connection, apply_timeout: bool = False
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
cmd, db = operation.as_command(conn, apply_timeout)
|
||||
# Support auto encryption
|
||||
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
|
||||
cmd = operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
|
||||
operation.db, cmd, operation.codec_options
|
||||
)
|
||||
operation.update_command(cmd)
|
||||
|
||||
return cmd, db
|
||||
|
||||
@_handle_reauth
|
||||
def run_operation(
|
||||
self,
|
||||
conn: Connection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: _ServerMode,
|
||||
listeners: Optional[_EventListeners],
|
||||
unpack_res: Callable[..., list[_DocumentOut]],
|
||||
client: MongoClient[Any],
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
This method is used only to run _Query/_GetMore operations from
|
||||
cursors.
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:param conn: A Connection instance.
|
||||
:param operation: A _Query or _GetMore object.
|
||||
:param read_preference: The read preference to use.
|
||||
:param listeners: Instance of _EventListeners or None.
|
||||
:param unpack_res: A callable that decodes the wire protocol response.
|
||||
:param client: A MongoClient instance.
|
||||
"""
|
||||
assert listeners is not None
|
||||
publish = listeners.enabled_for_commands
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(conn)
|
||||
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
|
||||
cmd, dbn = self.operation_to_command(operation, conn, use_cmd)
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
message = operation.get_message(read_preference, conn, use_cmd)
|
||||
request_id, data, max_doc_size = self._split_message(message)
|
||||
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
|
||||
if publish:
|
||||
if "$db" not in cmd:
|
||||
cmd["$db"] = dbn
|
||||
assert listeners is not None
|
||||
listeners.publish_command_start(
|
||||
cmd,
|
||||
dbn,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if more_to_come:
|
||||
reply = conn.receive_message(None)
|
||||
else:
|
||||
conn.send_message(data, max_doc_size)
|
||||
reply = conn.receive_message(request_id)
|
||||
|
||||
# Unpack and check for command errors.
|
||||
if use_cmd:
|
||||
user_fields = _CURSOR_DOC_FIELDS
|
||||
legacy_response = False
|
||||
else:
|
||||
user_fields = None
|
||||
legacy_response = True
|
||||
docs = unpack_res(
|
||||
reply,
|
||||
operation.cursor_id,
|
||||
operation.codec_options,
|
||||
legacy_response=legacy_response,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
if use_cmd:
|
||||
first = docs[0]
|
||||
operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type]
|
||||
_check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.now() - start
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
operation.name,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbn,
|
||||
)
|
||||
raise
|
||||
duration = datetime.now() - start
|
||||
# Must publish in find / getMore / explain command response
|
||||
# format.
|
||||
if use_cmd:
|
||||
res = docs[0]
|
||||
elif operation.name == "explain":
|
||||
res = docs[0] if docs else {}
|
||||
else:
|
||||
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
|
||||
if operation.name == "find":
|
||||
res["cursor"]["firstBatch"] = docs
|
||||
else:
|
||||
res["cursor"]["nextBatch"] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
durationMS=duration,
|
||||
reply=res,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
res,
|
||||
operation.name,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbn,
|
||||
)
|
||||
|
||||
# Decrypt response.
|
||||
client = operation.client # type: ignore[assignment]
|
||||
if client and client._encrypter:
|
||||
if use_cmd:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
docs = _decode_all_selective(decrypted, operation.codec_options, user_fields)
|
||||
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
|
||||
conn.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
# more_to_come flag is set.
|
||||
more_to_come = reply.more_to_come
|
||||
else:
|
||||
# In OP_REPLY, the server keeps sending until cursor_id is 0.
|
||||
more_to_come = bool(operation.exhaust and reply.cursor_id)
|
||||
if operation.conn_mgr:
|
||||
operation.conn_mgr.update_exhaust(more_to_come)
|
||||
response = PinnedResponse(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
conn=conn,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
docs=docs,
|
||||
more_to_come=more_to_come,
|
||||
)
|
||||
else:
|
||||
response = Response(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
docs=docs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> ContextManager[Connection]:
|
||||
return self.pool.checkout(handler)
|
||||
|
||||
@property
|
||||
def description(self) -> ServerDescription:
|
||||
return self._description
|
||||
|
||||
@description.setter
|
||||
def description(self, server_description: ServerDescription) -> None:
|
||||
assert server_description.address == self._description.address
|
||||
self._description = server_description
|
||||
|
||||
@property
|
||||
def pool(self) -> Pool:
|
||||
return self._pool
|
||||
|
||||
def _split_message(
|
||||
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
|
||||
) -> tuple[int, Any, int]:
|
||||
"""Return request_id, data, max_doc_size.
|
||||
|
||||
:param message: (request_id, data, max_doc_size) or (request_id, data)
|
||||
"""
|
||||
if len(message) == 3:
|
||||
return message # type: ignore[return-value]
|
||||
else:
|
||||
# get_more and kill_cursors messages don't include BSON documents.
|
||||
request_id, data = message # type: ignore[misc]
|
||||
return request_id, data, 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._description!r}>"
|
||||
@@ -0,0 +1,175 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Represent MongoClient's configuration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Collection, Optional, Type, Union
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import common
|
||||
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.synchronous import monitor, pool
|
||||
from pymongo.synchronous.pool import Pool
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TopologySettings:
|
||||
def __init__(
|
||||
self,
|
||||
seeds: Optional[Collection[tuple[str, int]]] = None,
|
||||
replica_set_name: Optional[str] = None,
|
||||
pool_class: Optional[Type[Pool]] = None,
|
||||
pool_options: Optional[PoolOptions] = None,
|
||||
monitor_class: Optional[Type[monitor.Monitor]] = None,
|
||||
condition_class: Optional[Type[threading.Condition]] = None,
|
||||
local_threshold_ms: int = LOCAL_THRESHOLD_MS,
|
||||
server_selection_timeout: int = SERVER_SELECTION_TIMEOUT,
|
||||
heartbeat_frequency: int = common.HEARTBEAT_FREQUENCY,
|
||||
server_selector: Optional[_ServerSelector] = None,
|
||||
fqdn: Optional[str] = None,
|
||||
direct_connection: Optional[bool] = False,
|
||||
load_balanced: Optional[bool] = None,
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
Take a list of (host, port) pairs and optional replica set name.
|
||||
"""
|
||||
if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL:
|
||||
raise ConfigurationError(
|
||||
"heartbeatFrequencyMS cannot be less than %d"
|
||||
% (common.MIN_HEARTBEAT_INTERVAL * 1000,)
|
||||
)
|
||||
|
||||
self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)]
|
||||
self._replica_set_name = replica_set_name
|
||||
self._pool_class: Type[Pool] = pool_class or pool.Pool
|
||||
self._pool_options: PoolOptions = pool_options or PoolOptions()
|
||||
self._monitor_class: Type[monitor.Monitor] = monitor_class or monitor.Monitor
|
||||
self._condition_class: Type[threading.Condition] = condition_class or threading.Condition
|
||||
self._local_threshold_ms = local_threshold_ms
|
||||
self._server_selection_timeout = server_selection_timeout
|
||||
self._server_selector = server_selector
|
||||
self._fqdn = fqdn
|
||||
self._heartbeat_frequency = heartbeat_frequency
|
||||
self._direct = direct_connection
|
||||
self._load_balanced = load_balanced
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
if topology_id is not None:
|
||||
self._topology_id = topology_id
|
||||
else:
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack()[:-2])
|
||||
|
||||
@property
|
||||
def seeds(self) -> Collection[tuple[str, int]]:
|
||||
"""List of server addresses."""
|
||||
return self._seeds
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def pool_class(self) -> Type[Pool]:
|
||||
return self._pool_class
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
return self._pool_options
|
||||
|
||||
@property
|
||||
def monitor_class(self) -> Type[monitor.Monitor]:
|
||||
return self._monitor_class
|
||||
|
||||
@property
|
||||
def condition_class(self) -> Type[threading.Condition]:
|
||||
return self._condition_class
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
return self._local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
return self._server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> Optional[_ServerSelector]:
|
||||
return self._server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self._heartbeat_frequency
|
||||
|
||||
@property
|
||||
def fqdn(self) -> Optional[str]:
|
||||
return self._fqdn
|
||||
|
||||
@property
|
||||
def direct(self) -> Optional[bool]:
|
||||
"""Connect directly to a single server, or use a set of servers?
|
||||
|
||||
True if there is one seed and no replica_set_name.
|
||||
"""
|
||||
return self._direct
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self._load_balanced
|
||||
|
||||
@property
|
||||
def srv_service_name(self) -> str:
|
||||
"""The srvServiceName."""
|
||||
return self._srv_service_name
|
||||
|
||||
@property
|
||||
def srv_max_hosts(self) -> int:
|
||||
"""The srvMaxHosts."""
|
||||
return self._srv_max_hosts
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The serverMonitoringMode."""
|
||||
return self._server_monitoring_mode
|
||||
|
||||
def get_topology_type(self) -> int:
|
||||
if self.load_balanced:
|
||||
return TOPOLOGY_TYPE.LoadBalanced
|
||||
elif self.direct:
|
||||
return TOPOLOGY_TYPE.Single
|
||||
elif self.replica_set_name is not None:
|
||||
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
else:
|
||||
return TOPOLOGY_TYPE.Unknown
|
||||
|
||||
def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]:
|
||||
"""Initial dict of (address, ServerDescription) for all seeds."""
|
||||
return {address: ServerDescription(address) for address in self.seeds}
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _have_dnspython() -> bool:
|
||||
try:
|
||||
import dns # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text: Union[str, bytes]) -> str:
|
||||
if isinstance(text, bytes):
|
||||
return text.decode()
|
||||
return text
|
||||
|
||||
|
||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||
if _IS_SYNC:
|
||||
from dns import resolver
|
||||
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
else:
|
||||
from dns import asyncresolver
|
||||
|
||||
return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
|
||||
|
||||
|
||||
_INVALID_HOST_MSG = (
|
||||
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
|
||||
"Did you mean to use 'mongodb://'?"
|
||||
)
|
||||
|
||||
|
||||
class _SrvResolver:
|
||||
def __init__(
|
||||
self,
|
||||
fqdn: str,
|
||||
connect_timeout: Optional[float],
|
||||
srv_service_name: str,
|
||||
srv_max_hosts: int = 0,
|
||||
):
|
||||
self.__fqdn = fqdn
|
||||
self.__srv = srv_service_name
|
||||
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
|
||||
self.__srv_max_hosts = srv_max_hosts or 0
|
||||
# Validate the fully qualified domain name.
|
||||
try:
|
||||
ipaddress.ip_address(fqdn)
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
split_fqdn = self.__fqdn.split(".")
|
||||
self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
self.nparts = len(split_fqdn)
|
||||
|
||||
def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
|
||||
try:
|
||||
results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
|
||||
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||
# No TXT records
|
||||
return None
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc)) from exc
|
||||
if len(results) > 1:
|
||||
raise ConfigurationError("Only one TXT record is supported")
|
||||
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
|
||||
|
||||
def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
|
||||
try:
|
||||
results = _resolve(
|
||||
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
if not encapsulate_errors:
|
||||
# Raise the original error.
|
||||
raise
|
||||
# Else, raise all errors as ConfigurationError.
|
||||
raise ConfigurationError(str(exc)) from exc
|
||||
return results
|
||||
|
||||
def _get_srv_response_and_hosts(
|
||||
self, encapsulate_errors: bool
|
||||
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
|
||||
results = self._resolve_uri(encapsulate_errors)
|
||||
|
||||
# Construct address tuples
|
||||
nodes = [
|
||||
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
|
||||
for res in results
|
||||
]
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
srv_host = node[0].lower()
|
||||
if self.__fqdn == srv_host and self.nparts < 3:
|
||||
raise ConfigurationError(
|
||||
"Invalid SRV host: return address is identical to SRV hostname"
|
||||
)
|
||||
try:
|
||||
nlist = srv_host.split(".")[1:][-self.__slen :]
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
|
||||
if self.__plist != nlist:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
|
||||
if self.__srv_max_hosts:
|
||||
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
|
||||
return results, nodes
|
||||
|
||||
def get_hosts(self) -> list[tuple[str, Any]]:
|
||||
_, nodes = self._get_srv_response_and_hosts(True)
|
||||
return nodes
|
||||
|
||||
def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
|
||||
results, nodes = self._get_srv_response_and_hosts(False)
|
||||
rrset = results.rrset
|
||||
ttl = rrset.ttl if rrset else 0
|
||||
return nodes, ttl
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,193 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||
from pymongo.uri_parser_shared import (
|
||||
_ALLOWED_TXT_OPTS,
|
||||
DEFAULT_PORT,
|
||||
SCHEME,
|
||||
SCHEME_LEN,
|
||||
SRV_SCHEME_LEN,
|
||||
_check_options,
|
||||
_make_options_case_sensitive,
|
||||
_validate_uri,
|
||||
split_hosts,
|
||||
split_options,
|
||||
)
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
{
|
||||
'nodelist': <list of (host, port) tuples>,
|
||||
'username': <username> or None,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
to build nodelist and options.
|
||||
|
||||
:param uri: The MongoDB URI to parse.
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host in the URI.
|
||||
:param validate: If ``True`` (the default), validate and
|
||||
normalize all options. Default: ``True``.
|
||||
:param warn: When validating, if ``True`` then will warn
|
||||
the user then ignore any invalid options or values. If ``False``,
|
||||
validation will error when options are unsupported or values are
|
||||
invalid. Default: ``False``.
|
||||
:param normalize: If ``True``, convert names of URI options
|
||||
to their internally-used names. Default: ``True``.
|
||||
:param connect_timeout: The maximum time in milliseconds to
|
||||
wait for a response from the DNS server.
|
||||
:param srv_service_name: A custom SRV service name
|
||||
|
||||
.. versionchanged:: 4.14
|
||||
``options`` is now type ``dict`` as opposed to a ``_CaseInsensitiveDictionary``.
|
||||
|
||||
.. versionchanged:: 4.6
|
||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||
supported.
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Added the ``normalize`` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added support for mongodb+srv:// URIs.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Return the original value of the ``readPreference`` MongoDB URI option
|
||||
instead of the validated read preference mode.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``warn`` added so invalid options can be ignored.
|
||||
"""
|
||||
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
|
||||
result.update(
|
||||
_parse_srv(
|
||||
uri,
|
||||
default_port,
|
||||
validate,
|
||||
warn,
|
||||
normalize,
|
||||
connect_timeout,
|
||||
srv_service_name,
|
||||
srv_max_hosts,
|
||||
)
|
||||
)
|
||||
result["options"] = _make_options_case_sensitive(result["options"])
|
||||
return result
|
||||
|
||||
|
||||
def _parse_srv(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
else:
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, _ = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if srv_service_name is None:
|
||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||
if "@" in host_part:
|
||||
_, _, hosts = host_part.rpartition("@")
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
fqdn, port = nodes[0]
|
||||
|
||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||
# argument overrides the same option passed in the connection string.
|
||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||
nodes = dns_resolver.get_hosts()
|
||||
dns_options = dns_resolver.get_options()
|
||||
if dns_options:
|
||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||
raise ConfigurationError(
|
||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||
)
|
||||
for opt, val in parsed_dns_options.items():
|
||||
if opt not in options:
|
||||
options[opt] = val
|
||||
if options.get("loadBalanced") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||
if options.get("replicaSet") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||
if "tls" not in options and "ssl" not in options:
|
||||
options["tls"] = True if validate else "true"
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"options": options,
|
||||
}
|
||||
Reference in New Issue
Block a user