Files
GeoData/geodata_pipeline/trees_enhanced.py

421 lines
12 KiB
Python

"""Enhanced tree detection with LPO refinement and canopy color sampling.
Extends the standard tree detection by:
1. Excluding detected street furniture
2. Refining tree heights using LPO point cloud
3. Sampling canopy colors from DOP20 orthophotos
"""
from __future__ import annotations
import csv
import glob
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
from osgeo import gdal
from .config import Config
from .gdal_utils import open_dataset
from .pointcloud import find_pointcloud_file, has_lpo_data, read_pointcloud_file
from .street_furniture import load_furniture_detections
@dataclass
class EnhancedTree:
"""Enhanced tree detection with additional attributes."""
tile_id: str
x_local: float
y_local: float
z_ground: float
height: float
radius: float
confidence: float
canopy_r: int = 128 # Default green-ish
canopy_g: int = 160
canopy_b: int = 80
def to_csv_row(self) -> list:
"""Convert to CSV row."""
return [
self.tile_id,
f"{self.x_local:.2f}",
f"{self.y_local:.2f}",
f"{self.z_ground:.2f}",
f"{self.height:.2f}",
f"{self.radius:.2f}",
f"{self.confidence:.3f}",
str(self.canopy_r),
str(self.canopy_g),
str(self.canopy_b),
]
def _create_furniture_mask(
detections: list,
shape: Tuple[int, int],
geotransform: Tuple[float, ...],
tile_xmin: float,
tile_ymin: float,
buffer_m: float = 2.0,
) -> np.ndarray:
"""Create raster mask from furniture detections.
Args:
detections: List of FurnitureDetection objects
shape: Output raster shape (rows, cols)
geotransform: GDAL geotransform
tile_xmin, tile_ymin: Tile origin
buffer_m: Buffer around each detection
Returns:
Boolean mask where True = furniture location
"""
mask = np.zeros(shape, dtype=bool)
pixel_size = abs(geotransform[1])
for det in detections:
# Convert local coords to world coords
world_x = det.x_local + tile_xmin
world_y = det.y_local + tile_ymin
# Convert to pixel coords
col = int((world_x - geotransform[0]) / geotransform[1])
row = int((world_y - geotransform[3]) / geotransform[5])
# Apply buffer
buffer_pixels = int(buffer_m / pixel_size) + 1
row_min = max(0, row - buffer_pixels)
row_max = min(shape[0], row + buffer_pixels + 1)
col_min = max(0, col - buffer_pixels)
col_max = min(shape[1], col + buffer_pixels + 1)
mask[row_min:row_max, col_min:col_max] = True
return mask
def _detect_local_maxima(
chm: np.ndarray,
min_height: float,
window_size: int = 5,
) -> List[Tuple[int, int, float]]:
"""Detect local maxima in canopy height model.
Args:
chm: Canopy height model array
min_height: Minimum tree height
window_size: Search window size
Returns:
List of (row, col, height) tuples
"""
from scipy import ndimage
# Apply minimum height threshold
chm_thresh = np.where(chm >= min_height, chm, 0)
# Find local maxima using maximum filter
local_max = ndimage.maximum_filter(chm_thresh, size=window_size)
peaks = (chm_thresh == local_max) & (chm_thresh > 0)
# Get peak locations
rows, cols = np.where(peaks)
heights = chm[rows, cols]
return list(zip(rows, cols, heights))
def _sample_ortho_color(
ortho_path: str,
world_x: float,
world_y: float,
radius: float,
) -> Optional[Tuple[int, int, int]]:
"""Sample average color from orthophoto within radius.
Args:
ortho_path: Path to orthophoto JPEG
world_x, world_y: World coordinates
radius: Sample radius in meters
Returns:
Tuple of (R, G, B) or None if sampling fails
"""
ds = open_dataset(ortho_path, required=False)
if ds is None:
return None
gt = ds.GetGeoTransform()
pixel_size = abs(gt[1])
# Convert to pixel coords
center_col = int((world_x - gt[0]) / gt[1])
center_row = int((world_y - gt[3]) / gt[5])
# Compute sample window
radius_pixels = int(radius / pixel_size) + 1
col_min = max(0, center_col - radius_pixels)
col_max = min(ds.RasterXSize, center_col + radius_pixels + 1)
row_min = max(0, center_row - radius_pixels)
row_max = min(ds.RasterYSize, center_row + radius_pixels + 1)
if col_min >= col_max or row_min >= row_max:
return None
try:
# Read RGB bands
r = ds.GetRasterBand(1).ReadAsArray(col_min, row_min, col_max - col_min, row_max - row_min)
g = ds.GetRasterBand(2).ReadAsArray(col_min, row_min, col_max - col_min, row_max - row_min)
b = ds.GetRasterBand(3).ReadAsArray(col_min, row_min, col_max - col_min, row_max - row_min)
# Create circular mask
y_indices, x_indices = np.ogrid[:r.shape[0], :r.shape[1]]
center_y = (row_max - row_min) / 2
center_x = (col_max - col_min) / 2
dist_sq = (y_indices - center_y) ** 2 + (x_indices - center_x) ** 2
circle_mask = dist_sq <= radius_pixels ** 2
if not circle_mask.any():
return None
r_avg = int(np.mean(r[circle_mask]))
g_avg = int(np.mean(g[circle_mask]))
b_avg = int(np.mean(b[circle_mask]))
return (r_avg, g_avg, b_avg)
except Exception:
return None
def detect_trees_in_tile(
tile_id: str,
xmin: float,
ymin: float,
xmax: float,
ymax: float,
cfg: Config,
) -> List[EnhancedTree]:
"""Detect trees in a single tile with enhanced processing.
Args:
tile_id: Tile identifier
xmin, ymin, xmax, ymax: Tile bounds
cfg: Configuration
Returns:
List of enhanced tree detections
"""
pc_cfg = cfg.pointcloud
et_cfg = cfg.trees_enhanced
tree_cfg = cfg.trees
# Load DOM1 and DGM1
dom1_pattern = os.path.join(pc_cfg.dom1_dir, f"*{tile_id}*.tif")
dom1_files = glob.glob(dom1_pattern)
if not dom1_files:
print(f"[trees_enhanced] No DOM1 for {tile_id}")
return []
dgm1_pattern = os.path.join(cfg.raw.dgm1_dir, f"*{tile_id}*.tif")
dgm1_files = glob.glob(dgm1_pattern)
if not dgm1_files:
print(f"[trees_enhanced] No DGM1 for {tile_id}")
return []
dom1_ds = open_dataset(dom1_files[0], required=False)
dgm1_ds = open_dataset(dgm1_files[0], required=False)
if dom1_ds is None or dgm1_ds is None:
return []
dom1 = dom1_ds.GetRasterBand(1).ReadAsArray().astype(np.float32)
dgm1 = dgm1_ds.GetRasterBand(1).ReadAsArray().astype(np.float32)
gt = dom1_ds.GetGeoTransform()
# Handle nodata
dom1_nodata = dom1_ds.GetRasterBand(1).GetNoDataValue() or -9999
dgm1_nodata = dgm1_ds.GetRasterBand(1).GetNoDataValue() or -9999
valid_mask = (dom1 > dom1_nodata + 1) & (dgm1 > dgm1_nodata + 1)
# Compute CHM
chm = np.where(valid_mask, dom1 - dgm1, 0)
chm = np.clip(chm, 0, None)
# Load and apply furniture exclusion mask
if et_cfg.exclude_furniture:
furniture = load_furniture_detections(tile_id, cfg)
if furniture:
furniture_mask = _create_furniture_mask(
furniture, chm.shape, gt, xmin, ymin, buffer_m=2.0
)
chm = np.where(furniture_mask, 0, chm)
# Detect tree peaks
peaks = _detect_local_maxima(chm, tree_cfg.min_height_m, window_size=5)
# Limit to max trees
if len(peaks) > tree_cfg.max_trees:
# Sort by height descending, keep tallest
peaks = sorted(peaks, key=lambda p: p[2], reverse=True)[:tree_cfg.max_trees]
# Find ortho for color sampling
ortho_path = os.path.join(cfg.export.ortho_dir, f"{tile_id}.jpg")
has_ortho = os.path.exists(ortho_path)
# Load LPO for height refinement
lpo = None
if et_cfg.use_lpo_refinement and has_lpo_data(tile_id, pc_cfg.lpo_dir):
lpo_file = find_pointcloud_file(pc_cfg.lpo_dir, tile_id)
if lpo_file:
lpo = read_pointcloud_file(lpo_file, bounds=(xmin, ymin, xmax, ymax))
trees = []
pixel_size = abs(gt[1])
for row, col, height in peaks:
# Convert to world coordinates
world_x = gt[0] + col * gt[1]
world_y = gt[3] + row * gt[5]
# Local coordinates
x_local = world_x - xmin
y_local = world_y - ymin
# Ground elevation
z_ground = float(dgm1[row, col])
# Estimate radius (simple heuristic)
radius = height * 0.25
# Base confidence
confidence = 0.6 + 0.4 * min(1.0, height / 30.0)
# Refine with LPO
if lpo is not None and len(lpo) > 0:
search_radius = et_cfg.lpo_search_radius_m
dist_sq = (lpo.x - world_x) ** 2 + (lpo.y - world_y) ** 2
nearby_mask = dist_sq < search_radius ** 2
if nearby_mask.any():
nearby_z = lpo.z[nearby_mask]
refined_height = float(np.percentile(nearby_z, 95)) - z_ground
if refined_height > tree_cfg.min_height_m:
height = refined_height
confidence += 0.1
# Sample canopy color
canopy_r, canopy_g, canopy_b = 128, 160, 80 # Default
if et_cfg.sample_canopy_color and has_ortho:
sample_radius = radius * et_cfg.canopy_sample_radius_factor
color = _sample_ortho_color(ortho_path, world_x, world_y, sample_radius)
if color:
canopy_r, canopy_g, canopy_b = color
tree = EnhancedTree(
tile_id=tile_id,
x_local=x_local,
y_local=y_local,
z_ground=z_ground,
height=height,
radius=radius,
confidence=min(confidence, 1.0),
canopy_r=canopy_r,
canopy_g=canopy_g,
canopy_b=canopy_b,
)
trees.append(tree)
return trees
def export_trees_csv(
trees: List[EnhancedTree],
tile_id: str,
cfg: Config,
) -> str:
"""Export enhanced trees to CSV.
Args:
trees: List of tree detections
tile_id: Tile identifier
cfg: Configuration
Returns:
Path to CSV file
"""
os.makedirs(cfg.trees_enhanced.csv_dir, exist_ok=True)
csv_path = os.path.join(cfg.trees_enhanced.csv_dir, f"{tile_id}.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow([
"tile_id", "x_local", "y_local", "z_ground",
"height", "radius", "confidence",
"canopy_r", "canopy_g", "canopy_b"
])
for tree in trees:
writer.writerow(tree.to_csv_row())
return csv_path
def export_trees_enhanced(cfg: Config) -> int:
"""Export enhanced tree detections for all tiles.
Args:
cfg: Configuration
Returns:
Exit code (0 = success)
"""
# Check for manifest
if not os.path.exists(cfg.export.manifest_path):
print(f"[trees_enhanced] ERROR: Manifest not found: {cfg.export.manifest_path}")
print("[trees_enhanced] Run heightmap export first.")
return 1
# Load tiles from manifest
tiles = []
with open(cfg.export.manifest_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
tiles.append({
"tile_id": row["tile_id"],
"xmin": float(row["xmin"]),
"ymin": float(row["ymin"]),
"xmax": float(row["xmax"]),
"ymax": float(row["ymax"]),
})
if not tiles:
print("[trees_enhanced] No tiles in manifest.")
return 1
os.makedirs(cfg.trees_enhanced.csv_dir, exist_ok=True)
total_trees = 0
for tile in tiles:
tile_id = tile["tile_id"]
print(f"[trees_enhanced] Processing {tile_id}...")
trees = detect_trees_in_tile(
tile_id,
tile["xmin"],
tile["ymin"],
tile["xmax"],
tile["ymax"],
cfg,
)
csv_path = export_trees_csv(trees, tile_id, cfg)
print(f"[trees_enhanced] {tile_id}: {len(trees)} trees -> {csv_path}")
total_trees += len(trees)
print(f"[trees_enhanced] DONE. Total trees: {total_trees}")
return 0