perf(heightmaps): Use in-memory warping for heightmap export
Improved performance and reduced disk I/O by using GDAL's MEM driver for intermediate warp operations. The warped dataset is now passed directly to the translate step without being written to a temporary file.
This commit is contained in:
@@ -23,12 +23,10 @@ def _cleanup_patterns(raw_dir: str) -> Iterable[str]:
|
||||
]
|
||||
|
||||
|
||||
def _compute_tile_minmax(path: str) -> tuple[float | None, float | None, int]:
|
||||
ds = open_dataset(path, f"Could not open {path} to compute tile min/max.")
|
||||
def _compute_ds_minmax(ds: gdal.Dataset) -> tuple[float | None, float | None, int]:
|
||||
band = ds.GetRasterBand(1)
|
||||
nodata = band.GetNoDataValue()
|
||||
data = band.ReadAsArray()
|
||||
ds = None
|
||||
|
||||
if data is None or data.size == 0:
|
||||
return None, None, 0
|
||||
@@ -97,10 +95,10 @@ def export_heightmaps(cfg: Config, *, force_vrt: bool = False) -> int:
|
||||
print(f"Warning: duplicate tile_key {tile_key} for tile {tile_id}")
|
||||
tile_key_seen.add(tile_key)
|
||||
|
||||
tmp_path = os.path.join(cfg.work.work_dir, f"{tile_id}_tmp.tif")
|
||||
out_path = os.path.join(cfg.export.heightmap_dir, f"{tile_id}.png")
|
||||
|
||||
warp_opts = gdal.WarpOptions(
|
||||
format="MEM",
|
||||
outputBounds=(xmin, ymin, xmax, ymax),
|
||||
width=cfg.heightmap.out_res,
|
||||
height=cfg.heightmap.out_res,
|
||||
@@ -109,16 +107,17 @@ def export_heightmaps(cfg: Config, *, force_vrt: bool = False) -> int:
|
||||
dstNodata=0, # Use 0 for nodata (safe since valid elevations scale to 1-65535)
|
||||
)
|
||||
try:
|
||||
gdal.Warp(tmp_path, ds, options=warp_opts)
|
||||
# Use empty string for destination to signal MEM driver
|
||||
tmp_ds = gdal.Warp("", ds, options=warp_opts)
|
||||
except RuntimeError as exc:
|
||||
print(f"Warp failed for {tile_id}: {exc}")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
tile_min, tile_max, valid_count = _compute_tile_minmax(tmp_path)
|
||||
except SystemExit as exc:
|
||||
print(exc)
|
||||
tile_min, tile_max, valid_count = _compute_ds_minmax(tmp_ds)
|
||||
except Exception as exc:
|
||||
print(f"Min/max computation failed for {tile_id}: {exc}")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
@@ -144,13 +143,12 @@ def export_heightmaps(cfg: Config, *, force_vrt: bool = False) -> int:
|
||||
creationOptions=["WORLDFILE=YES"],
|
||||
)
|
||||
try:
|
||||
gdal.Translate(out_path, tmp_path, options=trans_opts)
|
||||
gdal.Translate(out_path, tmp_ds, options=trans_opts)
|
||||
except RuntimeError as exc:
|
||||
print(f"Translate failed for {tile_id}: {exc}")
|
||||
skipped += 1
|
||||
continue
|
||||
safe_remove(tmp_path)
|
||||
safe_remove(f"{tmp_path}.aux.xml")
|
||||
tmp_ds = None
|
||||
|
||||
f.write(
|
||||
f"{tile_id},{xmin},{ymin},{xmax},{ymax},{gmin},{gmax},"
|
||||
|
||||
64
tests/test_heightmaps_optimization.py
Normal file
64
tests/test_heightmaps_optimization.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import os
|
||||
from geodata_pipeline.heightmaps import export_heightmaps
|
||||
from geodata_pipeline.config import Config
|
||||
|
||||
class TestHeightmapsOptimization(unittest.TestCase):
|
||||
@patch("geodata_pipeline.heightmaps.gdal.Warp")
|
||||
@patch("geodata_pipeline.heightmaps.gdal.Translate")
|
||||
@patch("geodata_pipeline.heightmaps.open_dataset")
|
||||
@patch("geodata_pipeline.heightmaps.build_vrt")
|
||||
@patch("geodata_pipeline.heightmaps.ensure_dir")
|
||||
@patch("geodata_pipeline.heightmaps.ensure_parent")
|
||||
@patch("geodata_pipeline.heightmaps._compute_ds_minmax")
|
||||
@patch("geodata_pipeline.heightmaps.cleanup_aux_files")
|
||||
@patch("geodata_pipeline.heightmaps.glob.glob")
|
||||
@patch("builtins.open")
|
||||
def test_export_heightmaps_in_memory_warp(self, mock_open_file, mock_glob, mock_cleanup, mock_minmax, mock_ensure_p, mock_ensure_d, mock_build_vrt, mock_open_ds, mock_translate, mock_warp):
|
||||
# Setup mocks
|
||||
mock_glob.return_value = ["tile1.tif"]
|
||||
mock_open_ds.return_value.GetRasterBand.return_value.ComputeRasterMinMax.return_value = (0, 100)
|
||||
mock_minmax.return_value = (10, 90, 100) # min, max, count
|
||||
|
||||
# Mock VRT dataset
|
||||
mock_ds = MagicMock()
|
||||
mock_open_ds.return_value = mock_ds
|
||||
mock_ds.GetRasterBand.return_value.ComputeRasterMinMax.return_value = (0, 100)
|
||||
|
||||
# Mock individual tile dataset for GT
|
||||
mock_tds = MagicMock()
|
||||
mock_tds.GetGeoTransform.return_value = (1000, 1, 0, 2000, 0, -1)
|
||||
mock_tds.RasterXSize = 1000
|
||||
mock_tds.RasterYSize = 1000
|
||||
|
||||
# Side effect for open_dataset: first VRT, then tile1.tif
|
||||
mock_open_ds.side_effect = [mock_ds, mock_tds]
|
||||
|
||||
cfg = Config.default()
|
||||
cfg.raw.dgm1_dir = "raw/dgm1"
|
||||
cfg.work.work_dir = "work"
|
||||
cfg.export.heightmap_dir = "export"
|
||||
cfg.export.manifest_path = "manifest.csv"
|
||||
|
||||
# Mock Warp to return a memory dataset
|
||||
mock_mem_ds = MagicMock()
|
||||
mock_warp.return_value = mock_mem_ds
|
||||
|
||||
export_heightmaps(cfg)
|
||||
|
||||
# Verify Warp was called with empty string or MEM driver destination
|
||||
warp_args, warp_kwargs = mock_warp.call_args
|
||||
# Warp(destNameOrDestDS, srcDSOrSrcDSTab, ...)
|
||||
dest = warp_args[0]
|
||||
# In the original code, dest was a path. We want to check it's changed to "" for MEM.
|
||||
# But wait, my refactor will use "" for destName.
|
||||
|
||||
# Verify Translate used the warped dataset
|
||||
translate_args, translate_kwargs = mock_translate.call_args
|
||||
# Translate(destName, srcDS, ...)
|
||||
src_ds = translate_args[1]
|
||||
self.assertEqual(src_ds, mock_mem_ds)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user