Add SWE boundary mask pipeline and mask tooling
This commit is contained in:
246
scripts/mask_validate_ids.py
Normal file
246
scripts/mask_validate_ids.py
Normal file
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Validate source/sink ID mask integrity.
|
||||
|
||||
ID encoding:
|
||||
id = B + 256*G + 65536*R
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
_TILE_KEY_RE = re.compile(r"^(dgm\d+_\d+_\d+_\d+)")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Validate source/sink masks and report ID usage.")
|
||||
p.add_argument("--source-dir", default="raw/water_source_masks", help="Directory with source masks.")
|
||||
p.add_argument("--sink-dir", default="raw/water_sink_masks", help="Directory with sink masks.")
|
||||
p.add_argument("--allowed-source-ids", default="", help="Comma-separated nonzero IDs allowed in source masks.")
|
||||
p.add_argument("--allowed-sink-ids", default="", help="Comma-separated nonzero IDs allowed in sink masks.")
|
||||
p.add_argument("--max-id", type=int, default=0, help="If >0, fail when any ID exceeds this value.")
|
||||
p.add_argument(
|
||||
"--report",
|
||||
default="work/mask_master/validation_report.json",
|
||||
help="Output JSON report path.",
|
||||
)
|
||||
p.add_argument("--fail-on-overlap", action="store_true", help="Fail if source and sink overlap on same tile pixels.")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def parse_allowed_ids(text: str) -> Optional[set[int]]:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return None
|
||||
out: set[int] = set()
|
||||
for part in text.split(","):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
out.add(int(part))
|
||||
return out
|
||||
|
||||
|
||||
def tile_key_from_name(name: str) -> Optional[str]:
|
||||
stem = Path(name).stem
|
||||
m = _TILE_KEY_RE.match(stem)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def decode_ids(rgb: np.ndarray) -> np.ndarray:
|
||||
r = rgb[..., 0].astype(np.uint32)
|
||||
g = rgb[..., 1].astype(np.uint32)
|
||||
b = rgb[..., 2].astype(np.uint32)
|
||||
return b + (256 * g) + (65536 * r)
|
||||
|
||||
|
||||
def analyze_mask(path: Path) -> dict:
|
||||
arr = np.array(Image.open(path).convert("RGB"), dtype=np.uint8)
|
||||
ids = decode_ids(arr)
|
||||
u, c = np.unique(ids, return_counts=True)
|
||||
nonzero = [(int(i), int(n)) for i, n in zip(u.tolist(), c.tolist()) if i != 0]
|
||||
return {
|
||||
"path": str(path),
|
||||
"shape": [int(arr.shape[1]), int(arr.shape[0])],
|
||||
"unique_count": int(len(u)),
|
||||
"max_id": int(u[-1]) if len(u) > 0 else 0,
|
||||
"nonzero_count": int(np.count_nonzero(ids)),
|
||||
"nonzero_ids": nonzero,
|
||||
"ids_array": ids,
|
||||
}
|
||||
|
||||
|
||||
def collect_pngs(path: Path) -> list[Path]:
|
||||
if not path.exists():
|
||||
return []
|
||||
return sorted([p for p in path.glob("*.png") if p.is_file()])
|
||||
|
||||
|
||||
def unexpected_ids(found: list[tuple[int, int]], allowed: Optional[set[int]]) -> list[int]:
|
||||
if allowed is None:
|
||||
return []
|
||||
return sorted([i for i, _ in found if i not in allowed])
|
||||
|
||||
|
||||
def build_overview(files: list[dict]) -> list[dict]:
|
||||
totals: Dict[int, dict] = {}
|
||||
for entry in files:
|
||||
tile_key = str(entry.get("tile_key") or "").strip()
|
||||
for ident, pixels in entry.get("nonzero_ids", []):
|
||||
node = totals.setdefault(
|
||||
int(ident),
|
||||
{
|
||||
"id": int(ident),
|
||||
"total_pixels": 0,
|
||||
"file_count": 0,
|
||||
"tile_keys": set(),
|
||||
},
|
||||
)
|
||||
node["total_pixels"] += int(pixels)
|
||||
node["file_count"] += 1
|
||||
if tile_key:
|
||||
node["tile_keys"].add(tile_key)
|
||||
out = []
|
||||
for ident in sorted(totals.keys()):
|
||||
node = totals[ident]
|
||||
out.append(
|
||||
{
|
||||
"id": int(node["id"]),
|
||||
"total_pixels": int(node["total_pixels"]),
|
||||
"file_count": int(node["file_count"]),
|
||||
"tile_keys": sorted(list(node["tile_keys"])),
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
allowed_source = parse_allowed_ids(args.allowed_source_ids)
|
||||
allowed_sink = parse_allowed_ids(args.allowed_sink_ids)
|
||||
|
||||
report = {
|
||||
"schema_version": 1,
|
||||
"config": {
|
||||
"source_dir": args.source_dir,
|
||||
"sink_dir": args.sink_dir,
|
||||
"allowed_source_ids": sorted(list(allowed_source)) if allowed_source is not None else [],
|
||||
"allowed_sink_ids": sorted(list(allowed_sink)) if allowed_sink is not None else [],
|
||||
"max_id": args.max_id,
|
||||
"fail_on_overlap": bool(args.fail_on_overlap),
|
||||
},
|
||||
"source": [],
|
||||
"source_overview": [],
|
||||
"sink": [],
|
||||
"sink_overview": [],
|
||||
"overlaps": [],
|
||||
"issues": [],
|
||||
}
|
||||
|
||||
failed = False
|
||||
source_by_tile: Dict[str, np.ndarray] = {}
|
||||
sink_by_tile: Dict[str, np.ndarray] = {}
|
||||
|
||||
for p in collect_pngs(Path(args.source_dir)):
|
||||
a = analyze_mask(p)
|
||||
key = tile_key_from_name(p.name)
|
||||
bad_ids = unexpected_ids(a["nonzero_ids"], allowed_source)
|
||||
if bad_ids:
|
||||
msg = f"[mask_validate_ids] Source {p.name} has unexpected IDs: {bad_ids}"
|
||||
report["issues"].append(msg)
|
||||
print(msg)
|
||||
failed = True
|
||||
if args.max_id > 0 and a["max_id"] > args.max_id:
|
||||
msg = f"[mask_validate_ids] Source {p.name} max_id={a['max_id']} exceeds max-id={args.max_id}"
|
||||
report["issues"].append(msg)
|
||||
print(msg)
|
||||
failed = True
|
||||
report["source"].append(
|
||||
{
|
||||
"path": a["path"],
|
||||
"tile_key": key or "",
|
||||
"shape": a["shape"],
|
||||
"unique_count": a["unique_count"],
|
||||
"max_id": a["max_id"],
|
||||
"nonzero_count": a["nonzero_count"],
|
||||
"nonzero_ids": a["nonzero_ids"],
|
||||
}
|
||||
)
|
||||
if key:
|
||||
source_by_tile[key] = a["ids_array"]
|
||||
|
||||
for p in collect_pngs(Path(args.sink_dir)):
|
||||
a = analyze_mask(p)
|
||||
key = tile_key_from_name(p.name)
|
||||
bad_ids = unexpected_ids(a["nonzero_ids"], allowed_sink)
|
||||
if bad_ids:
|
||||
msg = f"[mask_validate_ids] Sink {p.name} has unexpected IDs: {bad_ids}"
|
||||
report["issues"].append(msg)
|
||||
print(msg)
|
||||
failed = True
|
||||
if args.max_id > 0 and a["max_id"] > args.max_id:
|
||||
msg = f"[mask_validate_ids] Sink {p.name} max_id={a['max_id']} exceeds max-id={args.max_id}"
|
||||
report["issues"].append(msg)
|
||||
print(msg)
|
||||
failed = True
|
||||
report["sink"].append(
|
||||
{
|
||||
"path": a["path"],
|
||||
"tile_key": key or "",
|
||||
"shape": a["shape"],
|
||||
"unique_count": a["unique_count"],
|
||||
"max_id": a["max_id"],
|
||||
"nonzero_count": a["nonzero_count"],
|
||||
"nonzero_ids": a["nonzero_ids"],
|
||||
}
|
||||
)
|
||||
if key:
|
||||
sink_by_tile[key] = a["ids_array"]
|
||||
|
||||
shared_tiles = sorted(set(source_by_tile.keys()) & set(sink_by_tile.keys()))
|
||||
for key in shared_tiles:
|
||||
s = source_by_tile[key]
|
||||
k = sink_by_tile[key]
|
||||
if s.shape != k.shape:
|
||||
msg = f"[mask_validate_ids] Shape mismatch for tile {key}: source{s.shape} sink{k.shape}"
|
||||
report["issues"].append(msg)
|
||||
print(msg)
|
||||
failed = True
|
||||
continue
|
||||
overlap = int(np.count_nonzero((s > 0) & (k > 0)))
|
||||
if overlap > 0:
|
||||
entry = {"tile_key": key, "overlap_pixels": overlap}
|
||||
report["overlaps"].append(entry)
|
||||
msg = f"[mask_validate_ids] Overlap on tile {key}: {overlap} pixels"
|
||||
print(msg)
|
||||
if args.fail_on_overlap:
|
||||
report["issues"].append(msg)
|
||||
failed = True
|
||||
|
||||
report["source_overview"] = build_overview(report["source"])
|
||||
report["sink_overview"] = build_overview(report["sink"])
|
||||
|
||||
out_path = Path(args.report)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
|
||||
print(f"[mask_validate_ids] Report written: {out_path}")
|
||||
print(f"[mask_validate_ids] Source files: {len(report['source'])}, Sink files: {len(report['sink'])}")
|
||||
print(
|
||||
"[mask_validate_ids] Source IDs: "
|
||||
f"{len(report['source_overview'])}, Sink IDs: {len(report['sink_overview'])}"
|
||||
)
|
||||
print(f"[mask_validate_ids] Overlap records: {len(report['overlaps'])}")
|
||||
|
||||
return 1 if failed else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user