Files
GeoData/geodata_pipeline/swe_lods.py

1279 lines
43 KiB
Python

from __future__ import annotations
import json
import math
import os
import tomllib
from dataclasses import asdict
from typing import Iterable
import numpy as np
from osgeo import gdal
from .config import Config, SweLodConfig, SweLodLevelConfig
from .gdal_utils import build_vrt, cleanup_aux_files, ensure_dir, ensure_parent, open_dataset, safe_remove
from .pointcloud import find_pointcloud_file, read_pointcloud_file
gdal.UseExceptions()
def export_swe_lods(cfg: Config, *, force_vrt: bool = False) -> int:
swe_cfg = cfg.swe_lod
if not swe_cfg.enabled:
print("[swe_lods] SWE LOD export disabled in config.")
return 0
ensure_dir(cfg.work.work_dir)
ensure_dir(swe_cfg.out_dir)
ensure_parent(swe_cfg.tile_index_path)
ensure_parent(swe_cfg.manifest_path)
ensure_parent(swe_cfg.boundary_manifest_path)
height_ds = _prepare_height_source(cfg, swe_cfg, force_vrt=force_vrt)
xmin, ymin, xmax, ymax = _dataset_bounds(height_ds)
source_bounds = (xmin, ymin, xmax, ymax)
base_tile_size = max(level.tile_size_m for level in swe_cfg.lods)
origin_x = swe_cfg.origin_x
origin_y = swe_cfg.origin_y
if origin_x is None:
origin_x = math.floor(xmin / base_tile_size) * base_tile_size
if origin_y is None:
origin_y = math.floor(ymin / base_tile_size) * base_tile_size
porosity_ds = _open_optional_raster(swe_cfg.porosity_source, "porosity")
building_ds = _open_optional_raster(swe_cfg.building_height_source, "building height")
tile_rows = []
skipped = 0
written = 0
for level_index, level in enumerate(swe_cfg.lods):
lod_name = level.name or f"lod{level_index}"
lod_dir = os.path.join(swe_cfg.out_dir, lod_name)
height_dir = os.path.join(lod_dir, "height")
porosity_dir = os.path.join(lod_dir, "porosity")
building_dir = os.path.join(lod_dir, "buildings")
ensure_dir(height_dir)
ensure_dir(porosity_dir)
if swe_cfg.include_buildings:
ensure_dir(building_dir)
for tile_id, bounds in _iter_tiles(origin_x, origin_y, xmin, ymin, xmax, ymax, level.tile_size_m):
# Skip sliver tiles that intersect source bounds by less than one output pixel.
if not _has_min_coverage(bounds, source_bounds, level.resolution, level.resolution):
skipped += 1
continue
try:
height = _warp_array(
height_ds,
bounds,
level.resolution,
level.resolution,
swe_cfg.height_resample,
swe_cfg.height_nodata,
)
except RuntimeError as exc:
print(f"[swe_lods] Warp failed for {lod_name} {tile_id}: {exc}")
skipped += 1
continue
height_path = os.path.join(height_dir, f"height_{tile_id}.exr")
_write_exr(height_path, height, swe_cfg.prefer_float16)
porosity_path = os.path.join(porosity_dir, f"porosity_{tile_id}.exr")
porosity = _build_porosity(
porosity_ds,
bounds,
level.resolution,
swe_cfg,
)
_write_exr(porosity_path, porosity, swe_cfg.prefer_float16)
building_path = ""
if swe_cfg.include_buildings:
building_path = os.path.join(building_dir, f"buildings_{tile_id}.exr")
if building_ds is None:
building = np.zeros((level.resolution, level.resolution), dtype=np.float32)
else:
building = _warp_array(
building_ds,
bounds,
level.resolution,
level.resolution,
swe_cfg.building_resample,
swe_cfg.building_nodata,
)
_write_exr(building_path, building, swe_cfg.prefer_float16)
tile_rows.append(
{
"lod": lod_name,
"tile_x": tile_id.split("_")[0],
"tile_y": tile_id.split("_")[1],
"xmin": bounds[0],
"ymin": bounds[1],
"xmax": bounds[2],
"ymax": bounds[3],
"tile_size_m": level.tile_size_m,
"resolution": level.resolution,
"height_path": height_path,
"porosity_path": porosity_path,
"building_path": building_path,
}
)
written += 1
_write_tile_index(swe_cfg.tile_index_path, tile_rows)
_write_manifest(
swe_cfg.manifest_path,
swe_cfg,
origin_x,
origin_y,
xmin,
ymin,
xmax,
ymax,
)
_write_boundary_manifest(
cfg,
swe_cfg,
tile_rows,
force_vrt=force_vrt,
)
removed = cleanup_aux_files(_cleanup_patterns(cfg.raw.dgm1_dir))
print(f"[swe_lods] Summary: wrote {written} tiles; skipped {skipped}.")
print(f"[swe_lods] Cleanup removed {removed} temporary files/sidecars.")
if written == 0:
print("[swe_lods] Error: no tiles were written.")
return 1
return 0
def export_swe_porosity(cfg: Config, *, force_vrt: bool = False) -> int:
swe_cfg = cfg.swe_lod
if not swe_cfg.enabled:
print("[swe_porosity] SWE LOD export disabled in config.")
return 0
ensure_dir(cfg.work.work_dir)
ensure_dir(swe_cfg.out_dir)
ensure_parent(swe_cfg.tile_index_path)
ensure_parent(swe_cfg.manifest_path)
ensure_parent(swe_cfg.boundary_manifest_path)
porosity_source = swe_cfg.porosity_source
building_source = swe_cfg.building_height_source
if _uses_lpo(porosity_source) or _uses_lpo(building_source):
porosity_source, building_source = _ensure_lpo_sources(cfg, swe_cfg, force_vrt=force_vrt)
porosity_ds = _open_optional_raster(porosity_source, "porosity")
building_ds = _open_optional_raster(building_source, "building height")
if porosity_ds is None and building_ds is None:
raise SystemExit("[swe_porosity] No porosity or building height source configured.")
bounds_ds = porosity_ds or building_ds
xmin, ymin, xmax, ymax = _dataset_bounds(bounds_ds)
source_bounds = (xmin, ymin, xmax, ymax)
base_tile_size = max(level.tile_size_m for level in swe_cfg.lods)
origin_x = swe_cfg.origin_x
origin_y = swe_cfg.origin_y
if origin_x is None:
origin_x = math.floor(xmin / base_tile_size) * base_tile_size
if origin_y is None:
origin_y = math.floor(ymin / base_tile_size) * base_tile_size
tile_rows = []
skipped = 0
written = 0
for level_index, level in enumerate(swe_cfg.lods):
lod_name = level.name or f"lod{level_index}"
lod_dir = os.path.join(swe_cfg.out_dir, lod_name)
height_dir = os.path.join(lod_dir, "height")
porosity_dir = os.path.join(lod_dir, "porosity")
building_dir = os.path.join(lod_dir, "buildings")
ensure_dir(height_dir)
ensure_dir(porosity_dir)
if swe_cfg.include_buildings:
ensure_dir(building_dir)
for tile_id, bounds in _iter_tiles(origin_x, origin_y, xmin, ymin, xmax, ymax, level.tile_size_m):
if not _has_min_coverage(bounds, source_bounds, level.resolution, level.resolution):
skipped += 1
continue
porosity_path = os.path.join(porosity_dir, f"porosity_{tile_id}.exr")
porosity = _build_porosity(
porosity_ds,
bounds,
level.resolution,
swe_cfg,
)
_write_exr(porosity_path, porosity, swe_cfg.prefer_float16)
building_path = ""
if swe_cfg.include_buildings:
building_path = os.path.join(building_dir, f"buildings_{tile_id}.exr")
if building_ds is None:
building = np.zeros((level.resolution, level.resolution), dtype=np.float32)
else:
building = _warp_array(
building_ds,
bounds,
level.resolution,
level.resolution,
swe_cfg.building_resample,
swe_cfg.building_nodata,
)
_write_exr(building_path, building, swe_cfg.prefer_float16)
tile_rows.append(
{
"lod": lod_name,
"tile_x": tile_id.split("_")[0],
"tile_y": tile_id.split("_")[1],
"xmin": bounds[0],
"ymin": bounds[1],
"xmax": bounds[2],
"ymax": bounds[3],
"tile_size_m": level.tile_size_m,
"resolution": level.resolution,
"height_path": "",
"porosity_path": porosity_path,
"building_path": building_path,
}
)
written += 1
_write_tile_index(swe_cfg.tile_index_path, tile_rows)
_write_manifest(
swe_cfg.manifest_path,
swe_cfg,
origin_x,
origin_y,
xmin,
ymin,
xmax,
ymax,
)
_write_boundary_manifest(
cfg,
swe_cfg,
tile_rows,
force_vrt=force_vrt,
)
removed = cleanup_aux_files(_cleanup_patterns(cfg.raw.dgm1_dir))
print(f"[swe_porosity] Summary: wrote {written} tiles; skipped {skipped}.")
print(f"[swe_porosity] Cleanup removed {removed} temporary files/sidecars.")
if written == 0:
print("[swe_porosity] Error: no tiles were written.")
return 1
return 0
def _uses_lpo(value: str | None) -> bool:
if value is None:
return False
return str(value).strip().lower() == "lpo"
def _ensure_lpo_sources(cfg: Config, swe_cfg: SweLodConfig, *, force_vrt: bool):
lpo_dir = swe_cfg.lpo_dir or cfg.pointcloud.lpo_dir
if not lpo_dir:
raise SystemExit("[swe_porosity] LPO directory not configured.")
manifest = _resolve_height_manifest(cfg, swe_cfg)
if not os.path.exists(manifest):
raise SystemExit(f"[swe_porosity] Height manifest missing: {manifest}")
work_dir = os.path.join(cfg.work.work_dir, "swe_lpo")
porosity_tiles = os.path.join(work_dir, "porosity_tiles")
building_tiles = os.path.join(work_dir, "building_tiles")
ensure_dir(porosity_tiles)
ensure_dir(building_tiles)
porosity_paths = []
building_paths = []
tiles = _read_manifest_tiles(manifest)
if not tiles:
raise SystemExit(f"[swe_porosity] No tiles found in {manifest}")
for tile_id, bounds in tiles:
lpo_path = find_pointcloud_file(lpo_dir, tile_id)
if not lpo_path:
continue
porosity_path = os.path.join(porosity_tiles, f"{tile_id}.tif")
building_path = os.path.join(building_tiles, f"{tile_id}.tif")
if not force_vrt and os.path.exists(porosity_path) and os.path.exists(building_path):
porosity_paths.append(porosity_path)
building_paths.append(building_path)
continue
points = read_pointcloud_file(
lpo_path,
bounds=bounds,
chunk_size=cfg.pointcloud.chunk_size,
)
porosity, building = _rasterize_lpo(
points.x,
points.y,
points.z,
bounds,
swe_cfg.lpo_base_res,
swe_cfg.lpo_density_threshold,
swe_cfg.lpo_height_percentile,
swe_cfg.lpo_min_height_m,
)
gt = _geotransform_from_bounds(bounds, swe_cfg.lpo_base_res)
_write_geotiff(porosity_path, porosity, gt, "")
_write_geotiff(building_path, building, gt, "")
porosity_paths.append(porosity_path)
building_paths.append(building_path)
if not porosity_paths and not building_paths:
raise SystemExit("[swe_porosity] No LPO tiles were processed.")
porosity_vrt = os.path.join(work_dir, "porosity.vrt")
building_vrt = os.path.join(work_dir, "buildings.vrt")
build_vrt(porosity_vrt, porosity_paths, force=True)
build_vrt(building_vrt, building_paths, force=True)
return porosity_vrt, building_vrt
def _prepare_height_source(cfg: Config, swe_cfg: SweLodConfig, *, force_vrt: bool) :
source = (swe_cfg.height_source or "dgm1").lower()
if source == "river_erosion":
return _prepare_eroded_height_source(cfg, swe_cfg, force_vrt=force_vrt)
if source == "dgm1":
tif_paths = sorted(_collect_height_sources(cfg.raw.dgm1_dir))
if not tif_paths:
raise SystemExit(f"[swe_lods] No heightmap sources found in {cfg.raw.dgm1_dir}.")
build_vrt(cfg.work.heightmap_vrt, tif_paths, force=force_vrt)
return open_dataset(cfg.work.heightmap_vrt, f"[swe_lods] Could not open {cfg.work.heightmap_vrt}")
raise SystemExit(f"[swe_lods] Unknown height_source '{swe_cfg.height_source}'.")
def _prepare_eroded_height_source(cfg: Config, swe_cfg: SweLodConfig, *, force_vrt: bool):
from csv import DictReader
source_dir = swe_cfg.height_source_dir or cfg.river_erosion.output_dir
manifest = swe_cfg.height_source_manifest or cfg.river_erosion.manifest_vr
if not os.path.exists(manifest):
raise SystemExit(f"[swe_lods] River erosion manifest missing: {manifest}")
work_dir = os.path.join(cfg.work.work_dir, "swe_eroded")
ensure_dir(work_dir)
tif_paths = []
with open(manifest, newline="", encoding="utf-8") as handle:
reader = DictReader(handle)
for row in reader:
tile_id = (row.get("tile_id") or "").strip()
if not tile_id:
continue
png_path = os.path.join(source_dir, f"{tile_id}.png")
if not os.path.exists(png_path):
print(f"[swe_lods] Missing eroded PNG {png_path}")
continue
tile_min = _parse_float(row.get("tile_min"))
tile_max = _parse_float(row.get("tile_max"))
global_min = _parse_float(row.get("global_min"))
global_max = _parse_float(row.get("global_max"))
if tile_min is None or tile_max is None:
tile_min = global_min
tile_max = global_max
if tile_min is None or tile_max is None:
print(f"[swe_lods] Missing min/max for {tile_id}; skipping.")
continue
if tile_max <= tile_min:
tile_max = tile_min + 1e-3
tif_path = os.path.join(work_dir, f"{tile_id}.tif")
if force_vrt and os.path.exists(tif_path):
safe_remove(tif_path)
if not os.path.exists(tif_path):
ds_png = open_dataset(png_path, f"[swe_lods] Could not open {png_path}")
band = ds_png.GetRasterBand(1)
raw = band.ReadAsArray().astype(np.float32)
gt = ds_png.GetGeoTransform()
proj = ds_png.GetProjection()
ds_png = None
height = tile_min + (raw / 65535.0) * (tile_max - tile_min)
_write_geotiff(tif_path, height, gt, proj)
tif_paths.append(tif_path)
if not tif_paths:
raise SystemExit("[swe_lods] No eroded height tiles available to build VRT.")
vrt_path = os.path.join(work_dir, "swe_eroded.vrt")
build_vrt(vrt_path, tif_paths, force=True)
return open_dataset(vrt_path, f"[swe_lods] Could not open {vrt_path}")
def _collect_height_sources(raw_dir: str) -> Iterable[str]:
return sorted(
[
os.path.join(raw_dir, name)
for name in os.listdir(raw_dir)
if name.lower().endswith(".tif")
]
)
def _dataset_bounds(ds) -> tuple[float, float, float, float]:
gt = ds.GetGeoTransform()
ulx, xres, _, uly, _, yres = gt
xmax = ulx + xres * ds.RasterXSize
ymin = uly + yres * ds.RasterYSize
xmin = ulx
ymax = uly
return min(xmin, xmax), min(ymin, ymax), max(xmin, xmax), max(ymin, ymax)
def _has_min_coverage(
tile_bounds: tuple[float, float, float, float],
source_bounds: tuple[float, float, float, float],
out_w: int,
out_h: int,
) -> bool:
txmin, tymin, txmax, tymax = tile_bounds
sxmin, symin, sxmax, symax = source_bounds
overlap_w = max(0.0, min(txmax, sxmax) - max(txmin, sxmin))
overlap_h = max(0.0, min(tymax, symax) - max(tymin, symin))
if overlap_w <= 0.0 or overlap_h <= 0.0:
return False
tile_w = max(1e-6, txmax - txmin)
tile_h = max(1e-6, tymax - tymin)
covered_px_x = (overlap_w / tile_w) * float(out_w)
covered_px_y = (overlap_h / tile_h) * float(out_h)
# Require at least one output pixel in each axis, otherwise warps tend to
# generate empty/zero tiles at non-aligned outer borders.
return covered_px_x >= 1.0 and covered_px_y >= 1.0
def _iter_tiles(
origin_x: float,
origin_y: float,
xmin: float,
ymin: float,
xmax: float,
ymax: float,
tile_size: float,
):
start_x = int(math.floor((xmin - origin_x) / tile_size))
end_x = int(math.ceil((xmax - origin_x) / tile_size))
start_y = int(math.floor((ymin - origin_y) / tile_size))
end_y = int(math.ceil((ymax - origin_y) / tile_size))
for ty in range(start_y, end_y):
for tx in range(start_x, end_x):
tile_min_x = origin_x + tx * tile_size
tile_min_y = origin_y + ty * tile_size
tile_max_x = tile_min_x + tile_size
tile_max_y = tile_min_y + tile_size
if tile_max_x <= xmin or tile_min_x >= xmax or tile_max_y <= ymin or tile_min_y >= ymax:
continue
tile_id = f"{tx}_{ty}"
yield tile_id, (tile_min_x, tile_min_y, tile_max_x, tile_max_y)
def _warp_array(
ds,
bounds: tuple[float, float, float, float],
width: int,
height: int,
resample: str,
dst_nodata: float | None,
) -> np.ndarray:
warp_opts = gdal.WarpOptions(
format="MEM",
outputBounds=bounds,
width=width,
height=height,
resampleAlg=resample,
dstNodata=dst_nodata,
)
warped = gdal.Warp("", ds, options=warp_opts)
if warped is None:
raise RuntimeError("GDAL Warp returned None.")
band = warped.GetRasterBand(1)
data = band.ReadAsArray()
return data.astype(np.float32, copy=False)
def _build_porosity(
porosity_ds,
bounds: tuple[float, float, float, float],
resolution: int,
cfg: SweLodConfig,
) -> np.ndarray:
if porosity_ds is None:
return np.ones((resolution, resolution), dtype=np.float32)
porosity = _warp_array(
porosity_ds,
bounds,
resolution,
resolution,
cfg.porosity_resample,
cfg.porosity_nodata,
)
porosity = np.clip(porosity, 0.0, 1.0)
if cfg.solid_bias <= 0.0:
return porosity
denom = porosity + (1.0 - porosity) * (1.0 + cfg.solid_bias)
with np.errstate(divide="ignore", invalid="ignore"):
biased = np.divide(porosity, denom, out=np.zeros_like(porosity), where=denom != 0)
return np.clip(biased, 0.0, 1.0)
def _write_exr(path: str, data: np.ndarray, prefer_float16: bool) -> None:
driver = gdal.GetDriverByName("EXR")
if driver is None:
raise SystemExit("[swe_lods] GDAL EXR driver not available.")
out_type = gdal.GDT_Float32
if prefer_float16 and hasattr(gdal, "GDT_Float16"):
out_type = gdal.GDT_Float16
height, width = data.shape
ds = driver.Create(path, width, height, 1, out_type)
if ds is None:
raise RuntimeError(f"Could not create EXR output at {path}")
band = ds.GetRasterBand(1)
band.WriteArray(data.astype(np.float32, copy=False))
band.FlushCache()
ds.FlushCache()
ds = None
def _write_geotiff(path: str, data: np.ndarray, geo_transform, projection: str) -> None:
driver = gdal.GetDriverByName("GTiff")
if driver is None:
raise SystemExit("[swe_lods] GDAL GTiff driver not available.")
height, width = data.shape
ds = driver.Create(path, width, height, 1, gdal.GDT_Float32)
if ds is None:
raise RuntimeError(f"Could not create GeoTIFF at {path}")
ds.SetGeoTransform(geo_transform)
if projection:
ds.SetProjection(projection)
band = ds.GetRasterBand(1)
band.WriteArray(data.astype(np.float32, copy=False))
band.FlushCache()
ds.FlushCache()
ds = None
def _rasterize_lpo(
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
bounds: tuple[float, float, float, float],
resolution: int,
density_threshold: int,
height_percentile: float,
min_height: float,
) -> tuple[np.ndarray, np.ndarray]:
xmin, ymin, xmax, ymax = bounds
resolution = max(2, int(resolution))
step = (xmax - xmin) / (resolution - 1)
if step <= 0:
return np.ones((resolution, resolution), dtype=np.float32), np.zeros((resolution, resolution), dtype=np.float32)
mask = z >= min_height
if mask.any():
x = x[mask]
y = y[mask]
z = z[mask]
if x.size == 0:
porosity = np.ones((resolution, resolution), dtype=np.float32)
building = np.zeros((resolution, resolution), dtype=np.float32)
return porosity, building
ix = np.floor((x - xmin) / step).astype(np.int32)
iy = np.floor((y - ymin) / step).astype(np.int32)
ix = np.clip(ix, 0, resolution - 1)
iy = np.clip(iy, 0, resolution - 1)
counts = np.zeros((resolution, resolution), dtype=np.int32)
np.add.at(counts, (iy, ix), 1)
porosity = np.where(counts >= max(1, density_threshold), 0.0, 1.0).astype(np.float32)
building = np.zeros((resolution * resolution,), dtype=np.float32)
cell_index = iy * resolution + ix
order = np.argsort(cell_index)
cell_sorted = cell_index[order]
z_sorted = z[order]
unique_cells, start_idx = np.unique(cell_sorted, return_index=True)
for idx, cell in enumerate(unique_cells):
start = start_idx[idx]
end = start_idx[idx + 1] if idx + 1 < len(start_idx) else len(cell_sorted)
vals = z_sorted[start:end]
if vals.size == 0:
continue
if height_percentile >= 100.0:
height = float(np.max(vals))
else:
height = float(np.percentile(vals, height_percentile))
building[cell] = height
building = building.reshape((resolution, resolution))
return porosity, building
def _geotransform_from_bounds(bounds: tuple[float, float, float, float], resolution: int):
xmin, ymin, xmax, ymax = bounds
step = (xmax - xmin) / (resolution - 1)
return (xmin, step, 0.0, ymax, 0.0, -step)
def _resolve_height_manifest(cfg: Config, swe_cfg: SweLodConfig) -> str:
if swe_cfg.height_source_manifest:
return swe_cfg.height_source_manifest
if swe_cfg.height_source.lower() == "river_erosion":
return cfg.river_erosion.manifest_vr
return cfg.export.manifest_path
def _read_manifest_tiles(path: str) -> list[tuple[str, tuple[float, float, float, float]]]:
import csv
tiles = []
with open(path, newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
tile_id = (row.get("tile_id") or "").strip()
if not tile_id:
continue
xmin = _parse_float(row.get("xmin"))
ymin = _parse_float(row.get("ymin"))
xmax = _parse_float(row.get("xmax"))
ymax = _parse_float(row.get("ymax"))
if xmin is None or ymin is None or xmax is None or ymax is None:
continue
tiles.append((tile_id, (xmin, ymin, xmax, ymax)))
return tiles
def _parse_float(value: str | None) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _write_tile_index(path: str, rows: list[dict]) -> None:
with open(path, "w", encoding="utf-8") as fh:
fh.write("lod,tile_x,tile_y,xmin,ymin,xmax,ymax,tile_size_m,resolution,height_path,porosity_path,building_path\n")
for row in rows:
fh.write(
f"{row['lod']},{row['tile_x']},{row['tile_y']},"
f"{row['xmin']},{row['ymin']},{row['xmax']},{row['ymax']},"
f"{row['tile_size_m']},{row['resolution']},"
f"{row['height_path']},{row['porosity_path']},{row['building_path']}\n"
)
def _write_manifest(
path: str,
cfg: SweLodConfig,
origin_x: float,
origin_y: float,
xmin: float,
ymin: float,
xmax: float,
ymax: float,
) -> None:
payload = {
"origin_x": origin_x,
"origin_y": origin_y,
"bounds": [xmin, ymin, xmax, ymax],
"solid_bias": cfg.solid_bias,
"lods": [asdict(level) for level in cfg.lods],
}
with open(path, "w", encoding="utf-8") as fh:
json.dump(payload, fh, indent=2)
def _open_optional_raster(path: str | None, label: str):
if not path:
return None
if not os.path.exists(path):
print(f"[swe_lods] {label} raster not found: {path}.")
return None
try:
return open_dataset(path, f"[swe_lods] Could not open {label} raster {path}")
except SystemExit as exc:
print(exc)
return None
def _write_boundary_manifest(
cfg: Config,
swe_cfg: SweLodConfig,
tile_rows: list[dict],
*,
force_vrt: bool,
) -> None:
boundary_inflow_mask_dir = swe_cfg.boundary_inflow_mask_dir or swe_cfg.source_mask_dir
source_area_mask_dir = swe_cfg.source_area_mask_dir
sink_mask_dir = swe_cfg.sink_mask_dir
boundary_inflow_ds = _open_boundary_mask_dataset(
cfg,
mask_dir=boundary_inflow_mask_dir,
kind="boundary_inflow",
force_vrt=force_vrt,
)
source_area_ds = _open_boundary_mask_dataset(
cfg,
mask_dir=source_area_mask_dir,
kind="source_area",
force_vrt=force_vrt,
)
sink_ds = _open_boundary_mask_dataset(
cfg,
mask_dir=sink_mask_dir,
kind="sink",
force_vrt=force_vrt,
)
boundary_inflow_params = _load_boundary_params_multi(
swe_cfg.boundary_inflow_params_toml or swe_cfg.source_params_toml,
kinds=("boundary_inflows", "boundary_inflow", "sources", "source"),
)
source_area_params = _load_boundary_params_multi(
swe_cfg.source_area_params_toml,
kinds=("source_areas", "source_area", "sources", "source"),
)
sink_params = _load_boundary_params_multi(
swe_cfg.sink_params_toml,
kinds=("sinks", "sink", "boundary_outflows", "boundary_outflow", "sink_areas", "sink_area"),
merge=True,
)
tiles_payload = []
boundary_inflow_stats: dict[int, dict] = {}
source_area_stats: dict[int, dict] = {}
sink_stats: dict[int, dict] = {}
for row in tile_rows:
bounds = (float(row["xmin"]), float(row["ymin"]), float(row["xmax"]), float(row["ymax"]))
resolution = int(row["resolution"])
lod = str(row["lod"])
tile_x = int(row["tile_x"])
tile_y = int(row["tile_y"])
boundary_inflow_arr = (
_warp_id_array(boundary_inflow_ds, bounds, resolution, resolution)
if boundary_inflow_ds is not None
else None
)
source_area_arr = (
_warp_id_array(source_area_ds, bounds, resolution, resolution)
if source_area_ds is not None
else None
)
sink_arr = _warp_id_array(sink_ds, bounds, resolution, resolution) if sink_ds is not None else None
boundary_inflow_ids = _ids_to_entries(boundary_inflow_arr)
source_area_ids = _ids_to_entries(source_area_arr)
sink_ids = _ids_to_entries(sink_arr)
_accumulate_id_stats(boundary_inflow_stats, boundary_inflow_ids, lod)
_accumulate_id_stats(source_area_stats, source_area_ids, lod)
_accumulate_id_stats(sink_stats, sink_ids, lod)
boundary_cells = _boundary_cells_from_ids(boundary_inflow_arr)
boundary_sink_cells = _boundary_cells_from_ids(sink_arr)
source_area_cells = _cell_groups_from_ids(source_area_arr)
sink_cells = _cell_groups_from_ids(sink_arr)
boundary_inflow_id_path = ""
source_area_id_path = ""
source_id_path = ""
sink_id_path = ""
lod_dir = os.path.join(swe_cfg.out_dir, lod)
if boundary_inflow_arr is not None:
source_dir = os.path.join(lod_dir, "source_ids")
boundary_dir = os.path.join(lod_dir, "boundary_inflow_ids")
ensure_dir(source_dir)
ensure_dir(boundary_dir)
source_id_path = os.path.join(source_dir, f"source_ids_{tile_x}_{tile_y}.exr")
boundary_inflow_id_path = os.path.join(boundary_dir, f"boundary_inflow_ids_{tile_x}_{tile_y}.exr")
_write_exr(source_id_path, boundary_inflow_arr.astype(np.float32, copy=False), swe_cfg.prefer_float16)
_write_exr(boundary_inflow_id_path, boundary_inflow_arr.astype(np.float32, copy=False), swe_cfg.prefer_float16)
if source_area_arr is not None:
source_area_dir = os.path.join(lod_dir, "source_area_ids")
ensure_dir(source_area_dir)
source_area_id_path = os.path.join(source_area_dir, f"source_area_ids_{tile_x}_{tile_y}.exr")
_write_exr(source_area_id_path, source_area_arr.astype(np.float32, copy=False), swe_cfg.prefer_float16)
if sink_arr is not None:
sink_dir = os.path.join(lod_dir, "sink_ids")
ensure_dir(sink_dir)
sink_id_path = os.path.join(sink_dir, f"sink_ids_{tile_x}_{tile_y}.exr")
_write_exr(sink_id_path, sink_arr.astype(np.float32, copy=False), swe_cfg.prefer_float16)
tiles_payload.append(
{
"lod": lod,
"tile_x": tile_x,
"tile_y": tile_y,
"tile_size_m": float(row["tile_size_m"]),
"resolution": resolution,
"bounds": [bounds[0], bounds[1], bounds[2], bounds[3]],
"source_ids": boundary_inflow_ids,
"sink_ids": sink_ids,
"boundary_inflow_ids": boundary_inflow_ids,
"source_area_ids": source_area_ids,
"source_id_path": source_id_path,
"sink_id_path": sink_id_path,
"boundary_inflow_id_path": boundary_inflow_id_path,
"source_area_id_path": source_area_id_path,
"boundary_cells": boundary_cells,
"boundary_sink_cells": boundary_sink_cells,
"source_area_cells": source_area_cells,
"sink_cells": sink_cells,
}
)
payload = {
"schema_version": 2,
"id_encoding": "id = B + 256*G + 65536*R",
"boundary_inflow_mask_dir": boundary_inflow_mask_dir,
"source_area_mask_dir": source_area_mask_dir,
"sink_mask_dir": sink_mask_dir,
"boundary_inflow_params_toml": swe_cfg.boundary_inflow_params_toml,
"source_area_params_toml": swe_cfg.source_area_params_toml,
"sink_params_toml": swe_cfg.sink_params_toml,
# Legacy aliases kept for existing tooling.
"source_mask_dir": boundary_inflow_mask_dir,
"source_params_toml": swe_cfg.boundary_inflow_params_toml or swe_cfg.source_params_toml,
"sources": _merge_stats_with_params(boundary_inflow_stats, boundary_inflow_params),
"sinks": _merge_stats_with_params(sink_stats, sink_params),
"boundaries": _merge_boundary_definitions(
boundary_inflow_stats,
boundary_inflow_params,
source_area_stats,
source_area_params,
sink_stats,
sink_params,
),
"tiles": tiles_payload,
}
with open(swe_cfg.boundary_manifest_path, "w", encoding="utf-8") as fh:
json.dump(payload, fh, indent=2)
print(f"[swe_lods] Wrote boundary manifest: {swe_cfg.boundary_manifest_path}")
def _open_boundary_mask_dataset(
cfg: Config,
*,
mask_dir: str,
kind: str,
force_vrt: bool,
):
if not mask_dir or not os.path.isdir(mask_dir):
print(f"[swe_lods] No {kind} mask directory: {mask_dir}")
return None
sources = sorted(
[
os.path.join(mask_dir, name)
for name in os.listdir(mask_dir)
if name.lower().endswith(".png")
]
)
if not sources:
print(f"[swe_lods] No {kind} mask PNGs found in {mask_dir}")
return None
out_dir = os.path.join(cfg.work.work_dir, "swe_boundaries")
ensure_dir(out_dir)
vrt_path = os.path.join(out_dir, f"{kind}_ids.vrt")
build_vrt(vrt_path, sources, force=force_vrt or True)
return _open_optional_raster(vrt_path, f"{kind} IDs")
def _warp_id_array(
ds,
bounds: tuple[float, float, float, float],
width: int,
height: int,
) -> np.ndarray:
# Keep id=0 as regular background; do not reserve it as NODATA.
warp_opts = gdal.WarpOptions(
format="MEM",
outputBounds=bounds,
width=width,
height=height,
resampleAlg="near",
)
warped = gdal.Warp("", ds, options=warp_opts)
if warped is None:
raise RuntimeError("GDAL Warp for ID mask returned None.")
if warped.RasterCount >= 3:
r = np.rint(warped.GetRasterBand(1).ReadAsArray()).astype(np.uint32)
g = np.rint(warped.GetRasterBand(2).ReadAsArray()).astype(np.uint32)
b = np.rint(warped.GetRasterBand(3).ReadAsArray()).astype(np.uint32)
ids = b + (256 * g) + (65536 * r)
else:
ids = np.rint(warped.GetRasterBand(1).ReadAsArray()).astype(np.uint32)
return ids
def _sample_boundary_ids(ds, bounds: tuple[float, float, float, float], resolution: int) -> list[dict]:
if ds is None:
return []
ids = _warp_id_array(ds, bounds, resolution, resolution)
return _ids_to_entries(ids)
def _ids_to_entries(ids: np.ndarray | None) -> list[dict]:
if ids is None:
return []
u, c = np.unique(ids, return_counts=True)
out = []
for ident, count in zip(u.tolist(), c.tolist()):
if ident == 0:
continue
out.append({"id": int(ident), "pixels": int(count)})
return out
def _accumulate_id_stats(stats: dict[int, dict], ids: list[dict], lod: str) -> None:
for entry in ids:
ident = int(entry["id"])
pixels = int(entry["pixels"])
node = stats.setdefault(
ident,
{
"id": ident,
"tile_count": 0,
"total_pixels": 0,
"lod_pixels": {},
},
)
node["tile_count"] += 1
node["total_pixels"] += pixels
node["lod_pixels"][lod] = int(node["lod_pixels"].get(lod, 0) + pixels)
def _boundary_cells_from_ids(ids: np.ndarray | None) -> list[dict]:
if ids is None:
return []
resolution = int(ids.shape[0])
ghost_resolution = resolution + 2
cells_per_id: dict[int, set[int]] = {}
for y in range(resolution):
for x in range(resolution):
ident = int(ids[y, x])
if ident <= 0:
continue
cell_set = cells_per_id.setdefault(ident, set())
if x == 0:
cell_set.add(((y + 1) * ghost_resolution) + 0)
if x == resolution - 1:
cell_set.add(((y + 1) * ghost_resolution) + (ghost_resolution - 1))
if y == 0:
cell_set.add((0 * ghost_resolution) + (x + 1))
if y == resolution - 1:
cell_set.add(((ghost_resolution - 1) * ghost_resolution) + (x + 1))
out = []
for ident in sorted(cells_per_id.keys()):
cells = sorted(cells_per_id[ident])
out.append(
{
"id": int(ident),
"count": len(cells),
"cells": [int(cell) for cell in cells],
}
)
return out
def _cell_groups_from_ids(ids: np.ndarray | None) -> list[dict]:
if ids is None:
return []
out = []
unique_ids = np.unique(ids)
for ident in unique_ids.tolist():
ident = int(ident)
if ident <= 0:
continue
flat = np.flatnonzero(ids == ident)
cells = [int(v) for v in flat.tolist()]
if not cells:
continue
out.append(
{
"id": ident,
"count": len(cells),
"cells": cells,
}
)
return out
def _load_boundary_params_toml(path: str, *, kind: str) -> dict[int, dict]:
if not path or not os.path.exists(path):
return {}
with open(path, "rb") as fh:
data = tomllib.load(fh)
section = data.get(kind)
if section is None and kind.endswith("s"):
section = data.get(kind[:-1])
if section is None:
return {}
out: dict[int, dict] = {}
if isinstance(section, list):
for idx, item in enumerate(section):
if not isinstance(item, dict):
continue
ident = _parse_int_id(item.get("id"))
if ident is None or ident <= 0:
continue
payload = {k: v for k, v in item.items() if k != "id"}
_set_boundary_param(out, ident, payload, kind=kind, path=path, entry_name=f"list[{idx}]")
elif isinstance(section, dict):
for key, item in section.items():
if not isinstance(item, dict):
continue
# Backward compatible:
# - numeric dict key style: [sources] 1 = {...}
# - named subtables style: [sources.source0] id = 1
ident = _parse_int_id(key)
if ident is None:
ident = _parse_int_id(item.get("id"))
if ident is None or ident <= 0:
continue
payload = {k: v for k, v in item.items() if k != "id"}
_set_boundary_param(out, ident, payload, kind=kind, path=path, entry_name=str(key))
return out
def _load_boundary_params_multi(path: str, *, kinds: tuple[str, ...], merge: bool = False) -> dict[int, dict]:
if not path:
return {}
if not merge:
for kind in kinds:
out = _load_boundary_params_toml(path, kind=kind)
if out:
return out
return {}
out: dict[int, dict] = {}
merged_kind = "|".join(kinds)
for kind in kinds:
section = _load_boundary_params_toml(path, kind=kind)
for ident, payload in section.items():
_set_boundary_param(
out,
ident,
payload,
kind=merged_kind,
path=path,
entry_name=f"{kind}.{ident}",
)
return out
def _parse_int_id(value) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
def _set_boundary_param(
out: dict[int, dict],
ident: int,
payload: dict,
*,
kind: str,
path: str,
entry_name: str,
) -> None:
if ident in out:
print(
f"[swe_lods] Warning: duplicate {kind} params for id={ident} in {path} "
f"(entry '{entry_name}'); overriding previous."
)
out[ident] = payload
def _merge_stats_with_params(stats: dict[int, dict], params: dict[int, dict]) -> list[dict]:
all_ids = sorted(set(stats.keys()) | set(params.keys()))
out = []
for ident in all_ids:
node = {
"id": int(ident),
"tile_count": 0,
"total_pixels": 0,
"lod_pixels": {},
"params": {},
}
if ident in stats:
node["tile_count"] = int(stats[ident]["tile_count"])
node["total_pixels"] = int(stats[ident]["total_pixels"])
node["lod_pixels"] = dict(stats[ident]["lod_pixels"])
if ident in params:
node["params"] = dict(params[ident])
out.append(node)
return out
def _merge_boundary_definitions(
boundary_inflow_stats: dict[int, dict],
boundary_inflow_params: dict[int, dict],
source_area_stats: dict[int, dict],
source_area_params: dict[int, dict],
sink_stats: dict[int, dict],
sink_params: dict[int, dict],
) -> list[dict]:
out = []
out.extend(
_build_boundary_definitions(
"boundary_inflow",
boundary_inflow_stats,
boundary_inflow_params,
)
)
out.extend(
_build_boundary_definitions(
"source_area",
source_area_stats,
source_area_params,
)
)
out.extend(
_build_boundary_definitions(
"sink",
sink_stats,
sink_params,
)
)
return out
def _build_boundary_definitions(kind: str, stats: dict[int, dict], params: dict[int, dict]) -> list[dict]:
all_ids = sorted(set(stats.keys()) | set(params.keys()))
out = []
for ident in all_ids:
node = {
"kind": kind,
"id": int(ident),
"tile_count": 0,
"total_pixels": 0,
"lod_pixels": {},
"params": {},
"default_state": _default_boundary_state(kind, params.get(ident, {})),
}
if ident in stats:
node["tile_count"] = int(stats[ident]["tile_count"])
node["total_pixels"] = int(stats[ident]["total_pixels"])
node["lod_pixels"] = dict(stats[ident]["lod_pixels"])
if ident in params:
node["params"] = dict(params[ident])
out.append(node)
return out
def _default_boundary_state(kind: str, params: dict) -> dict:
mode = str(params.get("mode", "")).strip().lower()
enabled = _parse_bool(params.get("enabled"), default=(kind == "sink" and mode == "free_outflow"))
water_level = _parse_float_or_default(
params.get("water_level_m"),
_parse_float_or_default(params.get("base_level_offset_m"), 0.0),
)
velocity_u = _parse_float_or_default(
params.get("velocity_u_mps"),
_parse_float_or_default(params.get("u_mps"), 0.0),
)
velocity_v = _parse_float_or_default(
params.get("velocity_v_mps"),
_parse_float_or_default(params.get("v_mps"), 0.0),
)
depth_rate = _parse_float_or_default(
params.get("depth_rate_mps"),
_parse_float_or_default(params.get("base_depth_rate_mps"), 0.0),
)
if kind == "sink" and depth_rate > 0.0:
depth_rate = -depth_rate
return {
"enabled": enabled,
"water_level_m": float(water_level),
"velocity_u_mps": float(velocity_u),
"velocity_v_mps": float(velocity_v),
"depth_rate_mps": float(depth_rate),
}
def _parse_bool(value, *, default: bool) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
raw = value.strip().lower()
if raw in {"1", "true", "yes", "on"}:
return True
if raw in {"0", "false", "no", "off"}:
return False
return default
def _parse_float_or_default(value, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return float(default)
def _cleanup_patterns(raw_dir: str) -> Iterable[str]:
return [
os.path.join("work", "*_tmp.tif"),
os.path.join("work", "*_tmp.tif.aux.xml"),
os.path.join("work", "*.aux.xml"),
os.path.join(raw_dir, "*.aux.xml"),
]