251 lines
8.2 KiB
Python
251 lines
8.2 KiB
Python
"""
|
|
Parallel coverage calculation using ProcessPoolExecutor.
|
|
|
|
Workers receive pre-loaded terrain cache, buildings, and OSM data
|
|
via a shared pickle file. Each worker initializes module-level
|
|
service singletons with the cached data, then processes point chunks.
|
|
|
|
Usage:
|
|
from app.services.parallel_coverage_service import calculate_coverage_parallel
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import pickle
|
|
import tempfile
|
|
import multiprocessing as mp
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from typing import List, Dict, Tuple, Any, Optional, Callable
|
|
import numpy as np
|
|
|
|
|
|
# ── Module-level worker state (set once per process by _init_worker) ──
|
|
|
|
_worker_data: Dict[str, Any] = {}
|
|
_worker_initialized = False
|
|
|
|
|
|
def _init_worker(shared_data_path: str):
|
|
"""Initialize a worker process with shared data from temp file.
|
|
|
|
Injects terrain cache into the module-level terrain_service singleton
|
|
so that all other services (LOS, dominant path, etc.) automatically
|
|
see the cached tiles.
|
|
"""
|
|
global _worker_data, _worker_initialized
|
|
|
|
if _worker_initialized:
|
|
return
|
|
|
|
t0 = time.time()
|
|
pid = os.getpid()
|
|
|
|
# Load shared data
|
|
with open(shared_data_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
# Inject terrain cache into the global singleton —
|
|
# this automatically fixes los_service, dominant_path_service, etc.
|
|
# because they hold references to the same terrain_service object.
|
|
from app.services.terrain_service import terrain_service
|
|
terrain_service._tile_cache = data['terrain_cache']
|
|
|
|
# Build spatial index from buildings
|
|
from app.services.spatial_index import SpatialIndex
|
|
spatial_idx = SpatialIndex()
|
|
if data['buildings']:
|
|
spatial_idx.build(data['buildings'])
|
|
|
|
_worker_data = {
|
|
'buildings': data['buildings'],
|
|
'streets': data['streets'],
|
|
'water_bodies': data['water_bodies'],
|
|
'vegetation_areas': data['vegetation_areas'],
|
|
'spatial_idx': spatial_idx,
|
|
'site_dict': data['site_dict'],
|
|
'settings_dict': data['settings_dict'],
|
|
'site_elevation': data['site_elevation'],
|
|
}
|
|
|
|
_worker_initialized = True
|
|
dt = time.time() - t0
|
|
print(f"[WORKER {pid}] Initialized in {dt:.1f}s — "
|
|
f"{len(data['terrain_cache'])} tiles, "
|
|
f"{len(data['buildings'])} buildings, "
|
|
f"{len(data.get('vegetation_areas', []))} vegetation",
|
|
flush=True)
|
|
|
|
|
|
def _process_chunk(chunk: List[Tuple[float, float, float]]) -> List[Dict]:
|
|
"""Process a chunk of (lat, lon, point_elevation) tuples.
|
|
|
|
Returns list of CoveragePoint dicts for points above min_signal.
|
|
"""
|
|
from app.services.coverage_service import CoverageService, SiteParams, CoverageSettings
|
|
|
|
data = _worker_data
|
|
site = SiteParams(**data['site_dict'])
|
|
settings = CoverageSettings(**data['settings_dict'])
|
|
|
|
svc = CoverageService()
|
|
|
|
timing = {
|
|
"los": 0.0, "buildings": 0.0, "antenna": 0.0,
|
|
"dominant_path": 0.0, "street_canyon": 0.0,
|
|
"reflection": 0.0, "vegetation": 0.0,
|
|
}
|
|
|
|
results = []
|
|
for lat, lon, point_elev in chunk:
|
|
point = svc._calculate_point_sync(
|
|
site, lat, lon, settings,
|
|
data['buildings'], data['streets'],
|
|
data['spatial_idx'], data['water_bodies'],
|
|
data['vegetation_areas'],
|
|
data['site_elevation'], point_elev, timing,
|
|
)
|
|
if point.rsrp >= settings.min_signal:
|
|
results.append(point.model_dump())
|
|
|
|
return results
|
|
|
|
|
|
# ── Public API ──
|
|
|
|
|
|
def get_cpu_count() -> int:
|
|
"""Get number of usable CPU cores, capped at 14."""
|
|
try:
|
|
return min(mp.cpu_count() or 4, 14)
|
|
except Exception:
|
|
return 4
|
|
|
|
|
|
def calculate_coverage_parallel(
|
|
grid: List[Tuple[float, float]],
|
|
point_elevations: Dict[Tuple[float, float], float],
|
|
site_dict: Dict,
|
|
settings_dict: Dict,
|
|
terrain_cache: Dict[str, np.ndarray],
|
|
buildings: List,
|
|
streets: List,
|
|
water_bodies: List,
|
|
vegetation_areas: List,
|
|
site_elevation: float,
|
|
num_workers: Optional[int] = None,
|
|
log_fn: Optional[Callable[[str], None]] = None,
|
|
) -> Tuple[List[Dict], Dict[str, float]]:
|
|
"""Calculate coverage points in parallel using ProcessPoolExecutor.
|
|
|
|
Args:
|
|
grid: List of (lat, lon) tuples.
|
|
point_elevations: Pre-computed {(lat, lon): elevation} dict.
|
|
site_dict: SiteParams as a dict (for pickling).
|
|
settings_dict: CoverageSettings as a dict (for pickling).
|
|
terrain_cache: {tile_name: np.ndarray} — pre-loaded SRTM tiles.
|
|
buildings, streets, water_bodies, vegetation_areas: OSM data.
|
|
site_elevation: Elevation at site location (meters).
|
|
num_workers: Override worker count (default: auto-detect).
|
|
log_fn: Logging function (receives string messages).
|
|
|
|
Returns:
|
|
(results, timing) where results is list of CoveragePoint dicts.
|
|
"""
|
|
if log_fn is None:
|
|
log_fn = lambda msg: print(f"[PARALLEL] {msg}", flush=True)
|
|
|
|
if num_workers is None:
|
|
num_workers = get_cpu_count()
|
|
|
|
total_points = len(grid)
|
|
log_fn(f"Parallel mode: {total_points} points, {num_workers} workers")
|
|
|
|
# Prepare items with pre-computed elevations
|
|
items = [
|
|
(lat, lon, point_elevations.get((lat, lon), 0.0))
|
|
for lat, lon in grid
|
|
]
|
|
|
|
# Split into chunks — ~4 chunks per worker for granular progress
|
|
chunks_per_worker = 4
|
|
chunk_size = max(1, len(items) // (num_workers * chunks_per_worker))
|
|
chunks = [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
|
|
log_fn(f"Split into {len(chunks)} chunks of ~{chunk_size} points")
|
|
|
|
# ── Serialize shared data to temp file (once, not per-worker) ──
|
|
t_serial = time.time()
|
|
shared_data = {
|
|
'terrain_cache': terrain_cache,
|
|
'buildings': buildings,
|
|
'streets': streets,
|
|
'water_bodies': water_bodies,
|
|
'vegetation_areas': vegetation_areas,
|
|
'site_dict': site_dict,
|
|
'settings_dict': settings_dict,
|
|
'site_elevation': site_elevation,
|
|
}
|
|
|
|
tmpfile = tempfile.NamedTemporaryFile(delete=False, suffix='.pkl')
|
|
try:
|
|
pickle.dump(shared_data, tmpfile, protocol=pickle.HIGHEST_PROTOCOL)
|
|
finally:
|
|
tmpfile.close()
|
|
|
|
shared_data_path = tmpfile.name
|
|
file_size_mb = os.path.getsize(shared_data_path) / (1024 * 1024)
|
|
serial_time = time.time() - t_serial
|
|
log_fn(f"Serialized shared data: {file_size_mb:.1f}MB in {serial_time:.1f}s")
|
|
|
|
# Free main-process memory for the duplicate
|
|
del shared_data
|
|
|
|
# ── Run in process pool ──
|
|
t_calc = time.time()
|
|
all_results: List[Dict] = []
|
|
completed_points = 0
|
|
|
|
try:
|
|
with ProcessPoolExecutor(
|
|
max_workers=num_workers,
|
|
initializer=_init_worker,
|
|
initargs=(shared_data_path,),
|
|
) as executor:
|
|
futures = [executor.submit(_process_chunk, chunk) for chunk in chunks]
|
|
|
|
for i, future in enumerate(futures):
|
|
try:
|
|
chunk_results = future.result(timeout=600) # 10 min max per chunk
|
|
all_results.extend(chunk_results)
|
|
except Exception as e:
|
|
log_fn(f"Chunk {i} failed: {e}")
|
|
|
|
completed_points += len(chunks[i])
|
|
pct = min(100, completed_points * 100 // total_points)
|
|
elapsed = time.time() - t_calc
|
|
rate = completed_points / elapsed if elapsed > 0 else 0
|
|
|
|
# Log every ~10% or on last chunk
|
|
if (i + 1) % max(1, len(chunks) // 10) == 0 or i == len(chunks) - 1:
|
|
eta = (total_points - completed_points) / rate if rate > 0 else 0
|
|
log_fn(f"Progress: {completed_points}/{total_points} ({pct}%) — "
|
|
f"{rate:.0f} pts/s, ETA {eta:.0f}s")
|
|
|
|
finally:
|
|
# Clean up temp file
|
|
try:
|
|
os.unlink(shared_data_path)
|
|
except Exception:
|
|
pass
|
|
|
|
calc_time = time.time() - t_calc
|
|
log_fn(f"Parallel done: {calc_time:.1f}s, {len(all_results)} results "
|
|
f"({calc_time / max(1, total_points) * 1000:.1f}ms/point)")
|
|
|
|
timing = {
|
|
"parallel_total": calc_time,
|
|
"serialize": serial_time,
|
|
"workers": num_workers,
|
|
}
|
|
return all_results, timing
|