mytec: after methods
This commit is contained in:
@@ -132,6 +132,59 @@ async def calculate_coverage(request: CoverageRequest) -> CoverageResponse:
|
||||
)
|
||||
|
||||
|
||||
@router.post("/preview")
|
||||
async def calculate_preview(request: CoverageRequest) -> CoverageResponse:
|
||||
"""
|
||||
Fast radial preview using terrain-only along 360 spokes.
|
||||
|
||||
Returns coverage points much faster than full calculation
|
||||
by skipping building/OSM data and using radial spokes instead of grid.
|
||||
"""
|
||||
if not request.sites:
|
||||
raise HTTPException(400, "At least one site required")
|
||||
|
||||
site = request.sites[0]
|
||||
effective_settings = apply_preset(request.settings.model_copy())
|
||||
|
||||
env = getattr(effective_settings, 'environment', 'urban')
|
||||
primary_model = select_propagation_model(site.frequency, env)
|
||||
models_used = ["terrain_los", primary_model.name]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
points = await asyncio.wait_for(
|
||||
coverage_service.calculate_radial_preview(
|
||||
site, request.settings,
|
||||
),
|
||||
timeout=30.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(408, "Preview timeout (30s)")
|
||||
|
||||
computation_time = time.time() - start_time
|
||||
|
||||
rsrp_values = [p.rsrp for p in points]
|
||||
los_count = sum(1 for p in points if p.has_los)
|
||||
|
||||
stats = {
|
||||
"min_rsrp": min(rsrp_values) if rsrp_values else 0,
|
||||
"max_rsrp": max(rsrp_values) if rsrp_values else 0,
|
||||
"avg_rsrp": sum(rsrp_values) / len(rsrp_values) if rsrp_values else 0,
|
||||
"los_percentage": (los_count / len(points) * 100) if points else 0,
|
||||
"mode": "radial_preview",
|
||||
}
|
||||
|
||||
return CoverageResponse(
|
||||
points=points,
|
||||
count=len(points),
|
||||
settings=effective_settings,
|
||||
stats=stats,
|
||||
computation_time=round(computation_time, 2),
|
||||
models_used=models_used,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/presets")
|
||||
async def get_presets():
|
||||
"""Get available propagation model presets"""
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Valid SRTM tile sizes (bytes)
|
||||
_SRTM1_SIZE = 3601 * 3601 * 2 # 25,934,402
|
||||
_SRTM3_SIZE = 1201 * 1201 * 2 # 2,884,802
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_system_info():
|
||||
@@ -72,3 +78,108 @@ async def shutdown():
|
||||
loop.call_later(3.0, lambda: os._exit(0))
|
||||
|
||||
return {"status": "shutting down", "workers_killed": killed}
|
||||
|
||||
|
||||
@router.get("/diagnostics")
|
||||
async def get_diagnostics():
|
||||
"""Validate terrain tiles and OSM cache files.
|
||||
|
||||
Checks:
|
||||
- Terrain .hgt files: must be exactly SRTM1 or SRTM3 size
|
||||
- OSM cache .json files: must be valid JSON with expected structure
|
||||
- Cache manager stats (memory + disk)
|
||||
"""
|
||||
data_path = Path(os.environ.get('RFCP_DATA_PATH', './data'))
|
||||
terrain_path = data_path / 'terrain'
|
||||
osm_dirs = [
|
||||
data_path / 'osm' / 'buildings',
|
||||
data_path / 'osm' / 'streets',
|
||||
data_path / 'osm' / 'vegetation',
|
||||
data_path / 'osm' / 'water',
|
||||
]
|
||||
|
||||
# --- Terrain tiles ---
|
||||
terrain_tiles = []
|
||||
terrain_errors = []
|
||||
total_terrain_bytes = 0
|
||||
|
||||
if terrain_path.exists():
|
||||
for hgt in sorted(terrain_path.glob("*.hgt")):
|
||||
size = hgt.stat().st_size
|
||||
total_terrain_bytes += size
|
||||
if size == _SRTM1_SIZE:
|
||||
terrain_tiles.append({"name": hgt.name, "type": "SRTM1", "size": size})
|
||||
elif size == _SRTM3_SIZE:
|
||||
terrain_tiles.append({"name": hgt.name, "type": "SRTM3", "size": size})
|
||||
else:
|
||||
terrain_errors.append({
|
||||
"name": hgt.name,
|
||||
"size": size,
|
||||
"error": f"Invalid size (expected {_SRTM1_SIZE} or {_SRTM3_SIZE})",
|
||||
})
|
||||
|
||||
# --- OSM cache ---
|
||||
osm_files = []
|
||||
osm_errors = []
|
||||
total_osm_bytes = 0
|
||||
|
||||
for osm_dir in osm_dirs:
|
||||
if not osm_dir.exists():
|
||||
continue
|
||||
category = osm_dir.name
|
||||
for jf in sorted(osm_dir.glob("*.json")):
|
||||
fsize = jf.stat().st_size
|
||||
total_osm_bytes += fsize
|
||||
try:
|
||||
data = json.loads(jf.read_text())
|
||||
has_timestamp = '_cached_at' in data or '_ts' in data
|
||||
has_data = 'data' in data or 'v' in data
|
||||
if has_timestamp and has_data:
|
||||
osm_files.append({
|
||||
"name": jf.name,
|
||||
"category": category,
|
||||
"size": fsize,
|
||||
"valid": True,
|
||||
})
|
||||
else:
|
||||
osm_errors.append({
|
||||
"name": jf.name,
|
||||
"category": category,
|
||||
"size": fsize,
|
||||
"error": "Missing expected keys (_cached_at/data or _ts/v)",
|
||||
})
|
||||
except json.JSONDecodeError as e:
|
||||
osm_errors.append({
|
||||
"name": jf.name,
|
||||
"category": category,
|
||||
"size": fsize,
|
||||
"error": f"Invalid JSON: {e}",
|
||||
})
|
||||
|
||||
# --- Cache manager stats ---
|
||||
try:
|
||||
from app.services.cache import cache_manager
|
||||
cache_stats = cache_manager.stats()
|
||||
except Exception:
|
||||
cache_stats = None
|
||||
|
||||
return {
|
||||
"data_path": str(data_path),
|
||||
"terrain": {
|
||||
"path": str(terrain_path),
|
||||
"exists": terrain_path.exists(),
|
||||
"tile_count": len(terrain_tiles),
|
||||
"error_count": len(terrain_errors),
|
||||
"total_mb": round(total_terrain_bytes / (1024 * 1024), 1),
|
||||
"tiles": terrain_tiles,
|
||||
"errors": terrain_errors,
|
||||
},
|
||||
"osm_cache": {
|
||||
"valid_count": len(osm_files),
|
||||
"error_count": len(osm_errors),
|
||||
"total_mb": round(total_osm_bytes / (1024 * 1024), 1),
|
||||
"files": osm_files,
|
||||
"errors": osm_errors,
|
||||
},
|
||||
"cache_manager": cache_stats,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ progress updates during computation phases.
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
@@ -18,6 +19,8 @@ from app.services.coverage_service import (
|
||||
)
|
||||
from app.services.parallel_coverage_service import CancellationToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Track cancellation tokens per calculation."""
|
||||
@@ -37,8 +40,8 @@ class ConnectionManager:
|
||||
"progress": min(progress, 1.0),
|
||||
"eta_seconds": eta,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_progress failed: {e}")
|
||||
|
||||
async def send_result(self, ws: WebSocket, calc_id: str, result: dict):
|
||||
try:
|
||||
@@ -47,8 +50,8 @@ class ConnectionManager:
|
||||
"calculation_id": calc_id,
|
||||
"data": result,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_result failed: {e}")
|
||||
|
||||
async def send_error(self, ws: WebSocket, calc_id: str, error: str):
|
||||
try:
|
||||
@@ -57,8 +60,8 @@ class ConnectionManager:
|
||||
"calculation_id": calc_id,
|
||||
"message": error,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"[WS] send_error failed: {e}")
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
@@ -69,6 +72,17 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
cancel_token = CancellationToken()
|
||||
ws_manager._cancel_tokens[calc_id] = cancel_token
|
||||
|
||||
# Shared progress state — written by worker threads, polled by event loop.
|
||||
# Python GIL makes dict value assignment atomic for simple types.
|
||||
_progress = {"phase": "Initializing", "pct": 0.05, "seq": 0}
|
||||
_done = False
|
||||
|
||||
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
||||
"""Thread-safe progress callback — just updates a shared dict."""
|
||||
_progress["phase"] = phase
|
||||
_progress["pct"] = pct
|
||||
_progress["seq"] += 1
|
||||
|
||||
try:
|
||||
sites_data = data.get("sites", [])
|
||||
settings_data = data.get("settings", {})
|
||||
@@ -104,45 +118,21 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
|
||||
await ws_manager.send_progress(ws, calc_id, "Initializing", 0.05)
|
||||
|
||||
# ── Bridge sync progress_fn → async WS sends ──
|
||||
# progress_fn is called from two contexts:
|
||||
# 1. Event loop thread (phases 1-2.5, directly in calculate_coverage)
|
||||
# 2. Worker threads (phase 3, from ProcessPool/sequential executors)
|
||||
# We detect which thread we're on and use the appropriate method.
|
||||
loop = asyncio.get_running_loop()
|
||||
event_loop_thread_id = threading.current_thread().ident
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
# ── Progress poller: reads shared dict and sends WS updates ──
|
||||
async def progress_poller():
|
||||
last_sent_seq = 0
|
||||
last_sent_pct = 0.0
|
||||
while not _done:
|
||||
await asyncio.sleep(0.3)
|
||||
seq = _progress["seq"]
|
||||
pct = _progress["pct"]
|
||||
phase = _progress["phase"]
|
||||
if seq != last_sent_seq and (pct - last_sent_pct >= 0.01 or phase != "Calculating coverage"):
|
||||
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
||||
last_sent_seq = seq
|
||||
last_sent_pct = pct
|
||||
|
||||
def sync_progress_fn(phase: str, pct: float, _eta: Optional[float] = None):
|
||||
"""Thread-safe progress callback for coverage_service."""
|
||||
if threading.current_thread().ident == event_loop_thread_id:
|
||||
# From event loop thread: put directly to queue
|
||||
progress_queue.put_nowait((phase, pct))
|
||||
else:
|
||||
# From worker thread: use thread-safe bridge to wake event loop
|
||||
loop.call_soon_threadsafe(progress_queue.put_nowait, (phase, pct))
|
||||
|
||||
# Background task: drain queue and send WS progress messages
|
||||
_sender_done = False
|
||||
|
||||
async def progress_sender():
|
||||
nonlocal _sender_done
|
||||
last_pct = 0.0
|
||||
while not _sender_done:
|
||||
try:
|
||||
phase, pct = await asyncio.wait_for(progress_queue.get(), timeout=0.5)
|
||||
if pct >= 1.0:
|
||||
break
|
||||
# Throttle: only send if progress changed meaningfully
|
||||
if pct - last_pct >= 0.02 or phase != "Calculating coverage":
|
||||
await ws_manager.send_progress(ws, calc_id, phase, pct)
|
||||
last_pct = pct
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
|
||||
progress_task = asyncio.create_task(progress_sender())
|
||||
poller_task = asyncio.create_task(progress_poller())
|
||||
|
||||
# Run calculation with timeout
|
||||
start_time = time.time()
|
||||
@@ -164,25 +154,23 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_token.cancel()
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
_done = True
|
||||
await poller_task
|
||||
from app.services.parallel_coverage_service import _kill_worker_processes
|
||||
_kill_worker_processes()
|
||||
await ws_manager.send_error(ws, calc_id, "Calculation timeout (5 min)")
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
cancel_token.cancel()
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
_done = True
|
||||
await poller_task
|
||||
await ws_manager.send_error(ws, calc_id, "Calculation cancelled")
|
||||
return
|
||||
|
||||
# Stop progress sender
|
||||
_sender_done = True
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
await progress_task
|
||||
# Stop poller and send final progress
|
||||
_done = True
|
||||
await poller_task
|
||||
await ws_manager.send_progress(ws, calc_id, "Finalizing", 0.98)
|
||||
|
||||
computation_time = time.time() - start_time
|
||||
|
||||
@@ -216,14 +204,10 @@ async def _run_calculation(ws: WebSocket, calc_id: str, data: dict):
|
||||
await ws_manager.send_result(ws, calc_id, result)
|
||||
|
||||
except Exception as e:
|
||||
# Stop progress sender on unhandled exception
|
||||
_sender_done = True
|
||||
logger.error(f"[WS] Calculation error: {e}", exc_info=True)
|
||||
_done = True
|
||||
try:
|
||||
progress_queue.put_nowait(("done", 1.0))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await progress_task
|
||||
await poller_task
|
||||
except Exception:
|
||||
pass
|
||||
await ws_manager.send_error(ws, calc_id, str(e))
|
||||
|
||||
@@ -678,33 +678,65 @@ class CoverageService:
|
||||
|
||||
return list(point_map.values())
|
||||
|
||||
# Adaptive resolution zone boundaries (meters)
|
||||
_ADAPTIVE_ZONES = [
|
||||
(0, 2000), # Inner: full user resolution
|
||||
(2000, 5000), # Middle: at least 300m
|
||||
(5000, float('inf')), # Outer: at least 500m
|
||||
]
|
||||
_ADAPTIVE_MIN_RES = [None, 300, 500] # Minimum resolution per zone
|
||||
|
||||
def _generate_grid(
|
||||
self,
|
||||
center_lat: float, center_lon: float,
|
||||
radius: float, resolution: float
|
||||
) -> List[Tuple[float, float]]:
|
||||
"""Generate coverage grid points"""
|
||||
"""Generate coverage grid with adaptive resolution.
|
||||
|
||||
Close to TX (<2km): user's chosen resolution (details matter).
|
||||
Mid-range (2-5km): at least 300m resolution.
|
||||
Far (>5km): at least 500m resolution.
|
||||
|
||||
For small radii or coarse base resolution, this degenerates to a
|
||||
uniform grid (no zones exceed their minimum).
|
||||
"""
|
||||
cos_lat = np.cos(np.radians(center_lat))
|
||||
seen = set()
|
||||
points = []
|
||||
|
||||
# Convert resolution to degrees
|
||||
lat_step = resolution / 111000
|
||||
lon_step = resolution / (111000 * np.cos(np.radians(center_lat)))
|
||||
for zone_idx, (zone_min_m, zone_max_m) in enumerate(self._ADAPTIVE_ZONES):
|
||||
if zone_min_m >= radius:
|
||||
break # No points in this zone
|
||||
|
||||
# Calculate grid bounds
|
||||
lat_delta = radius / 111000
|
||||
lon_delta = radius / (111000 * np.cos(np.radians(center_lat)))
|
||||
zone_max_m = min(zone_max_m, radius)
|
||||
min_res = self._ADAPTIVE_MIN_RES[zone_idx]
|
||||
zone_res = max(resolution, min_res) if min_res else resolution
|
||||
|
||||
lat = center_lat - lat_delta
|
||||
while lat <= center_lat + lat_delta:
|
||||
lon = center_lon - lon_delta
|
||||
while lon <= center_lon + lon_delta:
|
||||
# Check if within radius (circular, not square)
|
||||
dist = TerrainService.haversine_distance(center_lat, center_lon, lat, lon)
|
||||
if dist <= radius:
|
||||
points.append((lat, lon))
|
||||
lon += lon_step
|
||||
lat += lat_step
|
||||
lat_step = zone_res / 111000
|
||||
lon_step = zone_res / (111000 * cos_lat)
|
||||
|
||||
# Grid bounds for this annular ring (with small overlap at boundaries)
|
||||
lat_delta = zone_max_m / 111000
|
||||
lon_delta = zone_max_m / (111000 * cos_lat)
|
||||
|
||||
lat = center_lat - lat_delta
|
||||
while lat <= center_lat + lat_delta:
|
||||
lon = center_lon - lon_delta
|
||||
while lon <= center_lon + lon_delta:
|
||||
dist = TerrainService.haversine_distance(
|
||||
center_lat, center_lon, lat, lon
|
||||
)
|
||||
if zone_min_m <= dist <= zone_max_m:
|
||||
# Round to avoid floating-point duplicates at zone boundaries
|
||||
key = (round(lat, 7), round(lon, 7))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
points.append(key)
|
||||
lon += lon_step
|
||||
lat += lat_step
|
||||
|
||||
_clog(f"Adaptive grid: {len(points)} points "
|
||||
f"(radius={radius:.0f}m, base_res={resolution:.0f}m)")
|
||||
return points
|
||||
|
||||
def _run_point_loop(
|
||||
@@ -1051,6 +1083,112 @@ class CoverageService:
|
||||
"""Knife-edge diffraction loss using ITU-R P.526 model."""
|
||||
return _DIFFRACTION_MODEL.calculate_clearance_loss(clearance, frequency)
|
||||
|
||||
async def calculate_radial_preview(
|
||||
self,
|
||||
site: SiteParams,
|
||||
settings: CoverageSettings,
|
||||
num_spokes: int = 360,
|
||||
points_per_spoke: int = 50,
|
||||
) -> List[CoveragePoint]:
|
||||
"""Fast radial preview using terrain-only along 360 spokes.
|
||||
|
||||
Much faster than full grid because:
|
||||
- No OSM data fetch (no buildings/vegetation/water)
|
||||
- Terrain profile reused per spoke
|
||||
- Fewer total points at long range
|
||||
"""
|
||||
calc_start = time.time()
|
||||
settings = apply_preset(settings)
|
||||
|
||||
# Pre-load terrain tiles for bbox
|
||||
lat_delta = settings.radius / 111000
|
||||
cos_lat = np.cos(np.radians(site.lat))
|
||||
lon_delta = settings.radius / (111000 * cos_lat)
|
||||
min_lat = site.lat - lat_delta
|
||||
max_lat = site.lat + lat_delta
|
||||
min_lon = site.lon - lon_delta
|
||||
max_lon = site.lon + lon_delta
|
||||
|
||||
tile_names = await self.terrain.ensure_tiles_for_bbox(
|
||||
min_lat, min_lon, max_lat, max_lon
|
||||
)
|
||||
for tn in tile_names:
|
||||
self.terrain._load_tile(tn)
|
||||
|
||||
site_elevation = self.terrain.get_elevation_sync(site.lat, site.lon)
|
||||
|
||||
# Select propagation model
|
||||
env = getattr(settings, 'environment', 'urban')
|
||||
model = select_propagation_model(site.frequency, env)
|
||||
|
||||
points: List[CoveragePoint] = []
|
||||
|
||||
for angle_deg in range(num_spokes):
|
||||
angle_rad = math.radians(angle_deg)
|
||||
cos_a = math.cos(angle_rad)
|
||||
sin_a = math.sin(angle_rad)
|
||||
|
||||
# Antenna pattern loss for this spoke direction
|
||||
antenna_loss = 0.0
|
||||
if site.azimuth is not None and site.beamwidth:
|
||||
angle_diff = abs(angle_deg - site.azimuth)
|
||||
if angle_diff > 180:
|
||||
angle_diff = 360 - angle_diff
|
||||
half_bw = site.beamwidth / 2
|
||||
if angle_diff <= half_bw:
|
||||
antenna_loss = 3 * (angle_diff / half_bw) ** 2
|
||||
else:
|
||||
antenna_loss = 3 + 12 * ((angle_diff - half_bw) / half_bw) ** 2
|
||||
antenna_loss = min(antenna_loss, 25)
|
||||
|
||||
for i in range(1, points_per_spoke + 1):
|
||||
distance = i * (settings.radius / points_per_spoke)
|
||||
|
||||
# Move point along bearing
|
||||
lat_offset = (distance / 111000) * cos_a
|
||||
lon_offset = (distance / (111000 * cos_lat)) * sin_a
|
||||
rx_lat = site.lat + lat_offset
|
||||
rx_lon = site.lon + lon_offset
|
||||
|
||||
# Path loss
|
||||
prop_input = PropagationInput(
|
||||
frequency_mhz=site.frequency,
|
||||
distance_m=distance,
|
||||
tx_height_m=site.height,
|
||||
rx_height_m=1.5,
|
||||
environment=env,
|
||||
)
|
||||
path_loss = model.calculate(prop_input).path_loss_db
|
||||
|
||||
# Terrain LOS check
|
||||
terrain_loss = 0.0
|
||||
has_los = True
|
||||
if settings.use_terrain:
|
||||
los_result = self.los.check_line_of_sight_sync(
|
||||
site.lat, site.lon, site.height,
|
||||
rx_lat, rx_lon, 1.5,
|
||||
)
|
||||
has_los = los_result['has_los']
|
||||
if not has_los:
|
||||
terrain_loss = self._diffraction_loss(
|
||||
los_result['clearance'], site.frequency
|
||||
)
|
||||
|
||||
rsrp = (site.power + site.gain - path_loss
|
||||
- antenna_loss - terrain_loss)
|
||||
|
||||
if rsrp >= settings.min_signal:
|
||||
points.append(CoveragePoint(
|
||||
lat=rx_lat, lon=rx_lon, rsrp=rsrp,
|
||||
distance=distance, has_los=has_los,
|
||||
terrain_loss=terrain_loss, building_loss=0.0,
|
||||
))
|
||||
|
||||
total_time = time.time() - calc_start
|
||||
_clog(f"Radial preview: {len(points)} points, {num_spokes} spokes × "
|
||||
f"{points_per_spoke} pts/spoke, {total_time:.1f}s")
|
||||
return points
|
||||
|
||||
|
||||
# Singleton
|
||||
coverage_service = CoverageService()
|
||||
|
||||
@@ -21,6 +21,7 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
@@ -450,6 +451,9 @@ def _calculate_with_ray(
|
||||
log_fn(f"Ray done: {calc_time:.1f}s, {len(all_results)} results "
|
||||
f"({calc_time / max(1, total_points) * 1000:.1f}ms/point)")
|
||||
|
||||
# Force garbage collection after Ray computation
|
||||
gc.collect()
|
||||
|
||||
timing = {
|
||||
"parallel_total": calc_time,
|
||||
"ray_put": put_time,
|
||||
@@ -744,6 +748,8 @@ def _calculate_with_process_pool(
|
||||
block.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
# Force garbage collection to release memory from workers
|
||||
gc.collect()
|
||||
|
||||
calc_time = time.time() - t_calc
|
||||
log_fn(f"ProcessPool done: {calc_time:.1f}s, {len(all_results)} results "
|
||||
@@ -820,6 +826,9 @@ def _calculate_sequential(
|
||||
log_fn(f"Sequential done: {calc_time:.1f}s, {len(results)} results "
|
||||
f"({calc_time / max(1, total) * 1000:.1f}ms/point)")
|
||||
|
||||
# Force garbage collection after sequential computation
|
||||
gc.collect()
|
||||
|
||||
timing["sequential_total"] = calc_time
|
||||
timing["backend"] = "sequential"
|
||||
return results, timing
|
||||
|
||||
Reference in New Issue
Block a user