Files
GeoData/scripts/mask_validate_ids.py

247 lines
8.3 KiB
Python

#!/usr/bin/env python3
"""Validate source/sink ID mask integrity.
ID encoding:
id = B + 256*G + 65536*R
"""
from __future__ import annotations
import argparse
import json
import re
from pathlib import Path
from typing import Dict, Optional
import numpy as np
from PIL import Image
_TILE_KEY_RE = re.compile(r"^(dgm\d+_\d+_\d+_\d+)")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Validate source/sink masks and report ID usage.")
p.add_argument("--source-dir", default="raw/water_source_masks", help="Directory with source masks.")
p.add_argument("--sink-dir", default="raw/water_sink_masks", help="Directory with sink masks.")
p.add_argument("--allowed-source-ids", default="", help="Comma-separated nonzero IDs allowed in source masks.")
p.add_argument("--allowed-sink-ids", default="", help="Comma-separated nonzero IDs allowed in sink masks.")
p.add_argument("--max-id", type=int, default=0, help="If >0, fail when any ID exceeds this value.")
p.add_argument(
"--report",
default="work/mask_master/validation_report.json",
help="Output JSON report path.",
)
p.add_argument("--fail-on-overlap", action="store_true", help="Fail if source and sink overlap on same tile pixels.")
return p.parse_args()
def parse_allowed_ids(text: str) -> Optional[set[int]]:
text = text.strip()
if not text:
return None
out: set[int] = set()
for part in text.split(","):
part = part.strip()
if not part:
continue
out.add(int(part))
return out
def tile_key_from_name(name: str) -> Optional[str]:
stem = Path(name).stem
m = _TILE_KEY_RE.match(stem)
return m.group(1) if m else None
def decode_ids(rgb: np.ndarray) -> np.ndarray:
r = rgb[..., 0].astype(np.uint32)
g = rgb[..., 1].astype(np.uint32)
b = rgb[..., 2].astype(np.uint32)
return b + (256 * g) + (65536 * r)
def analyze_mask(path: Path) -> dict:
arr = np.array(Image.open(path).convert("RGB"), dtype=np.uint8)
ids = decode_ids(arr)
u, c = np.unique(ids, return_counts=True)
nonzero = [(int(i), int(n)) for i, n in zip(u.tolist(), c.tolist()) if i != 0]
return {
"path": str(path),
"shape": [int(arr.shape[1]), int(arr.shape[0])],
"unique_count": int(len(u)),
"max_id": int(u[-1]) if len(u) > 0 else 0,
"nonzero_count": int(np.count_nonzero(ids)),
"nonzero_ids": nonzero,
"ids_array": ids,
}
def collect_pngs(path: Path) -> list[Path]:
if not path.exists():
return []
return sorted([p for p in path.glob("*.png") if p.is_file()])
def unexpected_ids(found: list[tuple[int, int]], allowed: Optional[set[int]]) -> list[int]:
if allowed is None:
return []
return sorted([i for i, _ in found if i not in allowed])
def build_overview(files: list[dict]) -> list[dict]:
totals: Dict[int, dict] = {}
for entry in files:
tile_key = str(entry.get("tile_key") or "").strip()
for ident, pixels in entry.get("nonzero_ids", []):
node = totals.setdefault(
int(ident),
{
"id": int(ident),
"total_pixels": 0,
"file_count": 0,
"tile_keys": set(),
},
)
node["total_pixels"] += int(pixels)
node["file_count"] += 1
if tile_key:
node["tile_keys"].add(tile_key)
out = []
for ident in sorted(totals.keys()):
node = totals[ident]
out.append(
{
"id": int(node["id"]),
"total_pixels": int(node["total_pixels"]),
"file_count": int(node["file_count"]),
"tile_keys": sorted(list(node["tile_keys"])),
}
)
return out
def main() -> int:
args = parse_args()
allowed_source = parse_allowed_ids(args.allowed_source_ids)
allowed_sink = parse_allowed_ids(args.allowed_sink_ids)
report = {
"schema_version": 1,
"config": {
"source_dir": args.source_dir,
"sink_dir": args.sink_dir,
"allowed_source_ids": sorted(list(allowed_source)) if allowed_source is not None else [],
"allowed_sink_ids": sorted(list(allowed_sink)) if allowed_sink is not None else [],
"max_id": args.max_id,
"fail_on_overlap": bool(args.fail_on_overlap),
},
"source": [],
"source_overview": [],
"sink": [],
"sink_overview": [],
"overlaps": [],
"issues": [],
}
failed = False
source_by_tile: Dict[str, np.ndarray] = {}
sink_by_tile: Dict[str, np.ndarray] = {}
for p in collect_pngs(Path(args.source_dir)):
a = analyze_mask(p)
key = tile_key_from_name(p.name)
bad_ids = unexpected_ids(a["nonzero_ids"], allowed_source)
if bad_ids:
msg = f"[mask_validate_ids] Source {p.name} has unexpected IDs: {bad_ids}"
report["issues"].append(msg)
print(msg)
failed = True
if args.max_id > 0 and a["max_id"] > args.max_id:
msg = f"[mask_validate_ids] Source {p.name} max_id={a['max_id']} exceeds max-id={args.max_id}"
report["issues"].append(msg)
print(msg)
failed = True
report["source"].append(
{
"path": a["path"],
"tile_key": key or "",
"shape": a["shape"],
"unique_count": a["unique_count"],
"max_id": a["max_id"],
"nonzero_count": a["nonzero_count"],
"nonzero_ids": a["nonzero_ids"],
}
)
if key:
source_by_tile[key] = a["ids_array"]
for p in collect_pngs(Path(args.sink_dir)):
a = analyze_mask(p)
key = tile_key_from_name(p.name)
bad_ids = unexpected_ids(a["nonzero_ids"], allowed_sink)
if bad_ids:
msg = f"[mask_validate_ids] Sink {p.name} has unexpected IDs: {bad_ids}"
report["issues"].append(msg)
print(msg)
failed = True
if args.max_id > 0 and a["max_id"] > args.max_id:
msg = f"[mask_validate_ids] Sink {p.name} max_id={a['max_id']} exceeds max-id={args.max_id}"
report["issues"].append(msg)
print(msg)
failed = True
report["sink"].append(
{
"path": a["path"],
"tile_key": key or "",
"shape": a["shape"],
"unique_count": a["unique_count"],
"max_id": a["max_id"],
"nonzero_count": a["nonzero_count"],
"nonzero_ids": a["nonzero_ids"],
}
)
if key:
sink_by_tile[key] = a["ids_array"]
shared_tiles = sorted(set(source_by_tile.keys()) & set(sink_by_tile.keys()))
for key in shared_tiles:
s = source_by_tile[key]
k = sink_by_tile[key]
if s.shape != k.shape:
msg = f"[mask_validate_ids] Shape mismatch for tile {key}: source{s.shape} sink{k.shape}"
report["issues"].append(msg)
print(msg)
failed = True
continue
overlap = int(np.count_nonzero((s > 0) & (k > 0)))
if overlap > 0:
entry = {"tile_key": key, "overlap_pixels": overlap}
report["overlaps"].append(entry)
msg = f"[mask_validate_ids] Overlap on tile {key}: {overlap} pixels"
print(msg)
if args.fail_on_overlap:
report["issues"].append(msg)
failed = True
report["source_overview"] = build_overview(report["source"])
report["sink_overview"] = build_overview(report["sink"])
out_path = Path(args.report)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
print(f"[mask_validate_ids] Report written: {out_path}")
print(f"[mask_validate_ids] Source files: {len(report['source'])}, Sink files: {len(report['sink'])}")
print(
"[mask_validate_ids] Source IDs: "
f"{len(report['source_overview'])}, Sink IDs: {len(report['sink_overview'])}"
)
print(f"[mask_validate_ids] Overlap records: {len(report['overlaps'])}")
return 1 if failed else 0
if __name__ == "__main__":
raise SystemExit(main())