71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
import numpy as np
|
|
from osgeo import gdal, osr
|
|
|
|
from geodata_pipeline.trees import _water_mask_candidates, _water_mask_from_dir
|
|
|
|
|
|
def _srs_wkt() -> str:
|
|
srs = osr.SpatialReference()
|
|
srs.ImportFromEPSG(25832)
|
|
return srs.ExportToWkt()
|
|
|
|
|
|
def _write_mask_png(path, block, width=4, height=4) -> None:
|
|
src = gdal.GetDriverByName("MEM").Create("", width, height, 3, gdal.GDT_Byte)
|
|
src.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
|
|
src.SetProjection(_srs_wkt())
|
|
for idx in range(1, 4):
|
|
band = np.zeros((height, width), dtype=np.uint8)
|
|
band[block] = 255
|
|
src.GetRasterBand(idx).WriteArray(band)
|
|
gdal.GetDriverByName("PNG").CreateCopy(str(path), src, options=["WORLDFILE=YES"])
|
|
|
|
|
|
def test_water_mask_candidates_default_exclude_visualization_files() -> None:
|
|
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=False)
|
|
assert all(not name.endswith("_viz.png") for name in candidates)
|
|
assert all("_mask_viz" not in name for name in candidates)
|
|
|
|
|
|
def test_water_mask_candidates_prefer_viz_and_base_id_fallback() -> None:
|
|
candidates = _water_mask_candidates("dgm01_32_324_5506_1_rp", prefer_viz=True)
|
|
assert candidates[0] == "dgm01_32_324_5506_1_rp_mask_viz.png"
|
|
assert candidates[1] == "dgm01_32_324_5506_rp_mask_viz.png"
|
|
assert "dgm01_32_324_5506_1_rp_mask.png" in candidates
|
|
assert "dgm01_32_324_5506_rp_mask.png" in candidates
|
|
|
|
|
|
def test_water_mask_from_dir_preserves_zero_background(tmp_path) -> None:
|
|
tile_id = "dgm01_32_000_0000_1_rp"
|
|
mask_path = tmp_path / f"{tile_id}_mask.png"
|
|
|
|
_write_mask_png(mask_path, np.s_[1:3, 1:3])
|
|
|
|
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
|
|
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
|
|
like.SetProjection(_srs_wkt())
|
|
|
|
mask = _water_mask_from_dir(tile_id, like, str(tmp_path))
|
|
assert mask is not None
|
|
assert int(np.sum(mask > 0)) == 4
|
|
|
|
|
|
def test_water_mask_from_dir_prefers_viz_when_requested(tmp_path) -> None:
|
|
tile_id = "dgm01_32_000_0000_1_rp"
|
|
viz_path = tmp_path / "dgm01_32_000_0000_rp_mask_viz.png"
|
|
mask_path = tmp_path / f"{tile_id}_mask.png"
|
|
|
|
_write_mask_png(viz_path, np.s_[0:1, 0:1])
|
|
_write_mask_png(mask_path, np.s_[1:3, 1:3])
|
|
|
|
like = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, gdal.GDT_Byte)
|
|
like.SetGeoTransform((0.5, 1.0, 0.0, 4.5, 0.0, -1.0))
|
|
like.SetProjection(_srs_wkt())
|
|
|
|
viz_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=True)
|
|
plain_mask = _water_mask_from_dir(tile_id, like, str(tmp_path), prefer_viz=False)
|
|
assert viz_mask is not None
|
|
assert plain_mask is not None
|
|
assert int(np.sum(viz_mask > 0)) == 1
|
|
assert int(np.sum(plain_mask > 0)) == 4
|