@mytec: iter3.4.0 start

This commit is contained in:
2026-02-02 21:30:00 +02:00
parent 7f0b4d2269
commit 867ee3d0f4
29 changed files with 1386 additions and 324 deletions

View File

@@ -8,7 +8,6 @@ progress updates during computation phases.
import time
import asyncio
import logging
import threading
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect
@@ -51,7 +50,7 @@ class ConnectionManager:
"data": result,
})
except Exception as e:
logger.debug(f"[WS] send_result failed: {e}")
logger.warning(f"[WS] send_result failed: {e}")
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
try:
@@ -61,7 +60,7 @@ class ConnectionManager:
"message": error,
})
except Exception as e:
logger.debug(f"[WS] send_error failed: {e}")
logger.warning(f"[WS] send_error failed: {e}")
ws_manager = ConnectionManager()
@@ -74,14 +73,32 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
# Shared progress state — written by worker threads, polled by event loop.
# Python GIL makes dict value assignment atomic for simple types.
_progress = {"phase": "Initializing", "pct": 0.05, "seq": 0}
_progress = {"phase": "Initializing", "pct": 0.0, "seq": 0}
_done = False
# Get event loop for cross-thread scheduling of WS sends.
loop = asyncio.get_running_loop()
_last_direct_pct = 0.0
_last_direct_phase = ""
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
"""Thread-safe progress callback — just updates a shared dict."""
"""Thread-safe progress callback — updates dict AND schedules direct WS send."""
nonlocal _last_direct_pct, _last_direct_phase
_progress["phase"] = phase
_progress["pct"] = pct
_progress["seq"] += 1
# Schedule direct WS send via event loop (works from any thread).
# Throttle: only send on phase change or >=2% progress.
if phase != _last_direct_phase or pct - _last_direct_pct >= 0.02:
_last_direct_pct = pct
_last_direct_phase = phase
try:
loop.call_soon_threadsafe(
asyncio.ensure_future,
ws_manager.send_progress(ws, calc_id, phase, pct),
)
except RuntimeError:
pass # Event loop closed
try:
sites_data = data.get("sites", [])
@@ -116,21 +133,27 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
if primary_model.name not in models_used:
models_used.insert(0, primary_model.name)
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05)
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.02)
# ── Progress poller: reads shared dict and sends WS updates ──
# ── Backup progress poller: catches anything call_soon_threadsafe missed ──
async def progress_poller():
last_sent_seq = 0
last_sent_pct = 0.0
last_sent_phase = "Initializing"
while not _done:
await asyncio.sleep(0.3)
await asyncio.sleep(0.5)
seq = _progress["seq"]
pct = _progress["pct"]
phase = _progress["phase"]
if seq != last_sent_seq and (pct - last_sent_pct >= 0.01 or phase != "Calculating coverage"):
# Send on any phase change OR >=3% progress (primary sends handle fine-grained)
if seq != last_sent_seq and (
phase != last_sent_phase
or pct - last_sent_pct >= 0.03
):
await ws_manager.send_progress(ws, calc_id, phase, pct)
last_sent_seq = seq
last_sent_pct = pct
last_sent_phase = phase
poller_task = asyncio.create_task(progress_poller())
@@ -149,6 +172,7 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
points = await asyncio.wait_for(
coverage_service.calculate_multi_site_coverage(
sites, settings, cancel_token,
progress_fn=sync_progress_fn,
),
timeout=300.0,
)
@@ -170,7 +194,6 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
# Stop poller and send final progress
_done = True
await poller_task
await ws_manager.send_progress(ws, calc_id, "Finalizing", 0.98)
computation_time = time.time() - start_time
@@ -201,7 +224,10 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
"models_used": models_used,
}
# Send "Complete" before result so frontend shows 100%
await ws_manager.send_progress(ws, calc_id, "Complete", 1.0)
await ws_manager.send_result(ws, calc_id, result)
logger.info(f"[WS] calc={calc_id} done: {len(points)} pts, {computation_time:.1f}s")
except Exception as e:
logger.error(f"[WS] Calculation error: {e}", exc_info=True)