421 lines
12 KiB
Python
421 lines
12 KiB
Python
"""Enhanced tree detection with LPO refinement and canopy color sampling.
|
|
|
|
Extends the standard tree detection by:
|
|
1. Excluding detected street furniture
|
|
2. Refining tree heights using LPO point cloud
|
|
3. Sampling canopy colors from DOP20 orthophotos
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import csv
|
|
import glob
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from osgeo import gdal
|
|
|
|
from .config import Config
|
|
from .gdal_utils import open_dataset
|
|
from .pointcloud import find_pointcloud_file, has_lpo_data, read_pointcloud_file
|
|
from .street_furniture import load_furniture_detections
|
|
|
|
|
|
@dataclass
|
|
class EnhancedTree:
|
|
"""Enhanced tree detection with additional attributes."""
|
|
tile_id: str
|
|
x_local: float
|
|
y_local: float
|
|
z_ground: float
|
|
height: float
|
|
radius: float
|
|
confidence: float
|
|
canopy_r: int = 128 # Default green-ish
|
|
canopy_g: int = 160
|
|
canopy_b: int = 80
|
|
|
|
def to_csv_row(self) -> list:
|
|
"""Convert to CSV row."""
|
|
return [
|
|
self.tile_id,
|
|
f"{self.x_local:.2f}",
|
|
f"{self.y_local:.2f}",
|
|
f"{self.z_ground:.2f}",
|
|
f"{self.height:.2f}",
|
|
f"{self.radius:.2f}",
|
|
f"{self.confidence:.3f}",
|
|
str(self.canopy_r),
|
|
str(self.canopy_g),
|
|
str(self.canopy_b),
|
|
]
|
|
|
|
|
|
def _create_furniture_mask(
|
|
detections: list,
|
|
shape: Tuple[int, int],
|
|
geotransform: Tuple[float, ...],
|
|
tile_xmin: float,
|
|
tile_ymin: float,
|
|
buffer_m: float = 2.0,
|
|
) -> np.ndarray:
|
|
"""Create raster mask from furniture detections.
|
|
|
|
Args:
|
|
detections: List of FurnitureDetection objects
|
|
shape: Output raster shape (rows, cols)
|
|
geotransform: GDAL geotransform
|
|
tile_xmin, tile_ymin: Tile origin
|
|
buffer_m: Buffer around each detection
|
|
|
|
Returns:
|
|
Boolean mask where True = furniture location
|
|
"""
|
|
mask = np.zeros(shape, dtype=bool)
|
|
pixel_size = abs(geotransform[1])
|
|
|
|
for det in detections:
|
|
# Convert local coords to world coords
|
|
world_x = det.x_local + tile_xmin
|
|
world_y = det.y_local + tile_ymin
|
|
|
|
# Convert to pixel coords
|
|
col = int((world_x - geotransform[0]) / geotransform[1])
|
|
row = int((world_y - geotransform[3]) / geotransform[5])
|
|
|
|
# Apply buffer
|
|
buffer_pixels = int(buffer_m / pixel_size) + 1
|
|
|
|
row_min = max(0, row - buffer_pixels)
|
|
row_max = min(shape[0], row + buffer_pixels + 1)
|
|
col_min = max(0, col - buffer_pixels)
|
|
col_max = min(shape[1], col + buffer_pixels + 1)
|
|
|
|
mask[row_min:row_max, col_min:col_max] = True
|
|
|
|
return mask
|
|
|
|
|
|
def _detect_local_maxima(
|
|
chm: np.ndarray,
|
|
min_height: float,
|
|
window_size: int = 5,
|
|
) -> List[Tuple[int, int, float]]:
|
|
"""Detect local maxima in canopy height model.
|
|
|
|
Args:
|
|
chm: Canopy height model array
|
|
min_height: Minimum tree height
|
|
window_size: Search window size
|
|
|
|
Returns:
|
|
List of (row, col, height) tuples
|
|
"""
|
|
from scipy import ndimage
|
|
|
|
# Apply minimum height threshold
|
|
chm_thresh = np.where(chm >= min_height, chm, 0)
|
|
|
|
# Find local maxima using maximum filter
|
|
local_max = ndimage.maximum_filter(chm_thresh, size=window_size)
|
|
peaks = (chm_thresh == local_max) & (chm_thresh > 0)
|
|
|
|
# Get peak locations
|
|
rows, cols = np.where(peaks)
|
|
heights = chm[rows, cols]
|
|
|
|
return list(zip(rows, cols, heights))
|
|
|
|
|
|
def _sample_ortho_color(
|
|
ortho_path: str,
|
|
world_x: float,
|
|
world_y: float,
|
|
radius: float,
|
|
) -> Optional[Tuple[int, int, int]]:
|
|
"""Sample average color from orthophoto within radius.
|
|
|
|
Args:
|
|
ortho_path: Path to orthophoto JPEG
|
|
world_x, world_y: World coordinates
|
|
radius: Sample radius in meters
|
|
|
|
Returns:
|
|
Tuple of (R, G, B) or None if sampling fails
|
|
"""
|
|
ds = open_dataset(ortho_path, required=False)
|
|
if ds is None:
|
|
return None
|
|
|
|
gt = ds.GetGeoTransform()
|
|
pixel_size = abs(gt[1])
|
|
|
|
# Convert to pixel coords
|
|
center_col = int((world_x - gt[0]) / gt[1])
|
|
center_row = int((world_y - gt[3]) / gt[5])
|
|
|
|
# Compute sample window
|
|
radius_pixels = int(radius / pixel_size) + 1
|
|
|
|
col_min = max(0, center_col - radius_pixels)
|
|
col_max = min(ds.RasterXSize, center_col + radius_pixels + 1)
|
|
row_min = max(0, center_row - radius_pixels)
|
|
row_max = min(ds.RasterYSize, center_row + radius_pixels + 1)
|
|
|
|
if col_min >= col_max or row_min >= row_max:
|
|
return None
|
|
|
|
try:
|
|
# Read RGB bands
|
|
r = ds.GetRasterBand(1).ReadAsArray(col_min, row_min, col_max - col_min, row_max - row_min)
|
|
g = ds.GetRasterBand(2).ReadAsArray(col_min, row_min, col_max - col_min, row_max - row_min)
|
|
b = ds.GetRasterBand(3).ReadAsArray(col_min, row_min, col_max - col_min, row_max - row_min)
|
|
|
|
# Create circular mask
|
|
y_indices, x_indices = np.ogrid[:r.shape[0], :r.shape[1]]
|
|
center_y = (row_max - row_min) / 2
|
|
center_x = (col_max - col_min) / 2
|
|
dist_sq = (y_indices - center_y) ** 2 + (x_indices - center_x) ** 2
|
|
circle_mask = dist_sq <= radius_pixels ** 2
|
|
|
|
if not circle_mask.any():
|
|
return None
|
|
|
|
r_avg = int(np.mean(r[circle_mask]))
|
|
g_avg = int(np.mean(g[circle_mask]))
|
|
b_avg = int(np.mean(b[circle_mask]))
|
|
|
|
return (r_avg, g_avg, b_avg)
|
|
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def detect_trees_in_tile(
|
|
tile_id: str,
|
|
xmin: float,
|
|
ymin: float,
|
|
xmax: float,
|
|
ymax: float,
|
|
cfg: Config,
|
|
) -> List[EnhancedTree]:
|
|
"""Detect trees in a single tile with enhanced processing.
|
|
|
|
Args:
|
|
tile_id: Tile identifier
|
|
xmin, ymin, xmax, ymax: Tile bounds
|
|
cfg: Configuration
|
|
|
|
Returns:
|
|
List of enhanced tree detections
|
|
"""
|
|
pc_cfg = cfg.pointcloud
|
|
et_cfg = cfg.trees_enhanced
|
|
tree_cfg = cfg.trees
|
|
|
|
# Load DOM1 and DGM1
|
|
dom1_pattern = os.path.join(pc_cfg.dom1_dir, f"*{tile_id}*.tif")
|
|
dom1_files = glob.glob(dom1_pattern)
|
|
if not dom1_files:
|
|
print(f"[trees_enhanced] No DOM1 for {tile_id}")
|
|
return []
|
|
|
|
dgm1_pattern = os.path.join(cfg.raw.dgm1_dir, f"*{tile_id}*.tif")
|
|
dgm1_files = glob.glob(dgm1_pattern)
|
|
if not dgm1_files:
|
|
print(f"[trees_enhanced] No DGM1 for {tile_id}")
|
|
return []
|
|
|
|
dom1_ds = open_dataset(dom1_files[0], required=False)
|
|
dgm1_ds = open_dataset(dgm1_files[0], required=False)
|
|
if dom1_ds is None or dgm1_ds is None:
|
|
return []
|
|
|
|
dom1 = dom1_ds.GetRasterBand(1).ReadAsArray().astype(np.float32)
|
|
dgm1 = dgm1_ds.GetRasterBand(1).ReadAsArray().astype(np.float32)
|
|
gt = dom1_ds.GetGeoTransform()
|
|
|
|
# Handle nodata
|
|
dom1_nodata = dom1_ds.GetRasterBand(1).GetNoDataValue() or -9999
|
|
dgm1_nodata = dgm1_ds.GetRasterBand(1).GetNoDataValue() or -9999
|
|
valid_mask = (dom1 > dom1_nodata + 1) & (dgm1 > dgm1_nodata + 1)
|
|
|
|
# Compute CHM
|
|
chm = np.where(valid_mask, dom1 - dgm1, 0)
|
|
chm = np.clip(chm, 0, None)
|
|
|
|
# Load and apply furniture exclusion mask
|
|
if et_cfg.exclude_furniture:
|
|
furniture = load_furniture_detections(tile_id, cfg)
|
|
if furniture:
|
|
furniture_mask = _create_furniture_mask(
|
|
furniture, chm.shape, gt, xmin, ymin, buffer_m=2.0
|
|
)
|
|
chm = np.where(furniture_mask, 0, chm)
|
|
|
|
# Detect tree peaks
|
|
peaks = _detect_local_maxima(chm, tree_cfg.min_height_m, window_size=5)
|
|
|
|
# Limit to max trees
|
|
if len(peaks) > tree_cfg.max_trees:
|
|
# Sort by height descending, keep tallest
|
|
peaks = sorted(peaks, key=lambda p: p[2], reverse=True)[:tree_cfg.max_trees]
|
|
|
|
# Find ortho for color sampling
|
|
ortho_path = os.path.join(cfg.export.ortho_dir, f"{tile_id}.jpg")
|
|
has_ortho = os.path.exists(ortho_path)
|
|
|
|
# Load LPO for height refinement
|
|
lpo = None
|
|
if et_cfg.use_lpo_refinement and has_lpo_data(tile_id, pc_cfg.lpo_dir):
|
|
lpo_file = find_pointcloud_file(pc_cfg.lpo_dir, tile_id)
|
|
if lpo_file:
|
|
lpo = read_pointcloud_file(lpo_file, bounds=(xmin, ymin, xmax, ymax))
|
|
|
|
trees = []
|
|
pixel_size = abs(gt[1])
|
|
|
|
for row, col, height in peaks:
|
|
# Convert to world coordinates
|
|
world_x = gt[0] + col * gt[1]
|
|
world_y = gt[3] + row * gt[5]
|
|
|
|
# Local coordinates
|
|
x_local = world_x - xmin
|
|
y_local = world_y - ymin
|
|
|
|
# Ground elevation
|
|
z_ground = float(dgm1[row, col])
|
|
|
|
# Estimate radius (simple heuristic)
|
|
radius = height * 0.25
|
|
|
|
# Base confidence
|
|
confidence = 0.6 + 0.4 * min(1.0, height / 30.0)
|
|
|
|
# Refine with LPO
|
|
if lpo is not None and len(lpo) > 0:
|
|
search_radius = et_cfg.lpo_search_radius_m
|
|
dist_sq = (lpo.x - world_x) ** 2 + (lpo.y - world_y) ** 2
|
|
nearby_mask = dist_sq < search_radius ** 2
|
|
|
|
if nearby_mask.any():
|
|
nearby_z = lpo.z[nearby_mask]
|
|
refined_height = float(np.percentile(nearby_z, 95)) - z_ground
|
|
if refined_height > tree_cfg.min_height_m:
|
|
height = refined_height
|
|
confidence += 0.1
|
|
|
|
# Sample canopy color
|
|
canopy_r, canopy_g, canopy_b = 128, 160, 80 # Default
|
|
if et_cfg.sample_canopy_color and has_ortho:
|
|
sample_radius = radius * et_cfg.canopy_sample_radius_factor
|
|
color = _sample_ortho_color(ortho_path, world_x, world_y, sample_radius)
|
|
if color:
|
|
canopy_r, canopy_g, canopy_b = color
|
|
|
|
tree = EnhancedTree(
|
|
tile_id=tile_id,
|
|
x_local=x_local,
|
|
y_local=y_local,
|
|
z_ground=z_ground,
|
|
height=height,
|
|
radius=radius,
|
|
confidence=min(confidence, 1.0),
|
|
canopy_r=canopy_r,
|
|
canopy_g=canopy_g,
|
|
canopy_b=canopy_b,
|
|
)
|
|
trees.append(tree)
|
|
|
|
return trees
|
|
|
|
|
|
def export_trees_csv(
|
|
trees: List[EnhancedTree],
|
|
tile_id: str,
|
|
cfg: Config,
|
|
) -> str:
|
|
"""Export enhanced trees to CSV.
|
|
|
|
Args:
|
|
trees: List of tree detections
|
|
tile_id: Tile identifier
|
|
cfg: Configuration
|
|
|
|
Returns:
|
|
Path to CSV file
|
|
"""
|
|
os.makedirs(cfg.trees_enhanced.csv_dir, exist_ok=True)
|
|
csv_path = os.path.join(cfg.trees_enhanced.csv_dir, f"{tile_id}.csv")
|
|
|
|
with open(csv_path, "w", newline="") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([
|
|
"tile_id", "x_local", "y_local", "z_ground",
|
|
"height", "radius", "confidence",
|
|
"canopy_r", "canopy_g", "canopy_b"
|
|
])
|
|
for tree in trees:
|
|
writer.writerow(tree.to_csv_row())
|
|
|
|
return csv_path
|
|
|
|
|
|
def export_trees_enhanced(cfg: Config) -> int:
|
|
"""Export enhanced tree detections for all tiles.
|
|
|
|
Args:
|
|
cfg: Configuration
|
|
|
|
Returns:
|
|
Exit code (0 = success)
|
|
"""
|
|
# Check for manifest
|
|
if not os.path.exists(cfg.export.manifest_path):
|
|
print(f"[trees_enhanced] ERROR: Manifest not found: {cfg.export.manifest_path}")
|
|
print("[trees_enhanced] Run heightmap export first.")
|
|
return 1
|
|
|
|
# Load tiles from manifest
|
|
tiles = []
|
|
with open(cfg.export.manifest_path, "r") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
tiles.append({
|
|
"tile_id": row["tile_id"],
|
|
"xmin": float(row["xmin"]),
|
|
"ymin": float(row["ymin"]),
|
|
"xmax": float(row["xmax"]),
|
|
"ymax": float(row["ymax"]),
|
|
})
|
|
|
|
if not tiles:
|
|
print("[trees_enhanced] No tiles in manifest.")
|
|
return 1
|
|
|
|
os.makedirs(cfg.trees_enhanced.csv_dir, exist_ok=True)
|
|
|
|
total_trees = 0
|
|
for tile in tiles:
|
|
tile_id = tile["tile_id"]
|
|
print(f"[trees_enhanced] Processing {tile_id}...")
|
|
|
|
trees = detect_trees_in_tile(
|
|
tile_id,
|
|
tile["xmin"],
|
|
tile["ymin"],
|
|
tile["xmax"],
|
|
tile["ymax"],
|
|
cfg,
|
|
)
|
|
|
|
csv_path = export_trees_csv(trees, tile_id, cfg)
|
|
print(f"[trees_enhanced] {tile_id}: {len(trees)} trees -> {csv_path}")
|
|
total_trees += len(trees)
|
|
|
|
print(f"[trees_enhanced] DONE. Total trees: {total_trees}")
|
|
return 0
|