Files
rfcp/backend/app/api/websocket.py

297 lines
11 KiB
Python

"""
WebSocket handler for real-time coverage calculation with progress.
Uses the same coverage_service pipeline as the HTTP endpoint but sends
progress updates during computation phases.
"""
import time
import asyncio
import logging
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect
from app.services.coverage_service import (
coverage_service, SiteParams, CoverageSettings, apply_preset,
select_propagation_model,
)
from app.services.parallel_coverage_service import CancellationToken
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Track cancellation tokens per calculation."""
def __init__(self):
self._cancel_tokens: dict[str, CancellationToken] = {}
async def send_progress(
self, ws: WebSocket, calc_id: str,
phase: str, progress: float, eta: Optional[float] = None,
):
try:
await ws.send_json({
"type": "progress",
"calculation_id": calc_id,
"phase": phase,
"progress": min(progress, 1.0),
"eta_seconds": eta,
})
except Exception as e:
logger.debug(f"[WS] send_progress failed: {e}")
async def send_result(self, ws: WebSocket, calc_id: str, result: dict):
try:
await ws.send_json({
"type": "result",
"calculation_id": calc_id,
"data": result,
})
except Exception as e:
logger.warning(f"[WS] send_result failed: {e}")
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
try:
await ws.send_json({
"type": "error",
"calculation_id": calc_id,
"message": error,
})
except Exception as e:
logger.warning(f"[WS] send_error failed: {e}")
async def send_partial_results(
self, ws: WebSocket, calc_id: str,
points: list, tile_idx: int, total_tiles: int,
):
"""Send per-tile partial results for progressive rendering."""
try:
await ws.send_json({
"type": "partial_results",
"calculation_id": calc_id,
"points": [p.model_dump() for p in points],
"tile": tile_idx,
"total_tiles": total_tiles,
"progress": (tile_idx + 1) / total_tiles,
})
except Exception as e:
logger.debug(f"[WS] send_partial_results failed: {e}")
ws_manager = ConnectionManager()
async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
"""Run coverage calculation with progress updates via WebSocket."""
cancel_token = CancellationToken()
ws_manager._cancel_tokens[calc_id] = cancel_token
# Shared progress state — written by worker threads, polled by event loop.
# Python GIL makes dict value assignment atomic for simple types.
_progress = {"phase": "Initializing", "pct": 0.0, "seq": 0}
_done = False
# Get event loop for cross-thread scheduling of WS sends.
loop = asyncio.get_running_loop()
_last_direct_pct = 0.0
_last_direct_phase = ""
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
"""Thread-safe progress callback — updates dict AND schedules direct WS send."""
nonlocal _last_direct_pct, _last_direct_phase
_progress["phase"] = phase
_progress["pct"] = pct
_progress["seq"] += 1
# Schedule direct WS send via event loop (works from any thread).
# Throttle: only send on phase change or >=2% progress.
if phase != _last_direct_phase or pct - _last_direct_pct >= 0.02:
_last_direct_pct = pct
_last_direct_phase = phase
try:
loop.call_soon_threadsafe(
asyncio.ensure_future,
ws_manager.send_progress(ws, calc_id, phase, pct),
)
except RuntimeError:
pass # Event loop closed
try:
sites_data = data.get("sites", [])
settings_data = data.get("settings", {})
if not sites_data:
await ws_manager.send_error(ws, calc_id, "At least one site required")
return
if len(sites_data) > 10:
await ws_manager.send_error(ws, calc_id, "Maximum 10 sites per request")
return
# Parse sites and settings (same format as HTTP endpoint)
sites = [SiteParams(**s) for s in sites_data]
settings = CoverageSettings(**settings_data)
if settings.radius > 50000:
await ws_manager.send_error(ws, calc_id, "Maximum radius 50km")
return
if settings.resolution < 50:
await ws_manager.send_error(ws, calc_id, "Minimum resolution 50m")
return
effective_settings = apply_preset(settings.model_copy())
# Determine models used
from app.api.routes.coverage import _get_active_models
models_used = _get_active_models(effective_settings)
env = getattr(effective_settings, 'environment', 'urban')
primary_model = select_propagation_model(sites[0].frequency, env)
if primary_model.name not in models_used:
models_used.insert(0, primary_model.name)
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.02)
# ── Tile callback for progressive results (large radius) ──
async def _tile_callback(tile_points, tile_idx, total_tiles):
await ws_manager.send_partial_results(
ws, calc_id, tile_points, tile_idx, total_tiles,
)
# ── Backup progress poller: catches anything call_soon_threadsafe missed ──
async def progress_poller():
last_sent_seq = 0
last_sent_pct = 0.0
last_sent_phase = "Initializing"
while not _done:
await asyncio.sleep(0.5)
seq = _progress["seq"]
pct = _progress["pct"]
phase = _progress["phase"]
# Send on any phase change OR >=3% progress (primary sends handle fine-grained)
if seq != last_sent_seq and (
phase != last_sent_phase
or pct - last_sent_pct >= 0.03
):
await ws_manager.send_progress(ws, calc_id, phase, pct)
last_sent_seq = seq
last_sent_pct = pct
last_sent_phase = phase
poller_task = asyncio.create_task(progress_poller())
# Run calculation with timeout
start_time = time.time()
try:
if len(sites) == 1:
points = await asyncio.wait_for(
coverage_service.calculate_coverage(
sites[0], settings, cancel_token,
progress_fn=sync_progress_fn,
tile_callback=_tile_callback,
),
timeout=300.0,
)
else:
points = await asyncio.wait_for(
coverage_service.calculate_multi_site_coverage(
sites, settings, cancel_token,
progress_fn=sync_progress_fn,
tile_callback=_tile_callback,
),
timeout=300.0,
)
except asyncio.TimeoutError:
cancel_token.cancel()
_done = True
await poller_task
from app.services.parallel_coverage_service import _kill_worker_processes
_kill_worker_processes()
await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)")
return
except asyncio.CancelledError:
cancel_token.cancel()
_done = True
await poller_task
await ws_manager.send_error(ws, calc_id, "Calculation cancelled")
return
# Stop poller and send final progress
_done = True
await poller_task
computation_time = time.time() - start_time
# Build response (identical format to HTTP endpoint)
rsrp_values = [p.rsrp for p in points]
los_count = sum(1 for p in points if p.has_los)
stats = {
"min_rsrp": min(rsrp_values) if rsrp_values else 0,
"max_rsrp": max(rsrp_values) if rsrp_values else 0,
"avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0,
"los_percentage": (los_count / len(points) * 100) if points else 0,
"points_with_buildings": sum(1 for p in points if p.building_loss > 0),
"points_with_terrain_loss": sum(1 for p in points if p.terrain_loss > 0),
"points_with_reflection_gain": sum(1 for p in points if p.reflection_gain > 0),
"points_with_vegetation_loss": sum(1 for p in points if p.vegetation_loss > 0),
"points_with_rain_loss": sum(1 for p in points if p.rain_loss > 0),
"points_with_indoor_loss": sum(1 for p in points if p.indoor_loss > 0),
"points_with_atmospheric_loss": sum(1 for p in points if p.atmospheric_loss > 0),
}
result = {
"points": [p.model_dump() for p in points],
"count": len(points),
"settings": effective_settings.model_dump(),
"stats": stats,
"computation_time": round(computation_time, 2),
"models_used": models_used,
}
# Send "Complete" before result so frontend shows 100%
await ws_manager.send_progress(ws, calc_id, "Complete", 1.0)
await ws_manager.send_result(ws, calc_id, result)
logger.info(f"[WS] calc={calc_id} done: {len(points)} pts, {computation_time:.1f}s")
except Exception as e:
logger.error(f"[WS] Calculation error: {e}", exc_info=True)
_done = True
try:
await poller_task
except Exception:
pass
await ws_manager.send_error(ws, calc_id, str(e))
finally:
ws_manager._cancel_tokens.pop(calc_id, None)
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for coverage calculations with progress."""
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "calculate":
calc_id = data.get("id", "")
asyncio.create_task(_run_calculation(websocket, calc_id, data))
elif msg_type == "cancel":
calc_id = data.get("id")
token = ws_manager._cancel_tokens.get(calc_id)
if token:
token.cancel()
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
for token in ws_manager._cancel_tokens.values():
token.cancel()
except Exception:
for token in ws_manager._cancel_tokens.values():
token.cancel()