""" WebSocket handler for real-time coverage calculation with progress. Uses the same coverage_service pipeline as the HTTP endpoint but sends progress updates during computation phases. """ import time import asyncio import logging import threading from typing import Optional from fastapi import WebSocket, WebSocketDisconnect from app.services.coverage_service import ( coverage_service, SiteParams, CoverageSettings, apply_preset, select_propagation_model, ) from app.services.parallel_coverage_service import CancellationToken logger = logging.getLogger(__name__) class ConnectionManager: """Track cancellation tokens per calculation.""" def __init__(self): self._cancel_tokens: dict[str, CancellationToken] = {} async def send_progress( self, ws: WebSocket, calc_id: str, phase: str, progress: float, eta: Optional[float] = None, ): try: await ws.send_json({ "type": "progress", "calculation_id": calc_id, "phase": phase, "progress": min(progress, 1.0), "eta_seconds": eta, }) except Exception as e: logger.debug(f"[WS] send_progress failed: {e}") async def send_result(self, ws: WebSocket, calc_id: str, result: dict): try: await ws.send_json({ "type": "result", "calculation_id": calc_id, "data": result, }) except Exception as e: logger.debug(f"[WS] send_result failed: {e}") async def send_error(self, ws: WebSocket, calc_id: str, error: str): try: await ws.send_json({ "type": "error", "calculation_id": calc_id, "message": error, }) except Exception as e: logger.debug(f"[WS] send_error failed: {e}") ws_manager = ConnectionManager() async def _run_calculation(ws: WebSocket, calc_id: str, data: dict): """Run coverage calculation with progress updates via WebSocket.""" cancel_token = CancellationToken() ws_manager._cancel_tokens[calc_id] = cancel_token # Shared progress state — written by worker threads, polled by event loop. # Python GIL makes dict value assignment atomic for simple types. _progress = {"phase": "Initializing", "pct": 0.05, "seq": 0} _done = False def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None): """Thread-safe progress callback — just updates a shared dict.""" _progress["phase"] = phase _progress["pct"] = pct _progress["seq"] += 1 try: sites_data = data.get("sites", []) settings_data = data.get("settings", {}) if not sites_data: await ws_manager.send_error(ws, calc_id, "At least one site required") return if len(sites_data) > 10: await ws_manager.send_error(ws, calc_id, "Maximum 10 sites per request") return # Parse sites and settings (same format as HTTP endpoint) sites = [SiteParams(**s) for s in sites_data] settings = CoverageSettings(**settings_data) if settings.radius > 50000: await ws_manager.send_error(ws, calc_id, "Maximum radius 50km") return if settings.resolution < 50: await ws_manager.send_error(ws, calc_id, "Minimum resolution 50m") return effective_settings = apply_preset(settings.model_copy()) # Determine models used from app.api.routes.coverage import _get_active_models models_used = _get_active_models(effective_settings) env = getattr(effective_settings, 'environment', 'urban') primary_model = select_propagation_model(sites[0].frequency, env) if primary_model.name not in models_used: models_used.insert(0, primary_model.name) await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05) # ── Progress poller: reads shared dict and sends WS updates ── async def progress_poller(): last_sent_seq = 0 last_sent_pct = 0.0 while not _done: await asyncio.sleep(0.3) seq = _progress["seq"] pct = _progress["pct"] phase = _progress["phase"] if seq != last_sent_seq and (pct - last_sent_pct >= 0.01 or phase != "Calculating coverage"): await ws_manager.send_progress(ws, calc_id, phase, pct) last_sent_seq = seq last_sent_pct = pct poller_task = asyncio.create_task(progress_poller()) # Run calculation with timeout start_time = time.time() try: if len(sites) == 1: points = await asyncio.wait_for( coverage_service.calculate_coverage( sites[0], settings, cancel_token, progress_fn=sync_progress_fn, ), timeout=300.0, ) else: points = await asyncio.wait_for( coverage_service.calculate_multi_site_coverage( sites, settings, cancel_token, ), timeout=300.0, ) except asyncio.TimeoutError: cancel_token.cancel() _done = True await poller_task from app.services.parallel_coverage_service import _kill_worker_processes _kill_worker_processes() await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)") return except asyncio.CancelledError: cancel_token.cancel() _done = True await poller_task await ws_manager.send_error(ws, calc_id, "Calculation cancelled") return # Stop poller and send final progress _done = True await poller_task await ws_manager.send_progress(ws, calc_id, "Finalizing", 0.98) computation_time = time.time() - start_time # Build response (identical format to HTTP endpoint) rsrp_values = [p.rsrp for p in points] los_count = sum(1 for p in points if p.has_los) stats = { "min_rsrp": min(rsrp_values) if rsrp_values else 0, "max_rsrp": max(rsrp_values) if rsrp_values else 0, "avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0, "los_percentage": (los_count / len(points) * 100) if points else 0, "points_with_buildings": sum(1 for p in points if p.building_loss > 0), "points_with_terrain_loss": sum(1 for p in points if p.terrain_loss > 0), "points_with_reflection_gain": sum(1 for p in points if p.reflection_gain > 0), "points_with_vegetation_loss": sum(1 for p in points if p.vegetation_loss > 0), "points_with_rain_loss": sum(1 for p in points if p.rain_loss > 0), "points_with_indoor_loss": sum(1 for p in points if p.indoor_loss > 0), "points_with_atmospheric_loss": sum(1 for p in points if p.atmospheric_loss > 0), } result = { "points": [p.model_dump() for p in points], "count": len(points), "settings": effective_settings.model_dump(), "stats": stats, "computation_time": round(computation_time, 2), "models_used": models_used, } await ws_manager.send_result(ws, calc_id, result) except Exception as e: logger.error(f"[WS] Calculation error: {e}", exc_info=True) _done = True try: await poller_task except Exception: pass await ws_manager.send_error(ws, calc_id, str(e)) finally: ws_manager._cancel_tokens.pop(calc_id, None) async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for coverage calculations with progress.""" await websocket.accept() try: while True: data = await websocket.receive_json() msg_type = data.get("type") if msg_type == "calculate": calc_id = data.get("id", "") asyncio.create_task(_run_calculation(websocket, calc_id, data)) elif msg_type == "cancel": calc_id = data.get("id") token = ws_manager._cancel_tokens.get(calc_id) if token: token.cancel() elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: for token in ws_manager._cancel_tokens.values(): token.cancel() except Exception: for token in ws_manager._cancel_tokens.values(): token.cancel()