diff --git a/geodata_pipeline/buildings.py b/geodata_pipeline/buildings.py index 6bf0621..51df4ea 100644 --- a/geodata_pipeline/buildings.py +++ b/geodata_pipeline/buildings.py @@ -315,6 +315,43 @@ def _compose_glb( return b"".join([header, json_header, json_padded, bin_header, bin_padded]) +def _extract_texture_from_vrt(bounds: Tuple[float, float, float, float], vrt_ds: gdal.Dataset, width_px: int) -> bytes | None: + """Extract a JPEG texture for the given bounds from the VRT.""" + xmin, ymin, xmax, ymax = bounds + + # Use VSIMEM for in-memory processing + mem_path = "/vsimem/tile_tex.jpg" + + warp_opts = gdal.WarpOptions( + format="MEM", + outputBounds=(xmin, ymin, xmax, ymax), + width=width_px, + height=width_px, + resampleAlg="bilinear", + ) + + try: + tmp_ds = gdal.Warp("", vrt_ds, options=warp_opts) + if not tmp_ds: + return None + + gdal.Translate(mem_path, tmp_ds, format="JPEG", creationOptions=["QUALITY=90"]) + + f = gdal.VSIFOpenL(mem_path, "rb") + if f: + gdal.VSIFSeekL(f, 0, 2) + size = gdal.VSIFTellL(f) + gdal.VSIFSeekL(f, 0, 0) + data = gdal.VSIFReadL(1, size, f) + gdal.VSIFCloseL(f) + gdal.VSIUnlink(mem_path) + return data + except Exception as exc: + print(f"[buildings] failed to extract texture from VRT: {exc}") + + return None + + def _load_ortho(tile_id: str, path: str) -> bytes | None: if not os.path.exists(path): print(f"[buildings] missing ortho for {tile_id}: {path}") @@ -445,6 +482,15 @@ def export_buildings(cfg: Config) -> int: if not os.path.exists(cfg.export.manifest_path): raise SystemExit(f"Tile index missing: {cfg.export.manifest_path}. Run heightmap export first.") + # Open VRTs once for performance + dgm_vrt_ds = gdal.Open(cfg.work.heightmap_vrt) + ortho_vrt_ds = gdal.Open(cfg.work.ortho_vrt) + + if not dgm_vrt_ds: + print(f"[buildings] warning: could not open heightmap VRT {cfg.work.heightmap_vrt}") + if not ortho_vrt_ds: + print(f"[buildings] warning: could not open orthophoto VRT {cfg.work.ortho_vrt}") + import csv written = 0 @@ -487,40 +533,73 @@ def export_buildings(cfg: Config) -> int: if wall_faces.size: wall_faces = _decimate(wall_faces, wall_budget or len(wall_faces)) - # Ground snap (simple: clamp below-ground vertices up to DTM) - try: - dgm_ds = gdal.Open(cfg.work.heightmap_vrt) - if dgm_ds: - gt = dgm_ds.GetGeoTransform() - band = dgm_ds.GetRasterBand(1) - # Calculate srcwin for the tile - xmin, ymin, xmax, ymax = bounds - xoff = int((xmin - gt[0]) / gt[1]) - yoff = int((ymax - gt[3]) / gt[5]) - xsize = int((xmax - xmin) / gt[1]) + 1 - ysize = int((ymax - ymin) / abs(gt[5])) + 1 - - # Clamp to raster dimensions - xoff_clamped = max(0, min(xoff, dgm_ds.RasterXSize - 1)) - yoff_clamped = max(0, min(yoff, dgm_ds.RasterYSize - 1)) - xsize_clamped = max(1, min(xsize, dgm_ds.RasterXSize - xoff_clamped)) - ysize_clamped = max(1, min(ysize, dgm_ds.RasterYSize - yoff_clamped)) - - arr = band.ReadAsArray(xoff_clamped, yoff_clamped, xsize_clamped, ysize_clamped) - nodata = band.GetNoDataValue() - for idx, (vx, vy, vz) in enumerate(vertices): - wx = vx + xmin - wy = vy + ymin - col = int((wx - gt[0]) / gt[1]) - xoff_clamped - row = int((wy - gt[3]) / gt[5]) - yoff_clamped - if 0 <= row < arr.shape[0] and 0 <= col < arr.shape[1]: - g = float(arr[row, col]) - if nodata is not None and g == nodata: - continue - if vz < g: - vertices[idx, 2] = g - except Exception: - pass + # Ground snap (simple: clamp below-ground vertices up to DTM) + + try: + + if dgm_vrt_ds: + + gt = dgm_vrt_ds.GetGeoTransform() + + band = dgm_vrt_ds.GetRasterBand(1) + + # Calculate srcwin for the tile + + xmin, ymin, xmax, ymax = bounds + + xoff = int((xmin - gt[0]) / gt[1]) + + yoff = int((ymax - gt[3]) / gt[5]) + + xsize = int((xmax - xmin) / gt[1]) + 1 + + ysize = int((ymax - ymin) / abs(gt[5])) + 1 + + + + # Clamp to raster dimensions + + xoff_clamped = max(0, min(xoff, dgm_vrt_ds.RasterXSize - 1)) + + yoff_clamped = max(0, min(yoff, dgm_vrt_ds.RasterYSize - 1)) + + xsize_clamped = max(1, min(xsize, dgm_vrt_ds.RasterXSize - xoff_clamped)) + + ysize_clamped = max(1, min(ysize, dgm_vrt_ds.RasterYSize - yoff_clamped)) + + + + arr = band.ReadAsArray(xoff_clamped, yoff_clamped, xsize_clamped, ysize_clamped) + + nodata = band.GetNoDataValue() + + for idx, (vx, vy, vz) in enumerate(vertices): + + wx = vx + xmin + + wy = vy + ymin + + col = int((wx - gt[0]) / gt[1]) - xoff_clamped + + row = int((wy - gt[3]) / gt[5]) - yoff_clamped + + if 0 <= row < arr.shape[0] and 0 <= col < arr.shape[1]: + + g = float(arr[row, col]) + + if nodata is not None and g == nodata: + + continue + + if vz < g: + + vertices[idx, 2] = g + + except Exception: + + pass + + xmin, ymin, xmax, ymax = bounds w = xmax - xmin @@ -540,48 +619,52 @@ def export_buildings(cfg: Config) -> int: uv = np.zeros((len(vertices), 2), dtype=np.float32) uv[:, 0] = source_xy[:, 0] / w uv[:, 1] = 1.0 - (source_xy[:, 1] / h) + # Wall colors sampled from ortho if available (fallback constant) wall_color = np.zeros((len(vertices), 3), dtype=np.float32) + 0.75 - ortho_path = os.path.join(cfg.export.ortho_dir, f"{tile_id}.jpg") - ortho_bytes = _load_ortho(tile_id, ortho_path) - try: - # Sample walls from global VRT using windowed read to avoid memory issues and ensure coverage - ortho_ds = gdal.Open(cfg.work.ortho_vrt) - if ortho_ds: - gt_o = ortho_ds.GetGeoTransform() - xmin, ymin, xmax, ymax = bounds + # Extract high-quality roof texture from VRT (handles multi-tile correctly) + ortho_bytes = None + if ortho_vrt_ds: + try: + # Calculate resolution based on tile size relative to 1km base + tile_w_km = (xmax - xmin) / 1000.0 + target_res = int(round(cfg.ortho.out_res * tile_w_km)) + ortho_bytes = _extract_texture_from_vrt(bounds, ortho_vrt_ds, width_px=target_res) - # Calculate srcwin - xoff = int((xmin - gt_o[0]) / gt_o[1]) - yoff = int((ymax - gt_o[3]) / gt_o[5]) - xsize = int((xmax - xmin) / gt_o[1]) + 1 - ysize = int((ymax - ymin) / abs(gt_o[5])) + 1 - - # Clamp to raster dimensions - xoff_clamped = max(0, min(xoff, ortho_ds.RasterXSize - 1)) - yoff_clamped = max(0, min(yoff, ortho_ds.RasterYSize - 1)) - xsize_clamped = max(1, min(xsize, ortho_ds.RasterXSize - xoff_clamped)) - ysize_clamped = max(1, min(ysize, ortho_ds.RasterYSize - yoff_clamped)) - - bands = [ - ortho_ds.GetRasterBand(i + 1).ReadAsArray(xoff_clamped, yoff_clamped, xsize_clamped, ysize_clamped) - for i in range(min(3, ortho_ds.RasterCount)) - ] - - for idx, (vx, vy, _) in enumerate(vertices): - wx = vx + xmin - wy = vy + ymin - # Map global coord to window-local coord - col = int((wx - gt_o[0]) / gt_o[1]) - xoff_clamped - row = int((wy - gt_o[3]) / gt_o[5]) - yoff_clamped + if ortho_bytes: + # Also sample wall colors from the VRT (already implemented using windowed read) + gt_o = ortho_vrt_ds.GetGeoTransform() - if 0 <= row < ysize_clamped and 0 <= col < xsize_clamped: - rgb = [bands[i][row, col] for i in range(len(bands))] - if rgb: - wall_color[idx] = np.array(rgb[:3], dtype=np.float32) / 255.0 - except Exception: - pass + # Calculate srcwin for the tile + xoff_o = int((xmin - gt_o[0]) / gt_o[1]) + yoff_o = int((ymax - gt_o[3]) / gt_o[5]) + xsize_o = int((xmax - xmin) / gt_o[1]) + 1 + ysize_o = int((ymax - ymin) / abs(gt_o[5])) + 1 + + # Clamp to raster dimensions + xoff_o_clamped = max(0, min(xoff_o, ortho_vrt_ds.RasterXSize - 1)) + yoff_o_clamped = max(0, min(yoff_o, ortho_vrt_ds.RasterYSize - 1)) + xsize_o_clamped = max(1, min(xsize_o, ortho_vrt_ds.RasterXSize - xoff_o_clamped)) + ysize_o_clamped = max(1, min(ysize_o, ortho_vrt_ds.RasterYSize - yoff_o_clamped)) + + bands = [ + ortho_vrt_ds.GetRasterBand(i + 1).ReadAsArray(xoff_o_clamped, yoff_o_clamped, xsize_o_clamped, ysize_o_clamped) + for i in range(min(3, ortho_vrt_ds.RasterCount)) + ] + + for idx, (vx, vy, _) in enumerate(vertices): + wx = vx + xmin + wy = vy + ymin + col = int((wx - gt_o[0]) / gt_o[1]) - xoff_o_clamped + row = int((wy - gt_o[3]) / gt_o[5]) - yoff_o_clamped + + if 0 <= row < ysize_o_clamped and 0 <= col < xsize_o_clamped: + rgb = [bands[i][row, col] for i in range(len(bands))] + if rgb: + wall_color[idx] = np.array(rgb[:3], dtype=np.float32) / 255.0 + except Exception as exc: + print(f"[buildings] error sampling textures for {tile_id}: {exc}") glb_bytes = _compose_glb( gltf_vertices, diff --git a/tests/test_roof_textures.py b/tests/test_roof_textures.py new file mode 100644 index 0000000..241ed35 --- /dev/null +++ b/tests/test_roof_textures.py @@ -0,0 +1,71 @@ +import unittest +from unittest.mock import patch, MagicMock +import os +import numpy as np +from geodata_pipeline.buildings import export_buildings +from geodata_pipeline.config import Config + +class TestRoofTextures(unittest.TestCase): + @patch("geodata_pipeline.buildings.gdal.WarpOptions") + @patch("geodata_pipeline.buildings.gdal.Warp") + @patch("geodata_pipeline.buildings.gdal.Translate") + @patch("geodata_pipeline.buildings.gdal.Open") + @patch("geodata_pipeline.buildings._ensure_cityjson_for_tile") + @patch("geodata_pipeline.buildings._load_cityjson") + @patch("geodata_pipeline.buildings._collect_faces") + @patch("geodata_pipeline.buildings._compose_glb") + @patch("geodata_pipeline.buildings.ensure_dir") + @patch("geodata_pipeline.buildings.os.path.exists") + @patch("builtins.open") + def test_vrt_roof_texture_extraction(self, mock_open_file, mock_exists, mock_ensure_dir, mock_compose, mock_collect, mock_load_cj, mock_ensure_cj, mock_gdal_open, mock_translate, mock_warp, mock_warp_opts): + # Setup mocks + mock_exists.return_value = True + + # Mock manifest: 2km tile (1000, 1000) to (3000, 3000) + mock_handle = MagicMock() + mock_open_file.return_value.__enter__.return_value = mock_handle + mock_handle.__iter__.return_value = [ + "tile_id,xmin,ymin,xmax,ymax,global_min,global_max,out_res,tile_key,tile_min,tile_max\n", + "tile_2km,1000,1000,3000,3000,0,100,1025,1_1,0,100\n" + ] + + mock_ensure_cj.return_value = "dummy.json" + mock_load_cj.return_value = {"CityObjects": {}} + mock_collect.return_value = ( + [[10.0, 10.0, 5.0]], # vertices + [([0, 0, 0], "RoofSurface")] + ) + + cfg = Config.default() + cfg.work.ortho_vrt = "work/dop.vrt" + cfg.export.ortho_dir = "export/ortho_jpg" + cfg.export.manifest_path = "manifest.csv" + + # Mock GDAL + mock_ds = MagicMock() + mock_gdal_open.return_value = mock_ds + mock_ds.GetGeoTransform.return_value = (0, 1, 0, 5000, 0, -1) + mock_ds.RasterXSize = 5000 + mock_ds.RasterYSize = 5000 + mock_ds.RasterCount = 3 + mock_band = MagicMock() + mock_ds.GetRasterBand.return_value = mock_band + mock_band.ReadAsArray.return_value = np.zeros((10, 10)) + + # Mock Warp to return a dataset + mock_warp.return_value = MagicMock() + + export_buildings(cfg) + + # Verify WarpOptions was called with the 2km bounds + mock_warp_opts.assert_called() + # Find the call that used outputBounds + found_correct_bounds = False + for call in mock_warp_opts.call_args_list: + if call.kwargs.get("outputBounds") == (1000.0, 1000.0, 3000.0, 3000.0): + found_correct_bounds = True + break + self.assertTrue(found_correct_bounds, "WarpOptions not called with expected 2km bounds") + +if __name__ == "__main__": + unittest.main()