Prefer manual _mask_viz for tree water masks

This commit is contained in:
2026-02-11 13:43:10 +01:00
parent 227b89cccb
commit 3422979ebf
2 changed files with 100 additions and 17 deletions

View File

@@ -143,33 +143,38 @@ def _bridge_mask_from_chm(
return bridge.astype(np.uint8) if np.any(bridge) else None
def _water_mask_candidates(tile_id: str) -> List[str]:
def _water_mask_candidates(tile_id: str, prefer_viz: bool = False) -> List[str]:
base_id = tile_id.replace("_1_rp", "_rp")
raw = [
f"{tile_id}.png",
f"{tile_id}_mask.png",
f"{tile_id}_mask_viz.png",
f"{tile_id}_viz.png",
]
tile_ids = [tile_id]
if base_id != tile_id:
tile_ids.append(base_id)
raw: List[str] = []
if prefer_viz:
raw.extend(f"{name}_mask_viz.png" for name in tile_ids)
for name in tile_ids:
raw.extend(
[
f"{base_id}.png",
f"{base_id}_mask.png",
f"{base_id}_mask_viz.png",
f"{base_id}_viz.png",
f"{name}.png",
f"{name}_mask.png",
]
)
# Preserve order while removing accidental duplicates.
return list(dict.fromkeys(raw))
def _water_mask_from_dir(tile_id: str, like_ds: gdal.Dataset, search_dir: str) -> np.ndarray | None:
def _water_mask_from_dir(
tile_id: str,
like_ds: gdal.Dataset,
search_dir: str,
*,
prefer_viz: bool = False,
) -> np.ndarray | None:
if not os.path.isdir(search_dir):
return None
mask_path = None
for candidate in _water_mask_candidates(tile_id):
for candidate in _water_mask_candidates(tile_id, prefer_viz=prefer_viz):
candidate_path = os.path.join(search_dir, candidate)
if os.path.exists(candidate_path):
mask_path = candidate_path
@@ -197,18 +202,23 @@ def _water_mask_from_dir(tile_id: str, like_ds: gdal.Dataset, search_dir: str) -
"width": width,
"height": height,
"resampleAlg": "near",
"dstNodata": 0,
# Keep literal 0-values in binary masks; dstNodata=0 turns source zeros into ones.
}
if like_proj:
warp_kwargs["dstSRS"] = like_proj
if not src_proj:
warp_kwargs["srcSRS"] = like_proj
warped = gdal.Warp("", src_ds, options=gdal.WarpOptions(**warp_kwargs))
try:
warped = gdal.Warp("", src_ds, options=gdal.WarpOptions(**warp_kwargs))
except RuntimeError as exc:
print(f"[trees] warning: failed to warp water mask '{mask_path}': {exc}")
return None
if warped is None or warped.RasterCount == 0:
return None
if warped.RasterCount >= 3:
# Manual water masks encode water in the blue channel.
blue = warped.GetRasterBand(3).ReadAsArray()
mask = (blue > 0).astype(np.uint8)
else:
@@ -225,8 +235,11 @@ def _water_mask(
) -> Tuple[np.ndarray | None, str]:
# Preferred source order:
# 1) curated raw masks, 2) generated river masks, 3) LiDAR classification fallback.
for search_dir, label in (("raw/water_masks", "raw"), ("work/river_masks", "river")):
mask = _water_mask_from_dir(tile_id, like_ds, search_dir)
for search_dir, label, prefer_viz in (
("raw/water_masks", "raw", True),
("work/river_masks", "river", False),
):
mask = _water_mask_from_dir(tile_id, like_ds, search_dir, prefer_viz=prefer_viz)
if mask is not None:
dilate_px = max(1, int(round(1.5 / max(cfg.trees.grid_res_m, 0.1))))
mask = _dilate_mask(mask, dilate_px)

View File

@@ -0,0 +1,70 @@
import numpy as np
from osgeo import gdal, osr
from geodata_pipeline.trees import _water_mask_candidates, _water_mask_from_dir
def _srs_wkt() -> str:
srs = osr.SpatialReference()
srs.ImportFromEPSG(25832)
return srs.ExportToWkt()
def _write_mask_png(path, block, width=4, height=4) -> None:
src = gdal.GetDriverByName("MEM").Create("", width, height, 3, gdal.GDT_Byte)
src.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
src.SetProjection(_srs_wkt())
for idx in range(1, 4):
band = np.zeros((height, width), dtype=np.uint8)
band[block] = 255
src.GetRasterBand(idx).WriteArray(band)
gdal.GetDriverByName("PNG").CreateCopy(str(path), src, options=["WORLDFILE=YES"])
def test_water_mask_candidates_default_exclude_visualization_files() -> None:
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=False)
assert all(not name.endswith("_viz.png") for name in candidates)
assert all("_mask_viz" not in name for name in candidates)
def test_water_mask_candidates_prefer_viz_and_base_id_fallback() -> None:
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=True)
assert candidates[0] == "dgm01_32_324_5506_1_rp_mask_viz.png"
assert candidates[1] == "dgm01_32_324_5506_rp_mask_viz.png"
assert "dgm01_32_324_5506_1_rp_mask.png" in candidates
assert "dgm01_32_324_5506_rp_mask.png" in candidates
def test_water_mask_from_dir_preserves_zero_background(tmp_path) -> None:
tile_id = "dgm01_32_000_0000_1_rp"
mask_path = tmp_path / f"{tile_id}_mask.png"
_write_mask_png(mask_path, np.s_[1:3, 1:3])
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
like.SetProjection(_srs_wkt())
mask = _water_mask_from_dir(tile_id, like, str(tmp_path))
assert mask is not None
assert int(np.sum(mask > 0)) == 4
def test_water_mask_from_dir_prefers_viz_when_requested(tmp_path) -> None:
tile_id = "dgm01_32_000_0000_1_rp"
viz_path = tmp_path / "dgm01_32_000_0000_rp_mask_viz.png"
mask_path = tmp_path / f"{tile_id}_mask.png"
_write_mask_png(viz_path, np.s_[0:1, 0:1])
_write_mask_png(mask_path, np.s_[1:3, 1:3])
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
like.SetProjection(_srs_wkt())
viz_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=True)
plain_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=False)
assert viz_mask is not None
assert plain_mask is not None
assert int(np.sum(viz_mask > 0)) == 1
assert int(np.sum(plain_mask > 0)) == 4