Files
rfcp/backend/app/api/websocket.py
2026-02-02 01:55:09 +02:00

246 lines
8.7 KiB
Python

"""
WebSocket handler for real-time coverage calculation with progress.
Uses the same coverage_service pipeline as the HTTP endpoint but sends
progress updates during computation phases.
"""
import time
import asyncio
import logging
import threading
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect
from app.services.coverage_service import (
coverage_service, SiteParams, CoverageSettings, apply_preset,
select_propagation_model,
)
from app.services.parallel_coverage_service import CancellationToken
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Track cancellation tokens per calculation."""
def __init__(self):
self._cancel_tokens: dict[str, CancellationToken] = {}
async def send_progress(
self, ws: WebSocket, calc_id: str,
phase: str, progress: float, eta: Optional[float] = None,
):
try:
await ws.send_json({
"type": "progress",
"calculation_id": calc_id,
"phase": phase,
"progress": min(progress, 1.0),
"eta_seconds": eta,
})
except Exception as e:
logger.debug(f"[WS] send_progress failed: {e}")
async def send_result(self, ws: WebSocket, calc_id: str, result: dict):
try:
await ws.send_json({
"type": "result",
"calculation_id": calc_id,
"data": result,
})
except Exception as e:
logger.debug(f"[WS] send_result failed: {e}")
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
try:
await ws.send_json({
"type": "error",
"calculation_id": calc_id,
"message": error,
})
except Exception as e:
logger.debug(f"[WS] send_error failed: {e}")
ws_manager = ConnectionManager()
async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
"""Run coverage calculation with progress updates via WebSocket."""
cancel_token = CancellationToken()
ws_manager._cancel_tokens[calc_id] = cancel_token
# Shared progress state — written by worker threads, polled by event loop.
# Python GIL makes dict value assignment atomic for simple types.
_progress = {"phase": "Initializing", "pct": 0.05, "seq": 0}
_done = False
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
"""Thread-safe progress callback — just updates a shared dict."""
_progress["phase"] = phase
_progress["pct"] = pct
_progress["seq"] += 1
try:
sites_data = data.get("sites", [])
settings_data = data.get("settings", {})
if not sites_data:
await ws_manager.send_error(ws, calc_id, "At least one site required")
return
if len(sites_data) > 10:
await ws_manager.send_error(ws, calc_id, "Maximum 10 sites per request")
return
# Parse sites and settings (same format as HTTP endpoint)
sites = [SiteParams(**s) for s in sites_data]
settings = CoverageSettings(**settings_data)
if settings.radius > 50000:
await ws_manager.send_error(ws, calc_id, "Maximum radius 50km")
return
if settings.resolution < 50:
await ws_manager.send_error(ws, calc_id, "Minimum resolution 50m")
return
effective_settings = apply_preset(settings.model_copy())
# Determine models used
from app.api.routes.coverage import _get_active_models
models_used = _get_active_models(effective_settings)
env = getattr(effective_settings, 'environment', 'urban')
primary_model = select_propagation_model(sites[0].frequency, env)
if primary_model.name not in models_used:
models_used.insert(0, primary_model.name)
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05)
# ── Progress poller: reads shared dict and sends WS updates ──
async def progress_poller():
last_sent_seq = 0
last_sent_pct = 0.0
while not _done:
await asyncio.sleep(0.3)
seq = _progress["seq"]
pct = _progress["pct"]
phase = _progress["phase"]
if seq != last_sent_seq and (pct - last_sent_pct >= 0.01 or phase != "Calculating coverage"):
await ws_manager.send_progress(ws, calc_id, phase, pct)
last_sent_seq = seq
last_sent_pct = pct
poller_task = asyncio.create_task(progress_poller())
# Run calculation with timeout
start_time = time.time()
try:
if len(sites) == 1:
points = await asyncio.wait_for(
coverage_service.calculate_coverage(
sites[0], settings, cancel_token,
progress_fn=sync_progress_fn,
),
timeout=300.0,
)
else:
points = await asyncio.wait_for(
coverage_service.calculate_multi_site_coverage(
sites, settings, cancel_token,
),
timeout=300.0,
)
except asyncio.TimeoutError:
cancel_token.cancel()
_done = True
await poller_task
from app.services.parallel_coverage_service import _kill_worker_processes
_kill_worker_processes()
await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)")
return
except asyncio.CancelledError:
cancel_token.cancel()
_done = True
await poller_task
await ws_manager.send_error(ws, calc_id, "Calculation cancelled")
return
# Stop poller and send final progress
_done = True
await poller_task
await ws_manager.send_progress(ws, calc_id, "Finalizing", 0.98)
computation_time = time.time() - start_time
# Build response (identical format to HTTP endpoint)
rsrp_values = [p.rsrp for p in points]
los_count = sum(1 for p in points if p.has_los)
stats = {
"min_rsrp": min(rsrp_values) if rsrp_values else 0,
"max_rsrp": max(rsrp_values) if rsrp_values else 0,
"avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0,
"los_percentage": (los_count / len(points) * 100) if points else 0,
"points_with_buildings": sum(1 for p in points if p.building_loss > 0),
"points_with_terrain_loss": sum(1 for p in points if p.terrain_loss > 0),
"points_with_reflection_gain": sum(1 for p in points if p.reflection_gain > 0),
"points_with_vegetation_loss": sum(1 for p in points if p.vegetation_loss > 0),
"points_with_rain_loss": sum(1 for p in points if p.rain_loss > 0),
"points_with_indoor_loss": sum(1 for p in points if p.indoor_loss > 0),
"points_with_atmospheric_loss": sum(1 for p in points if p.atmospheric_loss > 0),
}
result = {
"points": [p.model_dump() for p in points],
"count": len(points),
"settings": effective_settings.model_dump(),
"stats": stats,
"computation_time": round(computation_time, 2),
"models_used": models_used,
}
await ws_manager.send_result(ws, calc_id, result)
except Exception as e:
logger.error(f"[WS] Calculation error: {e}", exc_info=True)
_done = True
try:
await poller_task
except Exception:
pass
await ws_manager.send_error(ws, calc_id, str(e))
finally:
ws_manager._cancel_tokens.pop(calc_id, None)
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for coverage calculations with progress."""
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "calculate":
calc_id = data.get("id", "")
asyncio.create_task(_run_calculation(websocket, calc_id, data))
elif msg_type == "cancel":
calc_id = data.get("id")
token = ws_manager._cancel_tokens.get(calc_id)
if token:
token.cancel()
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
for token in ws_manager._cancel_tokens.values():
token.cancel()
except Exception:
for token in ws_manager._cancel_tokens.values():
token.cancel()