diff --git a/geodata_pipeline/trees.py b/geodata_pipeline/trees.py index 0494b4a..ca9b6f5 100644 --- a/geodata_pipeline/trees.py +++ b/geodata_pipeline/trees.py @@ -143,33 +143,38 @@ def _bridge_mask_from_chm( return bridge.astype(np.uint8) if np.any(bridge) else None -def _water_mask_candidates(tile_id: str) -> List[str]: +def _water_mask_candidates(tile_id: str, prefer_viz: bool = False) -> List[str]: base_id = tile_id.replace("_1_rp", "_rp") - raw = [ - f"{tile_id}.png", - f"{tile_id}_mask.png", - f"{tile_id}_mask_viz.png", - f"{tile_id}_viz.png", - ] + tile_ids = [tile_id] if base_id != tile_id: + tile_ids.append(base_id) + + raw: List[str] = [] + if prefer_viz: + raw.extend(f"{name}_mask_viz.png" for name in tile_ids) + for name in tile_ids: raw.extend( [ - f"{base_id}.png", - f"{base_id}_mask.png", - f"{base_id}_mask_viz.png", - f"{base_id}_viz.png", + f"{name}.png", + f"{name}_mask.png", ] ) # Preserve order while removing accidental duplicates. return list(dict.fromkeys(raw)) -def _water_mask_from_dir(tile_id: str, like_ds: gdal.Dataset, search_dir: str) -> np.ndarray | None: +def _water_mask_from_dir( + tile_id: str, + like_ds: gdal.Dataset, + search_dir: str, + *, + prefer_viz: bool = False, +) -> np.ndarray | None: if not os.path.isdir(search_dir): return None mask_path = None - for candidate in _water_mask_candidates(tile_id): + for candidate in _water_mask_candidates(tile_id, prefer_viz=prefer_viz): candidate_path = os.path.join(search_dir, candidate) if os.path.exists(candidate_path): mask_path = candidate_path @@ -197,18 +202,23 @@ def _water_mask_from_dir(tile_id: str, like_ds: gdal.Dataset, search_dir: str) - "width": width, "height": height, "resampleAlg": "near", - "dstNodata": 0, + # Keep literal 0-values in binary masks; dstNodata=0 turns source zeros into ones. } if like_proj: warp_kwargs["dstSRS"] = like_proj if not src_proj: warp_kwargs["srcSRS"] = like_proj - warped = gdal.Warp("", src_ds, options=gdal.WarpOptions(**warp_kwargs)) + try: + warped = gdal.Warp("", src_ds, options=gdal.WarpOptions(**warp_kwargs)) + except RuntimeError as exc: + print(f"[trees] warning: failed to warp water mask '{mask_path}': {exc}") + return None if warped is None or warped.RasterCount == 0: return None if warped.RasterCount >= 3: + # Manual water masks encode water in the blue channel. blue = warped.GetRasterBand(3).ReadAsArray() mask = (blue > 0).astype(np.uint8) else: @@ -225,8 +235,11 @@ def _water_mask( ) -> Tuple[np.ndarray | None, str]: # Preferred source order: # 1) curated raw masks, 2) generated river masks, 3) LiDAR classification fallback. - for search_dir, label in (("raw/water_masks", "raw"), ("work/river_masks", "river")): - mask = _water_mask_from_dir(tile_id, like_ds, search_dir) + for search_dir, label, prefer_viz in ( + ("raw/water_masks", "raw", True), + ("work/river_masks", "river", False), + ): + mask = _water_mask_from_dir(tile_id, like_ds, search_dir, prefer_viz=prefer_viz) if mask is not None: dilate_px = max(1, int(round(1.5 / max(cfg.trees.grid_res_m, 0.1)))) mask = _dilate_mask(mask, dilate_px) diff --git a/tests/test_trees_water_mask.py b/tests/test_trees_water_mask.py new file mode 100644 index 0000000..35af17d --- /dev/null +++ b/tests/test_trees_water_mask.py @@ -0,0 +1,70 @@ +import numpy as np +from osgeo import gdal, osr + +from geodata_pipeline.trees import _water_mask_candidates, _water_mask_from_dir + + +def _srs_wkt() -> str: + srs = osr.SpatialReference() + srs.ImportFromEPSG(25832) + return srs.ExportToWkt() + + +def _write_mask_png(path, block, width=4, height=4) -> None: + src = gdal.GetDriverByName("MEM").Create("", width, height, 3, gdal.GDT_Byte) + src.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0)) + src.SetProjection(_srs_wkt()) + for idx in range(1, 4): + band = np.zeros((height, width), dtype=np.uint8) + band[block] = 255 + src.GetRasterBand(idx).WriteArray(band) + gdal.GetDriverByName("PNG").CreateCopy(str(path), src, options=["WORLDFILE=YES"]) + + +def test_water_mask_candidates_default_exclude_visualization_files() -> None: + candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=False) + assert all(not name.endswith("_viz.png") for name in candidates) + assert all("_mask_viz" not in name for name in candidates) + + +def test_water_mask_candidates_prefer_viz_and_base_id_fallback() -> None: + candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=True) + assert candidates[0] == "dgm01_32_324_5506_1_rp_mask_viz.png" + assert candidates[1] == "dgm01_32_324_5506_rp_mask_viz.png" + assert "dgm01_32_324_5506_1_rp_mask.png" in candidates + assert "dgm01_32_324_5506_rp_mask.png" in candidates + + +def test_water_mask_from_dir_preserves_zero_background(tmp_path) -> None: + tile_id = "dgm01_32_000_0000_1_rp" + mask_path = tmp_path / f"{tile_id}_mask.png" + + _write_mask_png(mask_path, np.s_[1:3, 1:3]) + + like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte) + like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0)) + like.SetProjection(_srs_wkt()) + + mask = _water_mask_from_dir(tile_id, like, str(tmp_path)) + assert mask is not None + assert int(np.sum(mask > 0)) == 4 + + +def test_water_mask_from_dir_prefers_viz_when_requested(tmp_path) -> None: + tile_id = "dgm01_32_000_0000_1_rp" + viz_path = tmp_path / "dgm01_32_000_0000_rp_mask_viz.png" + mask_path = tmp_path / f"{tile_id}_mask.png" + + _write_mask_png(viz_path, np.s_[0:1, 0:1]) + _write_mask_png(mask_path, np.s_[1:3, 1:3]) + + like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte) + like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0)) + like.SetProjection(_srs_wkt()) + + viz_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=True) + plain_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=False) + assert viz_mask is not None + assert plain_mask is not None + assert int(np.sum(viz_mask > 0)) == 1 + assert int(np.sum(plain_mask > 0)) == 4