Files
GeoData/tests/test_trees_water_mask.py

71 lines
2.7 KiB
Python

import numpy as np
from osgeo import gdal, osr
from geodata_pipeline.trees import _water_mask_candidates, _water_mask_from_dir
def _srs_wkt() -> str:
srs = osr.SpatialReference()
srs.ImportFromEPSG(25832)
return srs.ExportToWkt()
def _write_mask_png(path, block, width=4, height=4) -> None:
src = gdal.GetDriverByName("MEM").Create("", width, height, 3, gdal.GDT_Byte)
src.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
src.SetProjection(_srs_wkt())
for idx in range(1, 4):
band = np.zeros((height, width), dtype=np.uint8)
band[block] = 255
src.GetRasterBand(idx).WriteArray(band)
gdal.GetDriverByName("PNG").CreateCopy(str(path), src, options=["WORLDFILE=YES"])
def test_water_mask_candidates_default_exclude_visualization_files() -> None:
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=False)
assert all(not name.endswith("_viz.png") for name in candidates)
assert all("_mask_viz" not in name for name in candidates)
def test_water_mask_candidates_prefer_viz_and_base_id_fallback() -> None:
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=True)
assert candidates[0] == "dgm01_32_324_5506_1_rp_mask_viz.png"
assert candidates[1] == "dgm01_32_324_5506_rp_mask_viz.png"
assert "dgm01_32_324_5506_1_rp_mask.png" in candidates
assert "dgm01_32_324_5506_rp_mask.png" in candidates
def test_water_mask_from_dir_preserves_zero_background(tmp_path) -> None:
tile_id = "dgm01_32_000_0000_1_rp"
mask_path = tmp_path / f"{tile_id}_mask.png"
_write_mask_png(mask_path, np.s_[1:3, 1:3])
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
like.SetProjection(_srs_wkt())
mask = _water_mask_from_dir(tile_id, like, str(tmp_path))
assert mask is not None
assert int(np.sum(mask > 0)) == 4
def test_water_mask_from_dir_prefers_viz_when_requested(tmp_path) -> None:
tile_id = "dgm01_32_000_0000_1_rp"
viz_path = tmp_path / "dgm01_32_000_0000_rp_mask_viz.png"
mask_path = tmp_path / f"{tile_id}_mask.png"
_write_mask_png(viz_path, np.s_[0:1, 0:1])
_write_mask_png(mask_path, np.s_[1:3, 1:3])
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
like.SetProjection(_srs_wkt())
viz_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=True)
plain_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=False)
assert viz_mask is not None
assert plain_mask is not None
assert int(np.sum(viz_mask > 0)) == 1
assert int(np.sum(plain_mask > 0)) == 4