import numpy as np from osgeo import gdal, osr from geodata_pipeline.trees import _water_mask_candidates, _water_mask_from_dir def _srs_wkt() -> str: srs = osr.SpatialReference() srs.ImportFromEPSG(25832) return srs.ExportToWkt() def _write_mask_png(path, block, width=4, height=4) -> None: src = gdal.GetDriverByName("MEM").Create("", width, height, 3, gdal.GDT_Byte) src.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0)) src.SetProjection(_srs_wkt()) for idx in range(1, 4): band = np.zeros((height, width), dtype=np.uint8) band[block] = 255 src.GetRasterBand(idx).WriteArray(band) gdal.GetDriverByName("PNG").CreateCopy(str(path), src, options=["WORLDFILE=YES"]) def test_water_mask_candidates_default_exclude_visualization_files() -> None: candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=False) assert all(not name.endswith("_viz.png") for name in candidates) assert all("_mask_viz" not in name for name in candidates) def test_water_mask_candidates_prefer_viz_and_base_id_fallback() -> None: candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=True) assert candidates[0] == "dgm01_32_324_5506_1_rp_mask_viz.png" assert candidates[1] == "dgm01_32_324_5506_rp_mask_viz.png" assert "dgm01_32_324_5506_1_rp_mask.png" in candidates assert "dgm01_32_324_5506_rp_mask.png" in candidates def test_water_mask_from_dir_preserves_zero_background(tmp_path) -> None: tile_id = "dgm01_32_000_0000_1_rp" mask_path = tmp_path / f"{tile_id}_mask.png" _write_mask_png(mask_path, np.s_[1:3, 1:3]) like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte) like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0)) like.SetProjection(_srs_wkt()) mask = _water_mask_from_dir(tile_id, like, str(tmp_path)) assert mask is not None assert int(np.sum(mask > 0)) == 4 def test_water_mask_from_dir_prefers_viz_when_requested(tmp_path) -> None: tile_id = "dgm01_32_000_0000_1_rp" viz_path = tmp_path / "dgm01_32_000_0000_rp_mask_viz.png" mask_path = tmp_path / f"{tile_id}_mask.png" _write_mask_png(viz_path, np.s_[0:1, 0:1]) _write_mask_png(mask_path, np.s_[1:3, 1:3]) like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte) like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0)) like.SetProjection(_srs_wkt()) viz_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=True) plain_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=False) assert viz_mask is not None assert plain_mask is not None assert int(np.sum(viz_mask > 0)) == 1 assert int(np.sum(plain_mask > 0)) == 4