mytec: after methods
This commit is contained in:
@@ -7,6 +7,7 @@ progress updates during computation phases.
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
@@ -18,6 +19,8 @@ from app.services.coverage_service import (
|
||||
)
|
||||
from app.services.parallel_coverage_service import CancellationToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Track cancellation tokens per calculation."""
|
||||
@@ -37,8 +40,8 @@ class ConnectionManager:
|
||||
"progress": min(progress, 1.0),
|
||||
"eta_seconds": eta,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_progress failed: {e}")
|
||||
|
||||
async def send_result(self, ws: WebSocket, calc_id: str, result: dict):
|
||||
try:
|
||||
@@ -47,8 +50,8 @@ class ConnectionManager:
|
||||
"calculation_id": calc_id,
|
||||
"data": result,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_result failed: {e}")
|
||||
|
||||
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
|
||||
try:
|
||||
@@ -57,8 +60,8 @@ class ConnectionManager:
|
||||
"calculation_id": calc_id,
|
||||
"message": error,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_error failed: {e}")
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
@@ -69,6 +72,17 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
cancel_token = CancellationToken()
|
||||
ws_manager._cancel_tokens[calc_id] = cancel_token
|
||||
|
||||
# Shared progress state — written by worker threads, polled by event loop.
|
||||
# Python GIL makes dict value assignment atomic for simple types.
|
||||
_progress = {"phase": "Initializing", "pct": 0.05, "seq": 0}
|
||||
_done = False
|
||||
|
||||
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
||||
"""Thread-safe progress callback — just updates a shared dict."""
|
||||
_progress["phase"] = phase
|
||||
_progress["pct"] = pct
|
||||
_progress["seq"] += 1
|
||||
|
||||
try:
|
||||
sites_data = data.get("sites", [])
|
||||
settings_data = data.get("settings", {})
|
||||
@@ -104,45 +118,21 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
|
||||
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05)
|
||||
|
||||
# ── Bridge sync progress_fn → async WS sends ──
|
||||
# progress_fn is called from two contexts:
|
||||
# 1. Event loop thread (phases 1-2.5, directly in calculate_coverage)
|
||||
# 2. Worker threads (phase 3, from ProcessPool/sequential executors)
|
||||
# We detect which thread we're on and use the appropriate method.
|
||||
loop = asyncio.get_running_loop()
|
||||
event_loop_thread_id = threading.current_thread().ident
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
# ── Progress poller: reads shared dict and sends WS updates ──
|
||||
async def progress_poller():
|
||||
last_sent_seq = 0
|
||||
last_sent_pct = 0.0
|
||||
while not _done:
|
||||
await asyncio.sleep(0.3)
|
||||
seq = _progress["seq"]
|
||||
pct = _progress["pct"]
|
||||
phase = _progress["phase"]
|
||||
if seq != last_sent_seq and (pct - last_sent_pct >= 0.01 or phase != "Calculating coverage"):
|
||||
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
||||
last_sent_seq = seq
|
||||
last_sent_pct = pct
|
||||
|
||||
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
||||
"""Thread-safe progress callback for coverage_service."""
|
||||
if threading.current_thread().ident == event_loop_thread_id:
|
||||
# From event loop thread: put directly to queue
|
||||
progress_queue.put_nowait((phase, pct))
|
||||
else:
|
||||
# From worker thread: use thread-safe bridge to wake event loop
|
||||
loop.call_soon_threadsafe(progress_queue.put_nowait, (phase, pct))
|
||||
|
||||
# Background task: drain queue and send WS progress messages
|
||||
_sender_done = False
|
||||
|
||||
async def progress_sender():
|
||||
nonlocal _sender_done
|
||||
last_pct = 0.0
|
||||
while not _sender_done:
|
||||
try:
|
||||
phase, pct = await asyncio.wait_for(progress_queue.get(), timeout=0.5)
|
||||
if pct >= 1.0:
|
||||
break
|
||||
# Throttle: only send if progress changed meaningfully
|
||||
if pct - last_pct >= 0.02 or phase != "Calculating coverage":
|
||||
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
||||
last_pct = pct
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
|
||||
progress_task = asyncio.create_task(progress_sender())
|
||||
poller_task = asyncio.create_task(progress_poller())
|
||||
|
||||
# Run calculation with timeout
|
||||
start_time = time.time()
|
||||
@@ -164,25 +154,23 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_token.cancel()
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
_done = True
|
||||
await poller_task
|
||||
from app.services.parallel_coverage_service import _kill_worker_processes
|
||||
_kill_worker_processes()
|
||||
await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)")
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
cancel_token.cancel()
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
_done = True
|
||||
await poller_task
|
||||
await ws_manager.send_error(ws, calc_id, "Calculation cancelled")
|
||||
return
|
||||
|
||||
# Stop progress sender
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
# Stop poller and send final progress
|
||||
_done = True
|
||||
await poller_task
|
||||
await ws_manager.send_progress(ws, calc_id, "Finalizing", 0.98)
|
||||
|
||||
computation_time = time.time() - start_time
|
||||
|
||||
@@ -216,14 +204,10 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
await ws_manager.send_result(ws, calc_id, result)
|
||||
|
||||
except Exception as e:
|
||||
# Stop progress sender on unhandled exception
|
||||
_sender_done = True
|
||||
logger.error(f"[WS] Calculation error: {e}", exc_info=True)
|
||||
_done = True
|
||||
try:
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await progress_task
|
||||
await poller_task
|
||||
except Exception:
|
||||
pass
|
||||
await ws_manager.send_error(ws, calc_id, str(e))
|
||||
|
||||
Reference in New Issue
Block a user