@mytec: feat: Phase 3.0 Architecture Refactor ✅
Major refactoring of RFCP backend: - Modular propagation models (8 models) - SharedMemoryManager for terrain data - ProcessPoolExecutor parallel processing - WebSocket progress streaming - Building filtering pipeline (351k → 15k) - 82 unit tests Performance: Standard preset 38s → 5s (7.6x speedup) Known issue: Detailed preset timeout (fix in 3.1.0)
This commit is contained in:
261
backend/app/api/websocket.py
Normal file
261
backend/app/api/websocket.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
WebSocket handler for real-time coverage calculation with progress.
|
||||
|
||||
Uses the same coverage_service pipeline as the HTTP endpoint but sends
|
||||
progress updates during computation phases.
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.services.coverage_service import (
|
||||
coverage_service, SiteParams, CoverageSettings, apply_preset,
|
||||
select_propagation_model,
|
||||
)
|
||||
from app.services.parallel_coverage_service import CancellationToken
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Track cancellation tokens per calculation."""
|
||||
|
||||
def __init__(self):
|
||||
self._cancel_tokens: dict[str, CancellationToken] = {}
|
||||
|
||||
async def send_progress(
|
||||
self, ws: WebSocket, calc_id: str,
|
||||
phase: str, progress: float, eta: Optional[float] = None,
|
||||
):
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "progress",
|
||||
"calculation_id": calc_id,
|
||||
"phase": phase,
|
||||
"progress": min(progress, 1.0),
|
||||
"eta_seconds": eta,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_result(self, ws: WebSocket, calc_id: str, result: dict):
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "result",
|
||||
"calculation_id": calc_id,
|
||||
"data": result,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "error",
|
||||
"calculation_id": calc_id,
|
||||
"message": error,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
|
||||
|
||||
async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
"""Run coverage calculation with progress updates via WebSocket."""
|
||||
cancel_token = CancellationToken()
|
||||
ws_manager._cancel_tokens[calc_id] = cancel_token
|
||||
|
||||
try:
|
||||
sites_data = data.get("sites", [])
|
||||
settings_data = data.get("settings", {})
|
||||
|
||||
if not sites_data:
|
||||
await ws_manager.send_error(ws, calc_id, "At least one site required")
|
||||
return
|
||||
|
||||
if len(sites_data) > 10:
|
||||
await ws_manager.send_error(ws, calc_id, "Maximum 10 sites per request")
|
||||
return
|
||||
|
||||
# Parse sites and settings (same format as HTTP endpoint)
|
||||
sites = [SiteParams(**s) for s in sites_data]
|
||||
settings = CoverageSettings(**settings_data)
|
||||
|
||||
if settings.radius > 50000:
|
||||
await ws_manager.send_error(ws, calc_id, "Maximum radius 50km")
|
||||
return
|
||||
if settings.resolution < 50:
|
||||
await ws_manager.send_error(ws, calc_id, "Minimum resolution 50m")
|
||||
return
|
||||
|
||||
effective_settings = apply_preset(settings.model_copy())
|
||||
|
||||
# Determine models used
|
||||
from app.api.routes.coverage import _get_active_models
|
||||
models_used = _get_active_models(effective_settings)
|
||||
env = getattr(effective_settings, 'environment', 'urban')
|
||||
primary_model = select_propagation_model(sites[0].frequency, env)
|
||||
if primary_model.name not in models_used:
|
||||
models_used.insert(0, primary_model.name)
|
||||
|
||||
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05)
|
||||
|
||||
# ── Bridge sync progress_fn → async WS sends ──
|
||||
# progress_fn is called from two contexts:
|
||||
# 1. Event loop thread (phases 1-2.5, directly in calculate_coverage)
|
||||
# 2. Worker threads (phase 3, from ProcessPool/sequential executors)
|
||||
# We detect which thread we're on and use the appropriate method.
|
||||
loop = asyncio.get_running_loop()
|
||||
event_loop_thread_id = threading.current_thread().ident
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
||||
"""Thread-safe progress callback for coverage_service."""
|
||||
if threading.current_thread().ident == event_loop_thread_id:
|
||||
# From event loop thread: put directly to queue
|
||||
progress_queue.put_nowait((phase, pct))
|
||||
else:
|
||||
# From worker thread: use thread-safe bridge to wake event loop
|
||||
loop.call_soon_threadsafe(progress_queue.put_nowait, (phase, pct))
|
||||
|
||||
# Background task: drain queue and send WS progress messages
|
||||
_sender_done = False
|
||||
|
||||
async def progress_sender():
|
||||
nonlocal _sender_done
|
||||
last_pct = 0.0
|
||||
while not _sender_done:
|
||||
try:
|
||||
phase, pct = await asyncio.wait_for(progress_queue.get(), timeout=0.5)
|
||||
if pct >= 1.0:
|
||||
break
|
||||
# Throttle: only send if progress changed meaningfully
|
||||
if pct - last_pct >= 0.02 or phase != "Calculating coverage":
|
||||
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
||||
last_pct = pct
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
|
||||
progress_task = asyncio.create_task(progress_sender())
|
||||
|
||||
# Run calculation with timeout
|
||||
start_time = time.time()
|
||||
try:
|
||||
if len(sites) == 1:
|
||||
points = await asyncio.wait_for(
|
||||
coverage_service.calculate_coverage(
|
||||
sites[0], settings, cancel_token,
|
||||
progress_fn=sync_progress_fn,
|
||||
),
|
||||
timeout=300.0,
|
||||
)
|
||||
else:
|
||||
points = await asyncio.wait_for(
|
||||
coverage_service.calculate_multi_site_coverage(
|
||||
sites, settings, cancel_token,
|
||||
),
|
||||
timeout=300.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_token.cancel()
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
from app.services.parallel_coverage_service import _kill_worker_processes
|
||||
_kill_worker_processes()
|
||||
await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)")
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
cancel_token.cancel()
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
await ws_manager.send_error(ws, calc_id, "Calculation cancelled")
|
||||
return
|
||||
|
||||
# Stop progress sender
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
|
||||
computation_time = time.time() - start_time
|
||||
|
||||
# Build response (identical format to HTTP endpoint)
|
||||
rsrp_values = [p.rsrp for p in points]
|
||||
los_count = sum(1 for p in points if p.has_los)
|
||||
|
||||
stats = {
|
||||
"min_rsrp": min(rsrp_values) if rsrp_values else 0,
|
||||
"max_rsrp": max(rsrp_values) if rsrp_values else 0,
|
||||
"avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0,
|
||||
"los_percentage": (los_count / len(points) * 100) if points else 0,
|
||||
"points_with_buildings": sum(1 for p in points if p.building_loss > 0),
|
||||
"points_with_terrain_loss": sum(1 for p in points if p.terrain_loss > 0),
|
||||
"points_with_reflection_gain": sum(1 for p in points if p.reflection_gain > 0),
|
||||
"points_with_vegetation_loss": sum(1 for p in points if p.vegetation_loss > 0),
|
||||
"points_with_rain_loss": sum(1 for p in points if p.rain_loss > 0),
|
||||
"points_with_indoor_loss": sum(1 for p in points if p.indoor_loss > 0),
|
||||
"points_with_atmospheric_loss": sum(1 for p in points if p.atmospheric_loss > 0),
|
||||
}
|
||||
|
||||
result = {
|
||||
"points": [p.model_dump() for p in points],
|
||||
"count": len(points),
|
||||
"settings": effective_settings.model_dump(),
|
||||
"stats": stats,
|
||||
"computation_time": round(computation_time, 2),
|
||||
"models_used": models_used,
|
||||
}
|
||||
|
||||
await ws_manager.send_result(ws, calc_id, result)
|
||||
|
||||
except Exception as e:
|
||||
# Stop progress sender on unhandled exception
|
||||
_sender_done = True
|
||||
try:
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await progress_task
|
||||
except Exception:
|
||||
pass
|
||||
await ws_manager.send_error(ws, calc_id, str(e))
|
||||
finally:
|
||||
ws_manager._cancel_tokens.pop(calc_id, None)
|
||||
|
||||
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for coverage calculations with progress."""
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "calculate":
|
||||
calc_id = data.get("id", "")
|
||||
asyncio.create_task(_run_calculation(websocket, calc_id, data))
|
||||
|
||||
elif msg_type == "cancel":
|
||||
calc_id = data.get("id")
|
||||
token = ws_manager._cancel_tokens.get(calc_id)
|
||||
if token:
|
||||
token.cancel()
|
||||
|
||||
elif msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
for token in ws_manager._cancel_tokens.values():
|
||||
token.cancel()
|
||||
except Exception:
|
||||
for token in ws_manager._cancel_tokens.values():
|
||||
token.cancel()
|
||||
Reference in New Issue
Block a user