Files
rfcp/backend/app/services/terrain_service.py

489 lines
17 KiB
Python

import os
import struct
import asyncio
import gzip
import zipfile
import io
import numpy as np
import httpx
from pathlib import Path
from typing import List, Optional, Tuple
class TerrainService:
"""
SRTM elevation data service with local caching.
- Stores tiles in RFCP_DATA_PATH/terrain/
- In-memory LRU cache (max 20 tiles)
- Auto-downloads from S3 mirror
- Supports both SRTM1 (3601x3601) and SRTM3 (1201x1201)
"""
SRTM_SOURCES = [
# Our tile server — SRTM1 (30m) preferred, uncompressed
{
"url": "https://terra.eliah.one/srtm1/{tile_name}.hgt",
"compressed": False,
"resolution": "srtm1",
},
# Our tile server — SRTM3 (90m) fallback
{
"url": "https://terra.eliah.one/srtm3/{tile_name}.hgt",
"compressed": False,
"resolution": "srtm3",
},
# Public AWS mirror — SRTM1, gzip compressed
{
"url": "https://elevation-tiles-prod.s3.amazonaws.com/skadi/{lat_dir}/{tile_name}.hgt.gz",
"compressed": True,
"resolution": "srtm1",
},
]
def __init__(self):
self.data_path = Path(os.environ.get('RFCP_DATA_PATH', './data'))
self.terrain_path = self.data_path / 'terrain'
self.terrain_path.mkdir(parents=True, exist_ok=True)
# In-memory cache for loaded tiles
self._tile_cache: dict[str, np.ndarray] = {}
self._max_cache_tiles = 20 # ~500MB max
def get_tile_name(self, lat: float, lon: float) -> str:
"""Convert lat/lon to SRTM tile name (e.g., N48E035)"""
lat_int = int(lat) if lat >= 0 else int(lat) - 1
lon_int = int(lon) if lon >= 0 else int(lon) - 1
lat_letter = 'N' if lat_int >= 0 else 'S'
lon_letter = 'E' if lon_int >= 0 else 'W'
return f"{lat_letter}{abs(lat_int):02d}{lon_letter}{abs(lon_int):03d}"
def get_tile_path(self, tile_name: str) -> Path:
"""Get local path for tile"""
return self.terrain_path / f"{tile_name}.hgt"
async def download_tile(self, tile_name: str) -> bool:
"""Download SRTM tile from configured sources, preferring highest resolution."""
tile_path = self.get_tile_path(tile_name)
if tile_path.exists():
return True
lat_dir = tile_name[:3] # e.g., "N48"
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
for source in self.SRTM_SOURCES:
url = source["url"].format(lat_dir=lat_dir, tile_name=tile_name)
try:
response = await client.get(url)
if response.status_code == 200:
data = response.content
# Skip empty responses
if len(data) < 1000:
continue
if source["compressed"]:
if url.endswith('.gz'):
data = gzip.decompress(data)
elif url.endswith('.zip'):
with zipfile.ZipFile(io.BytesIO(data)) as zf:
for name in zf.namelist():
if name.endswith('.hgt'):
data = zf.read(name)
break
# Validate tile size (SRTM1: 25,934,402 bytes, SRTM3: 2,884,802 bytes)
if len(data) not in (3601 * 3601 * 2, 1201 * 1201 * 2):
print(f"[Terrain] Invalid tile size {len(data)} from {url}")
continue
tile_path.write_bytes(data)
res = source["resolution"]
size_mb = len(data) / 1048576
print(f"[Terrain] Downloaded {tile_name} ({res}, {size_mb:.1f} MB)")
return True
except Exception as e:
print(f"[Terrain] Failed from {url}: {e}")
continue
print(f"[Terrain] Could not download {tile_name} from any source")
return False
def _load_tile(self, tile_name: str) -> Optional[np.ndarray]:
"""Load tile from disk into memory cache using memory-mapped I/O.
Uses np.memmap so the OS pages data from disk on demand — near-zero
upfront RAM cost per tile (~25 MB savings each vs full load).
Falls back to np.frombuffer if memmap fails.
"""
# Check memory cache first
if tile_name in self._tile_cache:
return self._tile_cache[tile_name]
tile_path = self.get_tile_path(tile_name)
if not tile_path.exists():
return None
try:
file_size = tile_path.stat().st_size
# SRTM HGT format: big-endian signed 16-bit integers
if file_size == 3601 * 3601 * 2:
size = 3601 # SRTM1 (30m)
elif file_size == 1201 * 1201 * 2:
size = 1201 # SRTM3 (90m)
else:
print(f"[Terrain] Unknown tile size: {file_size} bytes for {tile_name}")
return None
# Memory-mapped loading — OS pages from disk, near-zero RAM
try:
tile = np.memmap(
tile_path, dtype='>i2', mode='r', shape=(size, size),
)
except Exception:
# Fallback: full load into RAM
data = tile_path.read_bytes()
tile = np.frombuffer(data, dtype='>i2').reshape((size, size))
# Manage memory cache with LRU eviction
if len(self._tile_cache) >= self._max_cache_tiles:
oldest = next(iter(self._tile_cache))
del self._tile_cache[oldest]
self._tile_cache[tile_name] = tile
return tile
except Exception as e:
print(f"[Terrain] Failed to load {tile_name}: {e}")
return None
async def load_tile(self, tile_name: str) -> Optional[np.ndarray]:
"""Load tile into memory, downloading if needed"""
# Check memory cache
if tile_name in self._tile_cache:
return self._tile_cache[tile_name]
# Download if missing
if not self.get_tile_path(tile_name).exists():
success = await self.download_tile(tile_name)
if not success:
return None
return self._load_tile(tile_name)
def _bilinear_sample(self, tile: np.ndarray, lat: float, lon: float) -> float:
"""Sample elevation with bilinear interpolation for sub-meter accuracy.
SRTM1 at 30m means nearest-neighbor can have 15m positional error.
Bilinear interpolation reduces this to sub-meter accuracy.
"""
size = tile.shape[0]
# Tile southwest corner
lat_int = int(lat) if lat >= 0 else int(lat) - 1
lon_int = int(lon) if lon >= 0 else int(lon) - 1
# Fractional position within tile (0.0 to 1.0)
lat_frac = lat - lat_int # 0 = south edge, 1 = north edge
lon_frac = lon - lon_int # 0 = west edge, 1 = east edge
# Convert to row/col (note: rows go north to south!)
row_exact = (1.0 - lat_frac) * (size - 1) # 0 = north, size-1 = south
col_exact = lon_frac * (size - 1) # 0 = west, size-1 = east
# Four surrounding grid points
r0 = int(row_exact)
c0 = int(col_exact)
r1 = min(r0 + 1, size - 1)
c1 = min(c0 + 1, size - 1)
# Fractional position between grid points
dr = row_exact - r0
dc = col_exact - c0
# Get four corner values
z00 = tile[r0, c0]
z01 = tile[r0, c1]
z10 = tile[r1, c0]
z11 = tile[r1, c1]
# Handle void (-32768) values - fall back to nearest valid
void_val = -32768
corners = [(z00, r0, c0), (z01, r0, c1), (z10, r1, c0), (z11, r1, c1)]
if z00 == void_val or z01 == void_val or z10 == void_val or z11 == void_val:
valid = [(z, r, c) for z, r, c in corners if z != void_val]
if not valid:
return 0.0
# Return nearest valid value
return float(valid[0][0])
# Bilinear interpolation
elevation = (z00 * (1 - dr) * (1 - dc) +
z01 * (1 - dr) * dc +
z10 * dr * (1 - dc) +
z11 * dr * dc)
return float(elevation)
async def get_elevation(self, lat: float, lon: float) -> float:
"""Get elevation at specific coordinate with bilinear interpolation."""
tile_name = self.get_tile_name(lat, lon)
tile = await self.load_tile(tile_name)
if tile is None:
return 0.0
return self._bilinear_sample(tile, lat, lon)
def get_elevation_sync(self, lat: float, lon: float) -> float:
"""Sync elevation lookup with bilinear interpolation. Returns 0.0 if tile not loaded."""
tile_name = self.get_tile_name(lat, lon)
tile = self._tile_cache.get(tile_name)
if tile is None:
return 0.0
return self._bilinear_sample(tile, lat, lon)
def get_elevations_batch(self, lats: np.ndarray, lons: np.ndarray) -> np.ndarray:
"""Vectorized elevation lookup with bilinear interpolation.
Handles points spanning multiple tiles efficiently.
Groups points by tile, processes each tile with full NumPy vectorization.
Tiles must be pre-loaded into memory cache.
Args:
lats: Array of latitudes
lons: Array of longitudes
Returns:
Array of elevations (0.0 for missing tiles or void data)
"""
elevations = np.zeros(len(lats), dtype=np.float32)
# Compute tile indices for each point
lat_ints = np.floor(lats).astype(int)
lon_ints = np.floor(lons).astype(int)
# Group by tile using unique key
unique_tiles = set(zip(lat_ints, lon_ints))
for lat_int, lon_int in unique_tiles:
# Get tile name
lat_letter = 'N' if lat_int >= 0 else 'S'
lon_letter = 'E' if lon_int >= 0 else 'W'
tile_name = f"{lat_letter}{abs(lat_int):02d}{lon_letter}{abs(lon_int):03d}"
tile = self._tile_cache.get(tile_name)
if tile is None:
continue
# Mask for points in this tile
mask = (lat_ints == lat_int) & (lon_ints == lon_int)
tile_lats = lats[mask]
tile_lons = lons[mask]
size = tile.shape[0]
# Vectorized bilinear interpolation for all points in this tile
lat_frac = tile_lats - lat_int
lon_frac = tile_lons - lon_int
row_exact = (1.0 - lat_frac) * (size - 1)
col_exact = lon_frac * (size - 1)
r0 = np.clip(row_exact.astype(int), 0, size - 2)
c0 = np.clip(col_exact.astype(int), 0, size - 2)
r1 = r0 + 1
c1 = c0 + 1
dr = row_exact - r0
dc = col_exact - c0
# Get four corner values for all points at once
z00 = tile[r0, c0].astype(np.float32)
z01 = tile[r0, c1].astype(np.float32)
z10 = tile[r1, c0].astype(np.float32)
z11 = tile[r1, c1].astype(np.float32)
# Bilinear interpolation (vectorized)
result = (z00 * (1 - dr) * (1 - dc) +
z01 * (1 - dr) * dc +
z10 * dr * (1 - dc) +
z11 * dr * dc)
# Handle void values (-32768) - set to 0
void_mask = (z00 == -32768) | (z01 == -32768) | (z10 == -32768) | (z11 == -32768)
result[void_mask] = 0.0
elevations[mask] = result
return elevations
def get_required_tiles(self, center_lat: float, center_lon: float, radius_km: float) -> list:
"""Determine which tiles are needed for a coverage calculation."""
# Convert radius to degrees (approximate)
lat_delta = radius_km / 111.0 # ~111 km per degree latitude
lon_delta = radius_km / (111.0 * np.cos(np.radians(center_lat)))
min_lat = center_lat - lat_delta
max_lat = center_lat + lat_delta
min_lon = center_lon - lon_delta
max_lon = center_lon + lon_delta
tiles = []
for lat in range(int(np.floor(min_lat)), int(np.floor(max_lat)) + 1):
for lon in range(int(np.floor(min_lon)), int(np.floor(max_lon)) + 1):
lat_letter = 'N' if lat >= 0 else 'S'
lon_letter = 'E' if lon >= 0 else 'W'
tile_name = f"{lat_letter}{abs(lat):02d}{lon_letter}{abs(lon):03d}"
tiles.append(tile_name)
return tiles
def get_missing_tiles(self, center_lat: float, center_lon: float, radius_km: float) -> list:
"""Check which needed tiles are not available locally."""
required = self.get_required_tiles(center_lat, center_lon, radius_km)
return [t for t in required if not self.get_tile_path(t).exists()]
async def get_elevation_profile(
self,
lat1: float, lon1: float,
lat2: float, lon2: float,
num_points: int = 100
) -> List[dict]:
"""Get elevation profile between two points"""
lats = np.linspace(lat1, lat2, num_points)
lons = np.linspace(lon1, lon2, num_points)
total_distance = self.haversine_distance(lat1, lon1, lat2, lon2)
distances = np.linspace(0, total_distance, num_points)
profile = []
for i, (lat, lon, dist) in enumerate(zip(lats, lons, distances)):
elev = await self.get_elevation(lat, lon)
profile.append({
"lat": float(lat),
"lon": float(lon),
"elevation": elev,
"distance": float(dist)
})
return profile
def get_elevation_profile_sync(
self,
lat1: float, lon1: float,
lat2: float, lon2: float,
num_points: int = 50
) -> List[dict]:
"""Sync elevation profile - tiles must be pre-loaded into memory cache."""
lats = np.linspace(lat1, lat2, num_points)
lons = np.linspace(lon1, lon2, num_points)
total_distance = self.haversine_distance(lat1, lon1, lat2, lon2)
distances = np.linspace(0, total_distance, num_points)
profile = []
for i in range(num_points):
profile.append({
"lat": float(lats[i]),
"lon": float(lons[i]),
"elevation": self.get_elevation_sync(float(lats[i]), float(lons[i])),
"distance": float(distances[i])
})
return profile
async def ensure_tiles_for_bbox(
self,
min_lat: float, min_lon: float,
max_lat: float, max_lon: float
) -> list[str]:
"""Pre-download all tiles needed for a bounding box"""
tiles_needed = []
for lat in range(int(min_lat), int(max_lat) + 1):
for lon in range(int(min_lon), int(max_lon) + 1):
tile_name = self.get_tile_name(lat, lon)
tiles_needed.append(tile_name)
# Download in parallel (batches of 5 to avoid overload)
downloaded = []
batch_size = 5
for i in range(0, len(tiles_needed), batch_size):
batch = tiles_needed[i:i + batch_size]
results = await asyncio.gather(*[
self.download_tile(tile) for tile in batch
])
for tile, ok in zip(batch, results):
if ok:
downloaded.append(tile)
return downloaded
def get_cached_tiles(self) -> list[str]:
"""List all locally cached tile names"""
return [f.stem for f in self.terrain_path.glob("*.hgt")]
def get_cache_size_mb(self) -> float:
"""Get total terrain cache size in MB"""
total = sum(f.stat().st_size for f in self.terrain_path.glob("*.hgt"))
return total / (1024 * 1024)
def evict_disk_cache(self, max_size_mb: float = 2048.0):
"""LRU eviction of .hgt files when disk cache exceeds max_size_mb.
Deletes the oldest-accessed files until total size is under the limit.
"""
hgt_files = list(self.terrain_path.glob("*.hgt"))
if not hgt_files:
return
total = sum(f.stat().st_size for f in hgt_files)
if total / (1024 * 1024) <= max_size_mb:
return
# Sort by access time (oldest first)
hgt_files.sort(key=lambda f: f.stat().st_atime)
evicted = 0
for f in hgt_files:
if total / (1024 * 1024) <= max_size_mb:
break
fsize = f.stat().st_size
# Remove from memory cache if loaded
stem = f.stem
self._tile_cache.pop(stem, None)
f.unlink()
total -= fsize
evicted += 1
if evicted:
print(f"[Terrain] Evicted {evicted} tiles, "
f"cache now {total / (1024 * 1024):.0f} MB")
@staticmethod
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
"""Calculate distance between two points in meters"""
EARTH_RADIUS = 6371000
lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
c = 2 * np.arcsin(np.sqrt(a))
return EARTH_RADIUS * c
# Singleton instance
terrain_service = TerrainService()