""" WebSocket handler for real-time coverage calculation with progress. Uses the same coverage_service pipeline as the HTTP endpoint but sends progress updates during computation phases. """ import time import asyncio import logging from typing import Optional from fastapi import WebSocket, WebSocketDisconnect from app.services.coverage_service import ( coverage_service, SiteParams, CoverageSettings, apply_preset, select_propagation_model, ) from app.services.parallel_coverage_service import CancellationToken logger = logging.getLogger(__name__) class ConnectionManager: """Track cancellation tokens per calculation.""" def __init__(self): self._cancel_tokens: dict[str, CancellationToken] = {} async def send_progress( self, ws: WebSocket, calc_id: str, phase: str, progress: float, eta: Optional[float] = None, ): try: await ws.send_json({ "type": "progress", "calculation_id": calc_id, "phase": phase, "progress": min(progress, 1.0), "eta_seconds": eta, }) except Exception as e: logger.debug(f"[WS] send_progress failed: {e}") async def send_result(self, ws: WebSocket, calc_id: str, result: dict): try: await ws.send_json({ "type": "result", "calculation_id": calc_id, "data": result, }) except Exception as e: logger.warning(f"[WS] send_result failed: {e}") async def send_error(self, ws: WebSocket, calc_id: str, error: str): try: await ws.send_json({ "type": "error", "calculation_id": calc_id, "message": error, }) except Exception as e: logger.warning(f"[WS] send_error failed: {e}") async def send_partial_results( self, ws: WebSocket, calc_id: str, points: list, tile_idx: int, total_tiles: int, ): """Send per-tile partial results for progressive rendering.""" try: await ws.send_json({ "type": "partial_results", "calculation_id": calc_id, "points": [p.model_dump() for p in points], "tile": tile_idx, "total_tiles": total_tiles, "progress": (tile_idx + 1) / total_tiles, }) except Exception as e: logger.debug(f"[WS] send_partial_results failed: {e}") ws_manager = ConnectionManager() async def _run_calculation(ws: WebSocket, calc_id: str, data: dict): """Run coverage calculation with progress updates via WebSocket.""" cancel_token = CancellationToken() ws_manager._cancel_tokens[calc_id] = cancel_token # Shared progress state — written by worker threads, polled by event loop. # Python GIL makes dict value assignment atomic for simple types. _progress = {"phase": "Initializing", "pct": 0.0, "seq": 0} _done = False # Get event loop for cross-thread scheduling of WS sends. loop = asyncio.get_running_loop() _last_direct_pct = 0.0 _last_direct_phase = "" def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None): """Thread-safe progress callback — updates dict AND schedules direct WS send.""" nonlocal _last_direct_pct, _last_direct_phase _progress["phase"] = phase _progress["pct"] = pct _progress["seq"] += 1 # Schedule direct WS send via event loop (works from any thread). # Throttle: only send on phase change or >=2% progress. if phase != _last_direct_phase or pct - _last_direct_pct >= 0.02: _last_direct_pct = pct _last_direct_phase = phase try: loop.call_soon_threadsafe( asyncio.ensure_future, ws_manager.send_progress(ws, calc_id, phase, pct), ) except RuntimeError: pass # Event loop closed try: sites_data = data.get("sites", []) settings_data = data.get("settings", {}) if not sites_data: await ws_manager.send_error(ws, calc_id, "At least one site required") return if len(sites_data) > 10: await ws_manager.send_error(ws, calc_id, "Maximum 10 sites per request") return # Parse sites and settings (same format as HTTP endpoint) sites = [SiteParams(**s) for s in sites_data] settings = CoverageSettings(**settings_data) if settings.radius > 50000: await ws_manager.send_error(ws, calc_id, "Maximum radius 50km") return if settings.resolution < 50: await ws_manager.send_error(ws, calc_id, "Minimum resolution 50m") return effective_settings = apply_preset(settings.model_copy()) # Determine models used from app.api.routes.coverage import _get_active_models models_used = _get_active_models(effective_settings) env = getattr(effective_settings, 'environment', 'urban') primary_model = select_propagation_model(sites[0].frequency, env) if primary_model.name not in models_used: models_used.insert(0, primary_model.name) await ws_manager.send_progress(ws, calc_id, "Initializing", 0.02) # ── Tile callback for progressive results (large radius) ── async def _tile_callback(tile_points, tile_idx, total_tiles): await ws_manager.send_partial_results( ws, calc_id, tile_points, tile_idx, total_tiles, ) # ── Backup progress poller: catches anything call_soon_threadsafe missed ── async def progress_poller(): last_sent_seq = 0 last_sent_pct = 0.0 last_sent_phase = "Initializing" while not _done: await asyncio.sleep(0.5) seq = _progress["seq"] pct = _progress["pct"] phase = _progress["phase"] # Send on any phase change OR >=3% progress (primary sends handle fine-grained) if seq != last_sent_seq and ( phase != last_sent_phase or pct - last_sent_pct >= 0.03 ): await ws_manager.send_progress(ws, calc_id, phase, pct) last_sent_seq = seq last_sent_pct = pct last_sent_phase = phase poller_task = asyncio.create_task(progress_poller()) # Run calculation with timeout start_time = time.time() try: if len(sites) == 1: points = await asyncio.wait_for( coverage_service.calculate_coverage( sites[0], settings, cancel_token, progress_fn=sync_progress_fn, tile_callback=_tile_callback, ), timeout=300.0, ) else: points = await asyncio.wait_for( coverage_service.calculate_multi_site_coverage( sites, settings, cancel_token, progress_fn=sync_progress_fn, tile_callback=_tile_callback, ), timeout=300.0, ) except asyncio.TimeoutError: cancel_token.cancel() _done = True await poller_task from app.services.parallel_coverage_service import _kill_worker_processes _kill_worker_processes() await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)") return except asyncio.CancelledError: cancel_token.cancel() _done = True await poller_task await ws_manager.send_error(ws, calc_id, "Calculation cancelled") return # Stop poller and send final progress _done = True await poller_task computation_time = time.time() - start_time # Build response (identical format to HTTP endpoint) rsrp_values = [p.rsrp for p in points] los_count = sum(1 for p in points if p.has_los) stats = { "min_rsrp": min(rsrp_values) if rsrp_values else 0, "max_rsrp": max(rsrp_values) if rsrp_values else 0, "avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0, "los_percentage": (los_count / len(points) * 100) if points else 0, "points_with_buildings": sum(1 for p in points if p.building_loss > 0), "points_with_terrain_loss": sum(1 for p in points if p.terrain_loss > 0), "points_with_reflection_gain": sum(1 for p in points if p.reflection_gain > 0), "points_with_vegetation_loss": sum(1 for p in points if p.vegetation_loss > 0), "points_with_rain_loss": sum(1 for p in points if p.rain_loss > 0), "points_with_indoor_loss": sum(1 for p in points if p.indoor_loss > 0), "points_with_atmospheric_loss": sum(1 for p in points if p.atmospheric_loss > 0), } result = { "points": [p.model_dump() for p in points], "count": len(points), "settings": effective_settings.model_dump(), "stats": stats, "computation_time": round(computation_time, 2), "models_used": models_used, } # Send "Complete" before result so frontend shows 100% await ws_manager.send_progress(ws, calc_id, "Complete", 1.0) await ws_manager.send_result(ws, calc_id, result) logger.info(f"[WS] calc={calc_id} done: {len(points)} pts, {computation_time:.1f}s") except Exception as e: logger.error(f"[WS] Calculation error: {e}", exc_info=True) _done = True try: await poller_task except Exception: pass await ws_manager.send_error(ws, calc_id, str(e)) finally: ws_manager._cancel_tokens.pop(calc_id, None) async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for coverage calculations with progress.""" await websocket.accept() try: while True: data = await websocket.receive_json() msg_type = data.get("type") if msg_type == "calculate": calc_id = data.get("id", "") asyncio.create_task(_run_calculation(websocket, calc_id, data)) elif msg_type == "cancel": calc_id = data.get("id") token = ws_manager._cancel_tokens.get(calc_id) if token: token.cancel() elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: for token in ws_manager._cancel_tokens.values(): token.cancel() except Exception: for token in ws_manager._cancel_tokens.values(): token.cancel()