#!/usr/bin/env python3 """Validate source/sink ID mask integrity. ID encoding: id = B + 256*G + 65536*R """ from __future__ import annotations import argparse import json import re from pathlib import Path from typing import Dict, Optional import numpy as np from PIL import Image _TILE_KEY_RE = re.compile(r"^(dgm\d+_\d+_\d+_\d+)") def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Validate source/sink masks and report ID usage.") p.add_argument("--source-dir", default="raw/water_source_masks", help="Directory with source masks.") p.add_argument("--sink-dir", default="raw/water_sink_masks", help="Directory with sink masks.") p.add_argument("--allowed-source-ids", default="", help="Comma-separated nonzero IDs allowed in source masks.") p.add_argument("--allowed-sink-ids", default="", help="Comma-separated nonzero IDs allowed in sink masks.") p.add_argument("--max-id", type=int, default=0, help="If >0, fail when any ID exceeds this value.") p.add_argument( "--report", default="work/mask_master/validation_report.json", help="Output JSON report path.", ) p.add_argument("--fail-on-overlap", action="store_true", help="Fail if source and sink overlap on same tile pixels.") return p.parse_args() def parse_allowed_ids(text: str) -> Optional[set[int]]: text = text.strip() if not text: return None out: set[int] = set() for part in text.split(","): part = part.strip() if not part: continue out.add(int(part)) return out def tile_key_from_name(name: str) -> Optional[str]: stem = Path(name).stem m = _TILE_KEY_RE.match(stem) return m.group(1) if m else None def decode_ids(rgb: np.ndarray) -> np.ndarray: r = rgb[..., 0].astype(np.uint32) g = rgb[..., 1].astype(np.uint32) b = rgb[..., 2].astype(np.uint32) return b + (256 * g) + (65536 * r) def analyze_mask(path: Path) -> dict: arr = np.array(Image.open(path).convert("RGB"), dtype=np.uint8) ids = decode_ids(arr) u, c = np.unique(ids, return_counts=True) nonzero = [(int(i), int(n)) for i, n in zip(u.tolist(), c.tolist()) if i != 0] return { "path": str(path), "shape": [int(arr.shape[1]), int(arr.shape[0])], "unique_count": int(len(u)), "max_id": int(u[-1]) if len(u) > 0 else 0, "nonzero_count": int(np.count_nonzero(ids)), "nonzero_ids": nonzero, "ids_array": ids, } def collect_pngs(path: Path) -> list[Path]: if not path.exists(): return [] return sorted([p for p in path.glob("*.png") if p.is_file()]) def unexpected_ids(found: list[tuple[int, int]], allowed: Optional[set[int]]) -> list[int]: if allowed is None: return [] return sorted([i for i, _ in found if i not in allowed]) def build_overview(files: list[dict]) -> list[dict]: totals: Dict[int, dict] = {} for entry in files: tile_key = str(entry.get("tile_key") or "").strip() for ident, pixels in entry.get("nonzero_ids", []): node = totals.setdefault( int(ident), { "id": int(ident), "total_pixels": 0, "file_count": 0, "tile_keys": set(), }, ) node["total_pixels"] += int(pixels) node["file_count"] += 1 if tile_key: node["tile_keys"].add(tile_key) out = [] for ident in sorted(totals.keys()): node = totals[ident] out.append( { "id": int(node["id"]), "total_pixels": int(node["total_pixels"]), "file_count": int(node["file_count"]), "tile_keys": sorted(list(node["tile_keys"])), } ) return out def main() -> int: args = parse_args() allowed_source = parse_allowed_ids(args.allowed_source_ids) allowed_sink = parse_allowed_ids(args.allowed_sink_ids) report = { "schema_version": 1, "config": { "source_dir": args.source_dir, "sink_dir": args.sink_dir, "allowed_source_ids": sorted(list(allowed_source)) if allowed_source is not None else [], "allowed_sink_ids": sorted(list(allowed_sink)) if allowed_sink is not None else [], "max_id": args.max_id, "fail_on_overlap": bool(args.fail_on_overlap), }, "source": [], "source_overview": [], "sink": [], "sink_overview": [], "overlaps": [], "issues": [], } failed = False source_by_tile: Dict[str, np.ndarray] = {} sink_by_tile: Dict[str, np.ndarray] = {} for p in collect_pngs(Path(args.source_dir)): a = analyze_mask(p) key = tile_key_from_name(p.name) bad_ids = unexpected_ids(a["nonzero_ids"], allowed_source) if bad_ids: msg = f"[mask_validate_ids] Source {p.name} has unexpected IDs: {bad_ids}" report["issues"].append(msg) print(msg) failed = True if args.max_id > 0 and a["max_id"] > args.max_id: msg = f"[mask_validate_ids] Source {p.name} max_id={a['max_id']} exceeds max-id={args.max_id}" report["issues"].append(msg) print(msg) failed = True report["source"].append( { "path": a["path"], "tile_key": key or "", "shape": a["shape"], "unique_count": a["unique_count"], "max_id": a["max_id"], "nonzero_count": a["nonzero_count"], "nonzero_ids": a["nonzero_ids"], } ) if key: source_by_tile[key] = a["ids_array"] for p in collect_pngs(Path(args.sink_dir)): a = analyze_mask(p) key = tile_key_from_name(p.name) bad_ids = unexpected_ids(a["nonzero_ids"], allowed_sink) if bad_ids: msg = f"[mask_validate_ids] Sink {p.name} has unexpected IDs: {bad_ids}" report["issues"].append(msg) print(msg) failed = True if args.max_id > 0 and a["max_id"] > args.max_id: msg = f"[mask_validate_ids] Sink {p.name} max_id={a['max_id']} exceeds max-id={args.max_id}" report["issues"].append(msg) print(msg) failed = True report["sink"].append( { "path": a["path"], "tile_key": key or "", "shape": a["shape"], "unique_count": a["unique_count"], "max_id": a["max_id"], "nonzero_count": a["nonzero_count"], "nonzero_ids": a["nonzero_ids"], } ) if key: sink_by_tile[key] = a["ids_array"] shared_tiles = sorted(set(source_by_tile.keys()) & set(sink_by_tile.keys())) for key in shared_tiles: s = source_by_tile[key] k = sink_by_tile[key] if s.shape != k.shape: msg = f"[mask_validate_ids] Shape mismatch for tile {key}: source{s.shape} sink{k.shape}" report["issues"].append(msg) print(msg) failed = True continue overlap = int(np.count_nonzero((s > 0) & (k > 0))) if overlap > 0: entry = {"tile_key": key, "overlap_pixels": overlap} report["overlaps"].append(entry) msg = f"[mask_validate_ids] Overlap on tile {key}: {overlap} pixels" print(msg) if args.fail_on_overlap: report["issues"].append(msg) failed = True report["source_overview"] = build_overview(report["source"]) report["sink_overview"] = build_overview(report["sink"]) out_path = Path(args.report) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(report, indent=2), encoding="utf-8") print(f"[mask_validate_ids] Report written: {out_path}") print(f"[mask_validate_ids] Source files: {len(report['source'])}, Sink files: {len(report['sink'])}") print( "[mask_validate_ids] Source IDs: " f"{len(report['source_overview'])}, Sink IDs: {len(report['sink_overview'])}" ) print(f"[mask_validate_ids] Overlap records: {len(report['overlaps'])}") return 1 if failed else 0 if __name__ == "__main__": raise SystemExit(main())