297 lines
11 KiB
Python
297 lines
11 KiB
Python
"""
|
|
WebSocket handler for real-time coverage calculation with progress.
|
|
|
|
Uses the same coverage_service pipeline as the HTTP endpoint but sends
|
|
progress updates during computation phases.
|
|
"""
|
|
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
from app.services.coverage_service import (
|
|
coverage_service, SiteParams, CoverageSettings, apply_preset,
|
|
select_propagation_model,
|
|
)
|
|
from app.services.parallel_coverage_service import CancellationToken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Track cancellation tokens per calculation."""
|
|
|
|
def __init__(self):
|
|
self._cancel_tokens: dict[str, CancellationToken] = {}
|
|
|
|
async def send_progress(
|
|
self, ws: WebSocket, calc_id: str,
|
|
phase: str, progress: float, eta: Optional[float] = None,
|
|
):
|
|
try:
|
|
await ws.send_json({
|
|
"type": "progress",
|
|
"calculation_id": calc_id,
|
|
"phase": phase,
|
|
"progress": min(progress, 1.0),
|
|
"eta_seconds": eta,
|
|
})
|
|
except Exception as e:
|
|
logger.debug(f"[WS] send_progress failed: {e}")
|
|
|
|
async def send_result(self, ws: WebSocket, calc_id: str, result: dict):
|
|
try:
|
|
await ws.send_json({
|
|
"type": "result",
|
|
"calculation_id": calc_id,
|
|
"data": result,
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"[WS] send_result failed: {e}")
|
|
|
|
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
|
|
try:
|
|
await ws.send_json({
|
|
"type": "error",
|
|
"calculation_id": calc_id,
|
|
"message": error,
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"[WS] send_error failed: {e}")
|
|
|
|
async def send_partial_results(
|
|
self, ws: WebSocket, calc_id: str,
|
|
points: list, tile_idx: int, total_tiles: int,
|
|
):
|
|
"""Send per-tile partial results for progressive rendering."""
|
|
try:
|
|
await ws.send_json({
|
|
"type": "partial_results",
|
|
"calculation_id": calc_id,
|
|
"points": [p.model_dump() for p in points],
|
|
"tile": tile_idx,
|
|
"total_tiles": total_tiles,
|
|
"progress": (tile_idx + 1) / total_tiles,
|
|
})
|
|
except Exception as e:
|
|
logger.debug(f"[WS] send_partial_results failed: {e}")
|
|
|
|
|
|
ws_manager = ConnectionManager()
|
|
|
|
|
|
async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
|
"""Run coverage calculation with progress updates via WebSocket."""
|
|
cancel_token = CancellationToken()
|
|
ws_manager._cancel_tokens[calc_id] = cancel_token
|
|
|
|
# Shared progress state — written by worker threads, polled by event loop.
|
|
# Python GIL makes dict value assignment atomic for simple types.
|
|
_progress = {"phase": "Initializing", "pct": 0.0, "seq": 0}
|
|
_done = False
|
|
|
|
# Get event loop for cross-thread scheduling of WS sends.
|
|
loop = asyncio.get_running_loop()
|
|
_last_direct_pct = 0.0
|
|
_last_direct_phase = ""
|
|
|
|
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
|
"""Thread-safe progress callback — updates dict AND schedules direct WS send."""
|
|
nonlocal _last_direct_pct, _last_direct_phase
|
|
_progress["phase"] = phase
|
|
_progress["pct"] = pct
|
|
_progress["seq"] += 1
|
|
# Schedule direct WS send via event loop (works from any thread).
|
|
# Throttle: only send on phase change or >=2% progress.
|
|
if phase != _last_direct_phase or pct - _last_direct_pct >= 0.02:
|
|
_last_direct_pct = pct
|
|
_last_direct_phase = phase
|
|
try:
|
|
loop.call_soon_threadsafe(
|
|
asyncio.ensure_future,
|
|
ws_manager.send_progress(ws, calc_id, phase, pct),
|
|
)
|
|
except RuntimeError:
|
|
pass # Event loop closed
|
|
|
|
try:
|
|
sites_data = data.get("sites", [])
|
|
settings_data = data.get("settings", {})
|
|
|
|
if not sites_data:
|
|
await ws_manager.send_error(ws, calc_id, "At least one site required")
|
|
return
|
|
|
|
if len(sites_data) > 10:
|
|
await ws_manager.send_error(ws, calc_id, "Maximum 10 sites per request")
|
|
return
|
|
|
|
# Parse sites and settings (same format as HTTP endpoint)
|
|
sites = [SiteParams(**s) for s in sites_data]
|
|
settings = CoverageSettings(**settings_data)
|
|
|
|
if settings.radius > 50000:
|
|
await ws_manager.send_error(ws, calc_id, "Maximum radius 50km")
|
|
return
|
|
if settings.resolution < 50:
|
|
await ws_manager.send_error(ws, calc_id, "Minimum resolution 50m")
|
|
return
|
|
|
|
effective_settings = apply_preset(settings.model_copy())
|
|
|
|
# Determine models used
|
|
from app.api.routes.coverage import _get_active_models
|
|
models_used = _get_active_models(effective_settings)
|
|
env = getattr(effective_settings, 'environment', 'urban')
|
|
primary_model = select_propagation_model(sites[0].frequency, env)
|
|
if primary_model.name not in models_used:
|
|
models_used.insert(0, primary_model.name)
|
|
|
|
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.02)
|
|
|
|
# ── Tile callback for progressive results (large radius) ──
|
|
async def _tile_callback(tile_points, tile_idx, total_tiles):
|
|
await ws_manager.send_partial_results(
|
|
ws, calc_id, tile_points, tile_idx, total_tiles,
|
|
)
|
|
|
|
# ── Backup progress poller: catches anything call_soon_threadsafe missed ──
|
|
async def progress_poller():
|
|
last_sent_seq = 0
|
|
last_sent_pct = 0.0
|
|
last_sent_phase = "Initializing"
|
|
while not _done:
|
|
await asyncio.sleep(0.5)
|
|
seq = _progress["seq"]
|
|
pct = _progress["pct"]
|
|
phase = _progress["phase"]
|
|
# Send on any phase change OR >=3% progress (primary sends handle fine-grained)
|
|
if seq != last_sent_seq and (
|
|
phase != last_sent_phase
|
|
or pct - last_sent_pct >= 0.03
|
|
):
|
|
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
|
last_sent_seq = seq
|
|
last_sent_pct = pct
|
|
last_sent_phase = phase
|
|
|
|
poller_task = asyncio.create_task(progress_poller())
|
|
|
|
# Run calculation with timeout
|
|
start_time = time.time()
|
|
try:
|
|
if len(sites) == 1:
|
|
points = await asyncio.wait_for(
|
|
coverage_service.calculate_coverage(
|
|
sites[0], settings, cancel_token,
|
|
progress_fn=sync_progress_fn,
|
|
tile_callback=_tile_callback,
|
|
),
|
|
timeout=300.0,
|
|
)
|
|
else:
|
|
points = await asyncio.wait_for(
|
|
coverage_service.calculate_multi_site_coverage(
|
|
sites, settings, cancel_token,
|
|
progress_fn=sync_progress_fn,
|
|
tile_callback=_tile_callback,
|
|
),
|
|
timeout=300.0,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
cancel_token.cancel()
|
|
_done = True
|
|
await poller_task
|
|
from app.services.parallel_coverage_service import _kill_worker_processes
|
|
_kill_worker_processes()
|
|
await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)")
|
|
return
|
|
except asyncio.CancelledError:
|
|
cancel_token.cancel()
|
|
_done = True
|
|
await poller_task
|
|
await ws_manager.send_error(ws, calc_id, "Calculation cancelled")
|
|
return
|
|
|
|
# Stop poller and send final progress
|
|
_done = True
|
|
await poller_task
|
|
|
|
computation_time = time.time() - start_time
|
|
|
|
# Build response (identical format to HTTP endpoint)
|
|
rsrp_values = [p.rsrp for p in points]
|
|
los_count = sum(1 for p in points if p.has_los)
|
|
|
|
stats = {
|
|
"min_rsrp": min(rsrp_values) if rsrp_values else 0,
|
|
"max_rsrp": max(rsrp_values) if rsrp_values else 0,
|
|
"avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0,
|
|
"los_percentage": (los_count / len(points) * 100) if points else 0,
|
|
"points_with_buildings": sum(1 for p in points if p.building_loss > 0),
|
|
"points_with_terrain_loss": sum(1 for p in points if p.terrain_loss > 0),
|
|
"points_with_reflection_gain": sum(1 for p in points if p.reflection_gain > 0),
|
|
"points_with_vegetation_loss": sum(1 for p in points if p.vegetation_loss > 0),
|
|
"points_with_rain_loss": sum(1 for p in points if p.rain_loss > 0),
|
|
"points_with_indoor_loss": sum(1 for p in points if p.indoor_loss > 0),
|
|
"points_with_atmospheric_loss": sum(1 for p in points if p.atmospheric_loss > 0),
|
|
}
|
|
|
|
result = {
|
|
"points": [p.model_dump() for p in points],
|
|
"count": len(points),
|
|
"settings": effective_settings.model_dump(),
|
|
"stats": stats,
|
|
"computation_time": round(computation_time, 2),
|
|
"models_used": models_used,
|
|
}
|
|
|
|
# Send "Complete" before result so frontend shows 100%
|
|
await ws_manager.send_progress(ws, calc_id, "Complete", 1.0)
|
|
await ws_manager.send_result(ws, calc_id, result)
|
|
logger.info(f"[WS] calc={calc_id} done: {len(points)} pts, {computation_time:.1f}s")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[WS] Calculation error: {e}", exc_info=True)
|
|
_done = True
|
|
try:
|
|
await poller_task
|
|
except Exception:
|
|
pass
|
|
await ws_manager.send_error(ws, calc_id, str(e))
|
|
finally:
|
|
ws_manager._cancel_tokens.pop(calc_id, None)
|
|
|
|
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""WebSocket endpoint for coverage calculations with progress."""
|
|
await websocket.accept()
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "calculate":
|
|
calc_id = data.get("id", "")
|
|
asyncio.create_task(_run_calculation(websocket, calc_id, data))
|
|
|
|
elif msg_type == "cancel":
|
|
calc_id = data.get("id")
|
|
token = ws_manager._cancel_tokens.get(calc_id)
|
|
if token:
|
|
token.cancel()
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
for token in ws_manager._cancel_tokens.values():
|
|
token.cancel()
|
|
except Exception:
|
|
for token in ws_manager._cancel_tokens.values():
|
|
token.cancel()
|