@mytec: pushed back before 1.1
This commit is contained in:
@@ -0,0 +1,786 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal network layer helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import errno
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import _csot, ssl_support
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import decompress
|
||||
from pymongo.errors import ProtocolError, _OperationCancelled
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
try:
|
||||
from ssl import SSLError, SSLSocket
|
||||
|
||||
_HAVE_SSL = True
|
||||
except ImportError:
|
||||
_HAVE_SSL = False
|
||||
|
||||
try:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
|
||||
_HAVE_PYOPENSSL = True
|
||||
except ImportError:
|
||||
_HAVE_PYOPENSSL = False
|
||||
_sslConn = SSLSocket # type: ignore[assignment, misc]
|
||||
|
||||
from pymongo.ssl_support import (
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
_POLL_TIMEOUT = 0.5
|
||||
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, *BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
|
||||
# the MongoDB wire protocol
|
||||
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
timeout = sock.gettimeout()
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
|
||||
else:
|
||||
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
|
||||
except asyncio.TimeoutError as exc:
|
||||
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
|
||||
raise socket.timeout("timed out") from exc
|
||||
finally:
|
||||
sock.settimeout(timeout)
|
||||
|
||||
|
||||
if sys.platform != "win32":
|
||||
|
||||
async def _async_socket_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
sent = 0
|
||||
|
||||
def _is_ready(fut: Future[Any]) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
|
||||
while sent < len(buf):
|
||||
try:
|
||||
sent += sock.send(view[sent:])
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = sock.fileno()
|
||||
# Check for closed socket.
|
||||
if fd == -1:
|
||||
raise SSLError("Underlying socket has been closed") from None
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_writer(fd)
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
|
||||
async def _async_socket_receive_ssl(
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
|
||||
def _is_ready(fut: Future[Any]) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
|
||||
while total_read < length:
|
||||
try:
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
# KMS responses update their expected size after the first batch, stop reading after one loop
|
||||
if once:
|
||||
return mv[:read]
|
||||
total_read += read
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = conn.fileno()
|
||||
# Check for closed socket.
|
||||
if fd == -1:
|
||||
raise SSLError("Underlying socket has been closed") from None
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_writer(fd)
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
return mv
|
||||
|
||||
else:
|
||||
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
|
||||
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
|
||||
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
|
||||
async def _async_socket_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
total_length = len(buf)
|
||||
total_sent = 0
|
||||
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
|
||||
# down to 1ms.
|
||||
backoff = 0.001
|
||||
while total_sent < total_length:
|
||||
try:
|
||||
sent = sock.send(view[total_sent:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(backoff)
|
||||
sent = 0
|
||||
if sent > 0:
|
||||
backoff = max(backoff / 2, 0.001)
|
||||
else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_sent += sent
|
||||
|
||||
async def _async_socket_receive_ssl(
|
||||
conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
|
||||
# down to 1ms.
|
||||
backoff = 0.001
|
||||
while total_read < length:
|
||||
try:
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
# KMS responses update their expected size after the first batch, stop reading after one loop
|
||||
if once:
|
||||
return mv[:read]
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(backoff)
|
||||
read = 0
|
||||
if read > 0:
|
||||
backoff = max(backoff / 2, 0.001)
|
||||
else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_read += read
|
||||
return mv
|
||||
|
||||
|
||||
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
sock.sendall(buf)
|
||||
|
||||
|
||||
async def _poll_cancellation(conn: AsyncConnection) -> None:
|
||||
while True:
|
||||
if conn.cancel_context.cancelled:
|
||||
return
|
||||
|
||||
await asyncio.sleep(_POLL_TIMEOUT)
|
||||
|
||||
|
||||
async def async_receive_data_socket(
|
||||
sock: Union[socket.socket, _sslConn], length: int
|
||||
) -> memoryview:
|
||||
sock_timeout = sock.gettimeout()
|
||||
timeout = sock_timeout
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
return await asyncio.wait_for(
|
||||
_async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(
|
||||
_async_socket_receive(sock, length, loop), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise socket.timeout("timed out") from err
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def _async_socket_receive(
|
||||
conn: socket.socket, length: int, loop: AbstractEventLoop
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
bytes_read += chunk_length
|
||||
return mv
|
||||
|
||||
|
||||
_PYPY = "PyPy" in sys.version
|
||||
_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn.sock
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
|
||||
|
||||
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
# To support cancelling a network read, we shorten the socket timeout and
|
||||
# check for the cancellation signal after each timeout. Alternatively we
|
||||
# could close the socket but that does not reliably cancel recv() calls
|
||||
# on all OSes.
|
||||
# When the timeout has expired we perform one final non-blocking recv.
|
||||
# This helps avoid spurious timeouts when the response is actually already
|
||||
# buffered on the client.
|
||||
orig_timeout = conn.conn.gettimeout()
|
||||
try:
|
||||
while bytes_read < length:
|
||||
try:
|
||||
# Use the legacy wait_for_read cancellation approach on PyPy due to PYTHON-5011.
|
||||
# also use it on Windows due to PYTHON-5405
|
||||
if _PYPY or _WINDOWS:
|
||||
wait_for_read(conn, deadline)
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
else:
|
||||
if deadline is not None:
|
||||
short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT)
|
||||
else:
|
||||
short_timeout = _POLL_TIMEOUT
|
||||
conn.set_conn_timeout(short_timeout)
|
||||
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled") from None
|
||||
# We reached the true deadline.
|
||||
raise socket.timeout("timed out") from None
|
||||
except socket.timeout:
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled") from None
|
||||
if (
|
||||
_PYPY
|
||||
or _WINDOWS
|
||||
or not conn.is_sdam
|
||||
and deadline is not None
|
||||
and deadline - time.monotonic() < 0
|
||||
):
|
||||
# We reached the true deadline.
|
||||
raise
|
||||
continue
|
||||
except OSError as exc:
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled") from None
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
finally:
|
||||
conn.set_conn_timeout(orig_timeout)
|
||||
|
||||
return mv
|
||||
|
||||
|
||||
class NetworkingInterfaceBase:
|
||||
def __init__(self, conn: Any):
|
||||
self.conn = conn
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def get_conn(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def sock(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncNetworkingInterface(NetworkingInterfaceBase):
|
||||
def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
|
||||
super().__init__(conn)
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> float | None:
|
||||
return self.conn[1].gettimeout
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self.conn[1].settimeout(timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.conn[1].close()
|
||||
await self.conn[1].wait_closed()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self.conn[0].is_closing()
|
||||
|
||||
@property
|
||||
def get_conn(self) -> PyMongoProtocol:
|
||||
return self.conn[1]
|
||||
|
||||
@property
|
||||
def sock(self) -> socket.socket:
|
||||
return self.conn[0].get_extra_info("socket")
|
||||
|
||||
|
||||
class NetworkingInterface(NetworkingInterfaceBase):
|
||||
def __init__(self, conn: Union[socket.socket, _sslConn]):
|
||||
super().__init__(conn)
|
||||
|
||||
def gettimeout(self) -> float | None:
|
||||
return self.conn.gettimeout()
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self.conn.settimeout(timeout)
|
||||
|
||||
def close(self) -> None:
|
||||
self.conn.close()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self.conn.is_closing()
|
||||
|
||||
@property
|
||||
def get_conn(self) -> Union[socket.socket, _sslConn]:
|
||||
return self.conn
|
||||
|
||||
@property
|
||||
def sock(self) -> Union[socket.socket, _sslConn]:
|
||||
return self.conn
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.conn.fileno()
|
||||
|
||||
def recv_into(self, buffer: bytes | memoryview) -> int:
|
||||
return self.conn.recv_into(buffer)
|
||||
|
||||
|
||||
class PyMongoProtocol(BufferedProtocol):
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
self.transport: Transport = None # type: ignore[assignment]
|
||||
# Each message is reader in 2-3 parts: header, compression header, and message body
|
||||
# The message buffer is allocated after the header is read.
|
||||
self._header = memoryview(bytearray(16))
|
||||
self._header_index = 0
|
||||
self._compression_header = memoryview(bytearray(9))
|
||||
self._compression_index = 0
|
||||
self._message: Optional[memoryview] = None
|
||||
self._message_index = 0
|
||||
# State. TODO: replace booleans with an enum?
|
||||
self._expecting_header = True
|
||||
self._expecting_compression = False
|
||||
self._message_size = 0
|
||||
self._op_code = 0
|
||||
self._connection_lost = False
|
||||
self._read_waiter: Optional[Future[Any]] = None
|
||||
self._timeout = timeout
|
||||
self._is_compressed = False
|
||||
self._compressor_id: Optional[int] = None
|
||||
self._max_message_size = MAX_MESSAGE_SIZE
|
||||
self._response_to: Optional[int] = None
|
||||
self._closed = asyncio.get_running_loop().create_future()
|
||||
self._pending_messages: collections.deque[Future[Any]] = collections.deque()
|
||||
self._done_messages: collections.deque[Future[Any]] = collections.deque()
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self._timeout = timeout
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> float | None:
|
||||
"""The configured timeout for the socket that underlies our protocol pair."""
|
||||
return self._timeout
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called exactly once when a connection is made.
|
||||
The transport argument is the transport representing the write side of the connection.
|
||||
"""
|
||||
self.transport = transport # type: ignore[assignment]
|
||||
self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)
|
||||
|
||||
async def write(self, message: bytes) -> None:
|
||||
"""Write a message to this connection's transport."""
|
||||
if self.transport.is_closing():
|
||||
raise OSError("Connection is closed")
|
||||
self.transport.write(message)
|
||||
self.transport.resume_reading()
|
||||
|
||||
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
|
||||
"""Read a single MongoDB Wire Protocol message from this connection."""
|
||||
if self.transport:
|
||||
try:
|
||||
self.transport.resume_reading()
|
||||
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
|
||||
except AttributeError:
|
||||
raise OSError("connection is already closed") from None
|
||||
self._max_message_size = max_message_size
|
||||
if self._done_messages:
|
||||
message = await self._done_messages.popleft()
|
||||
else:
|
||||
if self.transport and self.transport.is_closing():
|
||||
raise OSError("connection is already closed")
|
||||
read_waiter = asyncio.get_running_loop().create_future()
|
||||
self._pending_messages.append(read_waiter)
|
||||
try:
|
||||
message = await read_waiter
|
||||
finally:
|
||||
if read_waiter in self._done_messages:
|
||||
self._done_messages.remove(read_waiter)
|
||||
if message:
|
||||
op_code, compressor_id, response_to, data = message
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(
|
||||
f"Got response id {response_to!r} but expected {request_id!r}"
|
||||
)
|
||||
if compressor_id is not None:
|
||||
data = decompress(data, compressor_id)
|
||||
return data, op_code
|
||||
raise OSError("connection closed")
|
||||
|
||||
def get_buffer(self, sizehint: int) -> memoryview:
|
||||
"""Called to allocate a new receive buffer.
|
||||
The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
|
||||
If any data does not fit into the returned buffer, this method will be called again until
|
||||
either no data remains or an empty buffer is returned.
|
||||
"""
|
||||
# Due to a bug, Python <=3.11 will call get_buffer() even after we raise
|
||||
# ProtocolError in buffer_updated() and call connection_lost(). We allocate
|
||||
# a temp buffer to drain the waiting data.
|
||||
if self._connection_lost:
|
||||
if not self._message:
|
||||
self._message = memoryview(bytearray(2**14))
|
||||
return self._message
|
||||
# TODO: optimize this by caching pointers to the buffers.
|
||||
# return self._buffer[self._index:]
|
||||
if self._expecting_header:
|
||||
return self._header[self._header_index :]
|
||||
if self._expecting_compression:
|
||||
return self._compression_header[self._compression_index :]
|
||||
return self._message[self._message_index :] # type: ignore[index]
|
||||
|
||||
def buffer_updated(self, nbytes: int) -> None:
|
||||
"""Called when the buffer was updated with the received data"""
|
||||
# Wrote 0 bytes into a non-empty buffer, signal connection closed
|
||||
if nbytes == 0:
|
||||
self.close(OSError("connection closed"))
|
||||
return
|
||||
if self._connection_lost:
|
||||
return
|
||||
if self._expecting_header:
|
||||
self._header_index += nbytes
|
||||
if self._header_index >= 16:
|
||||
self._expecting_header = False
|
||||
try:
|
||||
(
|
||||
self._message_size,
|
||||
self._op_code,
|
||||
self._response_to,
|
||||
self._expecting_compression,
|
||||
) = self.process_header()
|
||||
except ProtocolError as exc:
|
||||
self.close(exc)
|
||||
return
|
||||
self._message = memoryview(bytearray(self._message_size))
|
||||
return
|
||||
if self._expecting_compression:
|
||||
self._compression_index += nbytes
|
||||
if self._compression_index >= 9:
|
||||
self._expecting_compression = False
|
||||
self._op_code, self._compressor_id = self.process_compression_header()
|
||||
return
|
||||
|
||||
self._message_index += nbytes
|
||||
if self._message_index >= self._message_size:
|
||||
self._expecting_header = True
|
||||
# Pause reading to avoid storing an arbitrary number of messages in memory.
|
||||
self.transport.pause_reading()
|
||||
if self._pending_messages:
|
||||
result = self._pending_messages.popleft()
|
||||
else:
|
||||
result = asyncio.get_running_loop().create_future()
|
||||
# Future has been cancelled, close this connection
|
||||
if result.done():
|
||||
self.close(None)
|
||||
return
|
||||
# Necessary values to reconstruct and verify message
|
||||
result.set_result(
|
||||
(self._op_code, self._compressor_id, self._response_to, self._message)
|
||||
)
|
||||
self._done_messages.append(result)
|
||||
# Reset internal state to expect a new message
|
||||
self._header_index = 0
|
||||
self._compression_index = 0
|
||||
self._message_index = 0
|
||||
self._message_size = 0
|
||||
self._message = None
|
||||
self._op_code = 0
|
||||
self._compressor_id = None
|
||||
self._response_to = None
|
||||
|
||||
def process_header(self) -> tuple[int, int, int, bool]:
|
||||
"""Unpack a MongoDB Wire Protocol header."""
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
|
||||
expecting_compression = False
|
||||
if op_code == 2012: # OP_COMPRESSED
|
||||
if length <= 25:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
|
||||
)
|
||||
expecting_compression = True
|
||||
length -= 9
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > self._max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({self._max_message_size!r})"
|
||||
)
|
||||
|
||||
return length - 16, op_code, response_to, expecting_compression
|
||||
|
||||
def process_compression_header(self) -> tuple[int, int]:
|
||||
"""Unpack a MongoDB Wire Protocol compression header."""
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
|
||||
return op_code, compressor_id
|
||||
|
||||
def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
|
||||
pending = list(self._pending_messages)
|
||||
for msg in pending:
|
||||
if not msg.done():
|
||||
if exc is None:
|
||||
msg.set_result(None)
|
||||
else:
|
||||
msg.set_exception(exc)
|
||||
self._done_messages.append(msg)
|
||||
|
||||
def close(self, exc: Optional[Exception] = None) -> None:
|
||||
self.transport.abort()
|
||||
self._resolve_pending_messages(exc)
|
||||
self._connection_lost = True
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception] = None) -> None:
|
||||
self._resolve_pending_messages(exc)
|
||||
if not self._closed.done():
|
||||
self._closed.set_result(None)
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed
|
||||
|
||||
|
||||
async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
|
||||
except asyncio.TimeoutError as exc:
|
||||
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
|
||||
raise socket.timeout("timed out") from exc
|
||||
|
||||
|
||||
async def async_receive_message(
|
||||
conn: AsyncConnection,
|
||||
request_id: Optional[int],
|
||||
max_message_size: int = MAX_MESSAGE_SIZE,
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
timeout: Optional[Union[float, int]]
|
||||
timeout = conn.conn.gettimeout
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
if deadline:
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
timeout = max(deadline - time.monotonic(), 0)
|
||||
|
||||
cancellation_task = create_task(_poll_cancellation(conn))
|
||||
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
|
||||
tasks = [read_task, cancellation_task]
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
if read_task in done:
|
||||
data, op_code = read_task.result()
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.wait(tasks)
|
||||
raise
|
||||
|
||||
|
||||
def receive_message(
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
data: memoryview | bytes
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
|
||||
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = receive_data(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
Reference in New Issue
Block a user