Files
rfcp/backend/app/services/parallel_coverage_service.py

251 lines
8.2 KiB
Python

"""
Parallel coverage calculation using ProcessPoolExecutor.
Workers receive pre-loaded terrain cache, buildings, and OSM data
via a shared pickle file. Each worker initializes module-level
service singletons with the cached data, then processes point chunks.
Usage:
from app.services.parallel_coverage_service import calculate_coverage_parallel
"""
import os
import sys
import time
import pickle
import tempfile
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from typing import List, Dict, Tuple, Any, Optional, Callable
import numpy as np
# ── Module-level worker state (set once per process by _init_worker) ──
_worker_data: Dict[str, Any] = {}
_worker_initialized = False
def _init_worker(shared_data_path: str):
"""Initialize a worker process with shared data from temp file.
Injects terrain cache into the module-level terrain_service singleton
so that all other services (LOS, dominant path, etc.) automatically
see the cached tiles.
"""
global _worker_data, _worker_initialized
if _worker_initialized:
return
t0 = time.time()
pid = os.getpid()
# Load shared data
with open(shared_data_path, 'rb') as f:
data = pickle.load(f)
# Inject terrain cache into the global singleton —
# this automatically fixes los_service, dominant_path_service, etc.
# because they hold references to the same terrain_service object.
from app.services.terrain_service import terrain_service
terrain_service._tile_cache = data['terrain_cache']
# Build spatial index from buildings
from app.services.spatial_index import SpatialIndex
spatial_idx = SpatialIndex()
if data['buildings']:
spatial_idx.build(data['buildings'])
_worker_data = {
'buildings': data['buildings'],
'streets': data['streets'],
'water_bodies': data['water_bodies'],
'vegetation_areas': data['vegetation_areas'],
'spatial_idx': spatial_idx,
'site_dict': data['site_dict'],
'settings_dict': data['settings_dict'],
'site_elevation': data['site_elevation'],
}
_worker_initialized = True
dt = time.time() - t0
print(f"[WORKER {pid}] Initialized in {dt:.1f}s — "
f"{len(data['terrain_cache'])} tiles, "
f"{len(data['buildings'])} buildings, "
f"{len(data.get('vegetation_areas', []))} vegetation",
flush=True)
def _process_chunk(chunk: List[Tuple[float, float, float]]) -> List[Dict]:
"""Process a chunk of (lat, lon, point_elevation) tuples.
Returns list of CoveragePoint dicts for points above min_signal.
"""
from app.services.coverage_service import CoverageService, SiteParams, CoverageSettings
data = _worker_data
site = SiteParams(**data['site_dict'])
settings = CoverageSettings(**data['settings_dict'])
svc = CoverageService()
timing = {
"los": 0.0, "buildings": 0.0, "antenna": 0.0,
"dominant_path": 0.0, "street_canyon": 0.0,
"reflection": 0.0, "vegetation": 0.0,
}
results = []
for lat, lon, point_elev in chunk:
point = svc._calculate_point_sync(
site, lat, lon, settings,
data['buildings'], data['streets'],
data['spatial_idx'], data['water_bodies'],
data['vegetation_areas'],
data['site_elevation'], point_elev, timing,
)
if point.rsrp >= settings.min_signal:
results.append(point.model_dump())
return results
# ── Public API ──
def get_cpu_count() -> int:
"""Get number of usable CPU cores, capped at 14."""
try:
return min(mp.cpu_count() or 4, 14)
except Exception:
return 4
def calculate_coverage_parallel(
grid: List[Tuple[float, float]],
point_elevations: Dict[Tuple[float, float], float],
site_dict: Dict,
settings_dict: Dict,
terrain_cache: Dict[str, np.ndarray],
buildings: List,
streets: List,
water_bodies: List,
vegetation_areas: List,
site_elevation: float,
num_workers: Optional[int] = None,
log_fn: Optional[Callable[[str], None]] = None,
) -> Tuple[List[Dict], Dict[str, float]]:
"""Calculate coverage points in parallel using ProcessPoolExecutor.
Args:
grid: List of (lat, lon) tuples.
point_elevations: Pre-computed {(lat, lon): elevation} dict.
site_dict: SiteParams as a dict (for pickling).
settings_dict: CoverageSettings as a dict (for pickling).
terrain_cache: {tile_name: np.ndarray} — pre-loaded SRTM tiles.
buildings, streets, water_bodies, vegetation_areas: OSM data.
site_elevation: Elevation at site location (meters).
num_workers: Override worker count (default: auto-detect).
log_fn: Logging function (receives string messages).
Returns:
(results, timing) where results is list of CoveragePoint dicts.
"""
if log_fn is None:
log_fn = lambda msg: print(f"[PARALLEL] {msg}", flush=True)
if num_workers is None:
num_workers = get_cpu_count()
total_points = len(grid)
log_fn(f"Parallel mode: {total_points} points, {num_workers} workers")
# Prepare items with pre-computed elevations
items = [
(lat, lon, point_elevations.get((lat, lon), 0.0))
for lat, lon in grid
]
# Split into chunks — ~4 chunks per worker for granular progress
chunks_per_worker = 4
chunk_size = max(1, len(items) // (num_workers * chunks_per_worker))
chunks = [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
log_fn(f"Split into {len(chunks)} chunks of ~{chunk_size} points")
# ── Serialize shared data to temp file (once, not per-worker) ──
t_serial = time.time()
shared_data = {
'terrain_cache': terrain_cache,
'buildings': buildings,
'streets': streets,
'water_bodies': water_bodies,
'vegetation_areas': vegetation_areas,
'site_dict': site_dict,
'settings_dict': settings_dict,
'site_elevation': site_elevation,
}
tmpfile = tempfile.NamedTemporaryFile(delete=False, suffix='.pkl')
try:
pickle.dump(shared_data, tmpfile, protocol=pickle.HIGHEST_PROTOCOL)
finally:
tmpfile.close()
shared_data_path = tmpfile.name
file_size_mb = os.path.getsize(shared_data_path) / (1024 * 1024)
serial_time = time.time() - t_serial
log_fn(f"Serialized shared data: {file_size_mb:.1f}MB in {serial_time:.1f}s")
# Free main-process memory for the duplicate
del shared_data
# ── Run in process pool ──
t_calc = time.time()
all_results: List[Dict] = []
completed_points = 0
try:
with ProcessPoolExecutor(
max_workers=num_workers,
initializer=_init_worker,
initargs=(shared_data_path,),
) as executor:
futures = [executor.submit(_process_chunk, chunk) for chunk in chunks]
for i, future in enumerate(futures):
try:
chunk_results = future.result(timeout=600) # 10 min max per chunk
all_results.extend(chunk_results)
except Exception as e:
log_fn(f"Chunk {i} failed: {e}")
completed_points += len(chunks[i])
pct = min(100, completed_points * 100 // total_points)
elapsed = time.time() - t_calc
rate = completed_points / elapsed if elapsed > 0 else 0
# Log every ~10% or on last chunk
if (i + 1) % max(1, len(chunks) // 10) == 0 or i == len(chunks) - 1:
eta = (total_points - completed_points) / rate if rate > 0 else 0
log_fn(f"Progress: {completed_points}/{total_points} ({pct}%) — "
f"{rate:.0f} pts/s, ETA {eta:.0f}s")
finally:
# Clean up temp file
try:
os.unlink(shared_data_path)
except Exception:
pass
calc_time = time.time() - t_calc
log_fn(f"Parallel done: {calc_time:.1f}s, {len(all_results)} results "
f"({calc_time / max(1, total_points) * 1000:.1f}ms/point)")
timing = {
"parallel_total": calc_time,
"serialize": serial_time,
"workers": num_workers,
}
return all_results, timing