Prefer manual _mask_viz for tree water masks
This commit is contained in:
@@ -143,33 +143,38 @@ def _bridge_mask_from_chm(
|
||||
return bridge.astype(np.uint8) if np.any(bridge) else None
|
||||
|
||||
|
||||
def _water_mask_candidates(tile_id: str) -> List[str]:
|
||||
def _water_mask_candidates(tile_id: str, prefer_viz: bool = False) -> List[str]:
|
||||
base_id = tile_id.replace("_1_rp", "_rp")
|
||||
raw = [
|
||||
f"{tile_id}.png",
|
||||
f"{tile_id}_mask.png",
|
||||
f"{tile_id}_mask_viz.png",
|
||||
f"{tile_id}_viz.png",
|
||||
]
|
||||
tile_ids = [tile_id]
|
||||
if base_id != tile_id:
|
||||
tile_ids.append(base_id)
|
||||
|
||||
raw: List[str] = []
|
||||
if prefer_viz:
|
||||
raw.extend(f"{name}_mask_viz.png" for name in tile_ids)
|
||||
for name in tile_ids:
|
||||
raw.extend(
|
||||
[
|
||||
f"{base_id}.png",
|
||||
f"{base_id}_mask.png",
|
||||
f"{base_id}_mask_viz.png",
|
||||
f"{base_id}_viz.png",
|
||||
f"{name}.png",
|
||||
f"{name}_mask.png",
|
||||
]
|
||||
)
|
||||
# Preserve order while removing accidental duplicates.
|
||||
return list(dict.fromkeys(raw))
|
||||
|
||||
|
||||
def _water_mask_from_dir(tile_id: str, like_ds: gdal.Dataset, search_dir: str) -> np.ndarray | None:
|
||||
def _water_mask_from_dir(
|
||||
tile_id: str,
|
||||
like_ds: gdal.Dataset,
|
||||
search_dir: str,
|
||||
*,
|
||||
prefer_viz: bool = False,
|
||||
) -> np.ndarray | None:
|
||||
if not os.path.isdir(search_dir):
|
||||
return None
|
||||
|
||||
mask_path = None
|
||||
for candidate in _water_mask_candidates(tile_id):
|
||||
for candidate in _water_mask_candidates(tile_id, prefer_viz=prefer_viz):
|
||||
candidate_path = os.path.join(search_dir, candidate)
|
||||
if os.path.exists(candidate_path):
|
||||
mask_path = candidate_path
|
||||
@@ -197,18 +202,23 @@ def _water_mask_from_dir(tile_id: str, like_ds: gdal.Dataset, search_dir: str) -
|
||||
"width": width,
|
||||
"height": height,
|
||||
"resampleAlg": "near",
|
||||
"dstNodata": 0,
|
||||
# Keep literal 0-values in binary masks; dstNodata=0 turns source zeros into ones.
|
||||
}
|
||||
if like_proj:
|
||||
warp_kwargs["dstSRS"] = like_proj
|
||||
if not src_proj:
|
||||
warp_kwargs["srcSRS"] = like_proj
|
||||
|
||||
warped = gdal.Warp("", src_ds, options=gdal.WarpOptions(**warp_kwargs))
|
||||
try:
|
||||
warped = gdal.Warp("", src_ds, options=gdal.WarpOptions(**warp_kwargs))
|
||||
except RuntimeError as exc:
|
||||
print(f"[trees] warning: failed to warp water mask '{mask_path}': {exc}")
|
||||
return None
|
||||
if warped is None or warped.RasterCount == 0:
|
||||
return None
|
||||
|
||||
if warped.RasterCount >= 3:
|
||||
# Manual water masks encode water in the blue channel.
|
||||
blue = warped.GetRasterBand(3).ReadAsArray()
|
||||
mask = (blue > 0).astype(np.uint8)
|
||||
else:
|
||||
@@ -225,8 +235,11 @@ def _water_mask(
|
||||
) -> Tuple[np.ndarray | None, str]:
|
||||
# Preferred source order:
|
||||
# 1) curated raw masks, 2) generated river masks, 3) LiDAR classification fallback.
|
||||
for search_dir, label in (("raw/water_masks", "raw"), ("work/river_masks", "river")):
|
||||
mask = _water_mask_from_dir(tile_id, like_ds, search_dir)
|
||||
for search_dir, label, prefer_viz in (
|
||||
("raw/water_masks", "raw", True),
|
||||
("work/river_masks", "river", False),
|
||||
):
|
||||
mask = _water_mask_from_dir(tile_id, like_ds, search_dir, prefer_viz=prefer_viz)
|
||||
if mask is not None:
|
||||
dilate_px = max(1, int(round(1.5 / max(cfg.trees.grid_res_m, 0.1))))
|
||||
mask = _dilate_mask(mask, dilate_px)
|
||||
|
||||
70
tests/test_trees_water_mask.py
Normal file
70
tests/test_trees_water_mask.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
from osgeo import gdal, osr
|
||||
|
||||
from geodata_pipeline.trees import _water_mask_candidates, _water_mask_from_dir
|
||||
|
||||
|
||||
def _srs_wkt() -> str:
|
||||
srs = osr.SpatialReference()
|
||||
srs.ImportFromEPSG(25832)
|
||||
return srs.ExportToWkt()
|
||||
|
||||
|
||||
def _write_mask_png(path, block, width=4, height=4) -> None:
|
||||
src = gdal.GetDriverByName("MEM").Create("", width, height, 3, gdal.GDT_Byte)
|
||||
src.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
|
||||
src.SetProjection(_srs_wkt())
|
||||
for idx in range(1, 4):
|
||||
band = np.zeros((height, width), dtype=np.uint8)
|
||||
band[block] = 255
|
||||
src.GetRasterBand(idx).WriteArray(band)
|
||||
gdal.GetDriverByName("PNG").CreateCopy(str(path), src, options=["WORLDFILE=YES"])
|
||||
|
||||
|
||||
def test_water_mask_candidates_default_exclude_visualization_files() -> None:
|
||||
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=False)
|
||||
assert all(not name.endswith("_viz.png") for name in candidates)
|
||||
assert all("_mask_viz" not in name for name in candidates)
|
||||
|
||||
|
||||
def test_water_mask_candidates_prefer_viz_and_base_id_fallback() -> None:
|
||||
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=True)
|
||||
assert candidates[0] == "dgm01_32_324_5506_1_rp_mask_viz.png"
|
||||
assert candidates[1] == "dgm01_32_324_5506_rp_mask_viz.png"
|
||||
assert "dgm01_32_324_5506_1_rp_mask.png" in candidates
|
||||
assert "dgm01_32_324_5506_rp_mask.png" in candidates
|
||||
|
||||
|
||||
def test_water_mask_from_dir_preserves_zero_background(tmp_path) -> None:
|
||||
tile_id = "dgm01_32_000_0000_1_rp"
|
||||
mask_path = tmp_path / f"{tile_id}_mask.png"
|
||||
|
||||
_write_mask_png(mask_path, np.s_[1:3, 1:3])
|
||||
|
||||
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
|
||||
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
|
||||
like.SetProjection(_srs_wkt())
|
||||
|
||||
mask = _water_mask_from_dir(tile_id, like, str(tmp_path))
|
||||
assert mask is not None
|
||||
assert int(np.sum(mask > 0)) == 4
|
||||
|
||||
|
||||
def test_water_mask_from_dir_prefers_viz_when_requested(tmp_path) -> None:
|
||||
tile_id = "dgm01_32_000_0000_1_rp"
|
||||
viz_path = tmp_path / "dgm01_32_000_0000_rp_mask_viz.png"
|
||||
mask_path = tmp_path / f"{tile_id}_mask.png"
|
||||
|
||||
_write_mask_png(viz_path, np.s_[0:1, 0:1])
|
||||
_write_mask_png(mask_path, np.s_[1:3, 1:3])
|
||||
|
||||
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
|
||||
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
|
||||
like.SetProjection(_srs_wkt())
|
||||
|
||||
viz_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=True)
|
||||
plain_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=False)
|
||||
assert viz_mask is not None
|
||||
assert plain_mask is not None
|
||||
assert int(np.sum(viz_mask > 0)) == 1
|
||||
assert int(np.sum(plain_mask > 0)) == 4
|
||||
Reference in New Issue
Block a user