@mytec: iter3.4.0 start
This commit is contained in:
@@ -8,7 +8,6 @@ progress updates during computation phases.
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
@@ -51,7 +50,7 @@ class ConnectionManager:
|
||||
"data": result,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_result failed: {e}")
|
||||
logger.warning(f"[WS] send_result failed: {e}")
|
||||
|
||||
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
|
||||
try:
|
||||
@@ -61,7 +60,7 @@ class ConnectionManager:
|
||||
"message": error,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_error failed: {e}")
|
||||
logger.warning(f"[WS] send_error failed: {e}")
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
@@ -74,14 +73,32 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
|
||||
# Shared progress state — written by worker threads, polled by event loop.
|
||||
# Python GIL makes dict value assignment atomic for simple types.
|
||||
_progress = {"phase": "Initializing", "pct": 0.05, "seq": 0}
|
||||
_progress = {"phase": "Initializing", "pct": 0.0, "seq": 0}
|
||||
_done = False
|
||||
|
||||
# Get event loop for cross-thread scheduling of WS sends.
|
||||
loop = asyncio.get_running_loop()
|
||||
_last_direct_pct = 0.0
|
||||
_last_direct_phase = ""
|
||||
|
||||
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
||||
"""Thread-safe progress callback — just updates a shared dict."""
|
||||
"""Thread-safe progress callback — updates dict AND schedules direct WS send."""
|
||||
nonlocal _last_direct_pct, _last_direct_phase
|
||||
_progress["phase"] = phase
|
||||
_progress["pct"] = pct
|
||||
_progress["seq"] += 1
|
||||
# Schedule direct WS send via event loop (works from any thread).
|
||||
# Throttle: only send on phase change or >=2% progress.
|
||||
if phase != _last_direct_phase or pct - _last_direct_pct >= 0.02:
|
||||
_last_direct_pct = pct
|
||||
_last_direct_phase = phase
|
||||
try:
|
||||
loop.call_soon_threadsafe(
|
||||
asyncio.ensure_future,
|
||||
ws_manager.send_progress(ws, calc_id, phase, pct),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass # Event loop closed
|
||||
|
||||
try:
|
||||
sites_data = data.get("sites", [])
|
||||
@@ -116,21 +133,27 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
if primary_model.name not in models_used:
|
||||
models_used.insert(0, primary_model.name)
|
||||
|
||||
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05)
|
||||
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.02)
|
||||
|
||||
# ── Progress poller: reads shared dict and sends WS updates ──
|
||||
# ── Backup progress poller: catches anything call_soon_threadsafe missed ──
|
||||
async def progress_poller():
|
||||
last_sent_seq = 0
|
||||
last_sent_pct = 0.0
|
||||
last_sent_phase = "Initializing"
|
||||
while not _done:
|
||||
await asyncio.sleep(0.3)
|
||||
await asyncio.sleep(0.5)
|
||||
seq = _progress["seq"]
|
||||
pct = _progress["pct"]
|
||||
phase = _progress["phase"]
|
||||
if seq != last_sent_seq and (pct - last_sent_pct >= 0.01 or phase != "Calculating coverage"):
|
||||
# Send on any phase change OR >=3% progress (primary sends handle fine-grained)
|
||||
if seq != last_sent_seq and (
|
||||
phase != last_sent_phase
|
||||
or pct - last_sent_pct >= 0.03
|
||||
):
|
||||
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
||||
last_sent_seq = seq
|
||||
last_sent_pct = pct
|
||||
last_sent_phase = phase
|
||||
|
||||
poller_task = asyncio.create_task(progress_poller())
|
||||
|
||||
@@ -149,6 +172,7 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
points = await asyncio.wait_for(
|
||||
coverage_service.calculate_multi_site_coverage(
|
||||
sites, settings, cancel_token,
|
||||
progress_fn=sync_progress_fn,
|
||||
),
|
||||
timeout=300.0,
|
||||
)
|
||||
@@ -170,7 +194,6 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
# Stop poller and send final progress
|
||||
_done = True
|
||||
await poller_task
|
||||
await ws_manager.send_progress(ws, calc_id, "Finalizing", 0.98)
|
||||
|
||||
computation_time = time.time() - start_time
|
||||
|
||||
@@ -201,7 +224,10 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
"models_used": models_used,
|
||||
}
|
||||
|
||||
# Send "Complete" before result so frontend shows 100%
|
||||
await ws_manager.send_progress(ws, calc_id, "Complete", 1.0)
|
||||
await ws_manager.send_result(ws, calc_id, result)
|
||||
logger.info(f"[WS] calc={calc_id} done: {len(points)} pts, {computation_time:.1f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[WS] Calculation error: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user