Files
GeoData/geodata_download.py

570 lines
19 KiB
Python

#!/usr/bin/env python3
"""Download configured geodata tiles into raw/ based on a TOML config."""
from __future__ import annotations
import argparse
import os
import shutil
import sys
import time
import queue
import threading
from concurrent.futures import Future, as_completed
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse
try:
import tomllib
except ImportError: # pragma: no cover - tomllib is required
raise SystemExit("tomllib is required (Python 3.11+).")
import requests
DEFAULT_CONFIG = "geodata_download.toml"
DEFAULT_OUTPUT_DIR = "raw"
OUTPUT_SUBDIRS = {
"dgm1": "dgm1",
"dom1": "dom1",
"dop20": "dop20",
"geb3dlo": os.path.join("citygml", "lod2"),
"citygml": os.path.join("citygml", "lod2"),
"bdom20rgbi": "bdom20rgbi",
"lpg": "lpolpg",
"lpo": "lpolpg",
"lpolpg": "lpolpg",
}
FILE_TYPE_SUBDIRS = {
"image": "jp2",
"worldfile": "j2w",
"metadata": "meta",
}
@dataclass(frozen=True)
class DownloadTask:
dataset: str
url: str
output_path: str
class DownloadLogger:
def __init__(
self,
log_file: Optional[str] = None,
log_format: Optional[str] = None,
report_progress: bool = True,
) -> None:
self._log_file = log_file
self._log_format = log_format
self._report_progress = report_progress
def log(self, level: str, message: str) -> None:
line = self._format(level, message)
if self._report_progress:
print(line)
if self._log_file:
with open(self._log_file, "a", encoding="utf-8") as fh:
fh.write(line + "\n")
def _format(self, level: str, message: str) -> str:
if not self._log_format:
return f"[{level}] {message}"
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
try:
return self._log_format.format(timestamp=timestamp, level=level, message=message)
except Exception:
return f"[{level}] {message}"
class DownloadProgress:
def __init__(self, total: int, enabled: bool) -> None:
self._total = max(total, 1)
self._enabled = enabled
self._downloaded = 0
self._missing = 0
self._failed = 0
self._bytes = 0
self._lock = threading.Lock()
self._start = time.time()
self._last_label = ""
self._last_render = 0.0
self._min_interval = 0.25
def add_bytes(self, bytes_delta: int, label: Optional[str] = None) -> None:
if not self._enabled or bytes_delta <= 0:
return
with self._lock:
self._bytes += bytes_delta
if label:
self._last_label = os.path.basename(label)
self._render_locked(force=False)
def set_counts(self, downloaded: int, missing: int, failed: int, label: Optional[str] = None) -> None:
if not self._enabled:
return
with self._lock:
self._downloaded = downloaded
self._missing = missing
self._failed = failed
if label:
self._last_label = os.path.basename(label)
self._render_locked(force=True)
def finish(self) -> None:
if not self._enabled:
return
sys.stderr.write("\n")
sys.stderr.flush()
def _render_locked(self, force: bool) -> None:
now = time.time()
if not force and (now - self._last_render) < self._min_interval:
return
self._last_render = now
elapsed = max(now - self._start, 0.001)
done = self._downloaded + self._missing + self._failed
rate = done / elapsed
remaining = max(self._total - done, 0)
eta = int(remaining / rate) if rate > 0 else 0
bytes_mb = self._bytes / (1024 * 1024)
bytes_gb = bytes_mb / 1024
byte_rate = bytes_mb / elapsed
width = 28
filled = int(width * done / self._total)
bar = "#" * filled + "-" * (width - filled)
line = (
f"\r[{bar}] {done}/{self._total} "
f"{rate:.2f}/s eta {eta}s ok={self._downloaded} miss={self._missing} fail={self._failed} "
f"{bytes_gb:.1f}GB {byte_rate:.1f}MB/s "
f"{self._last_label}"
)
sys.stderr.write(line[:200])
sys.stderr.flush()
class DaemonThreadPool:
"""Minimal daemon-thread pool that supports submit() + shutdown()."""
def __init__(self, max_workers: int) -> None:
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
self._tasks: queue.Queue = queue.Queue()
self._threads: list[threading.Thread] = []
self._shutdown = False
for idx in range(max_workers):
thread = threading.Thread(
name=f"download-worker-{idx}",
target=self._worker,
daemon=True,
)
thread.start()
self._threads.append(thread)
def submit(self, fn, *args, **kwargs) -> Future:
if self._shutdown:
raise RuntimeError("Thread pool already shutdown")
future: Future = Future()
self._tasks.put((future, fn, args, kwargs))
return future
def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None:
self._shutdown = True
if cancel_futures:
while True:
try:
item = self._tasks.get_nowait()
except queue.Empty:
break
if item is None:
continue
future, _, _, _ = item
future.cancel()
for _ in self._threads:
self._tasks.put(None)
if wait:
for thread in self._threads:
thread.join()
def _worker(self) -> None:
while True:
item = self._tasks.get()
if item is None:
return
future, fn, args, kwargs = item
if not future.set_running_or_notify_cancel():
continue
try:
result = fn(*args, **kwargs)
except BaseException as exc:
future.set_exception(exc)
else:
future.set_result(result)
def _load_toml(path: str) -> dict:
if not os.path.exists(path):
raise SystemExit(f"Config not found: {path}")
with open(path, "rb") as fh:
return tomllib.load(fh)
def _parse_tile_ranges(cfg: dict) -> Dict[str, dict]:
ranges = cfg.get("tile_ranges", {})
if not isinstance(ranges, dict):
raise SystemExit("tile_ranges must be a table in the TOML config.")
return ranges
def _iter_tiles(range_cfg: dict) -> Iterable[Tuple[int, int]]:
x_start = int(range_cfg["x_start"])
x_end = int(range_cfg["x_end"])
x_step = int(range_cfg.get("x_step", 1))
y_start = int(range_cfg["y_start"])
y_end = int(range_cfg["y_end"])
y_step = int(range_cfg.get("y_step", 1))
for x in range(x_start, x_end + 1, x_step):
for y in range(y_start, y_end + 1, y_step):
yield x, y
def _override_range(range_cfg: dict, start: Tuple[int, int], end: Tuple[int, int]) -> dict:
new_range = dict(range_cfg)
new_range["x_start"] = start[0]
new_range["y_start"] = start[1]
new_range["x_end"] = end[0]
new_range["y_end"] = end[1]
return new_range
def _resolve_output_dir(dataset_key: str, dataset_cfg: dict) -> str:
if dataset_key in OUTPUT_SUBDIRS:
return OUTPUT_SUBDIRS[dataset_key]
return dataset_cfg.get("output_subdir", dataset_key)
def _format_url(
template: str,
base_url: str,
x: int,
y: int,
extra: dict,
) -> str:
format_vars = {"base_url": base_url, "x": x, "y": y}
for key, value in extra.items():
if isinstance(value, (str, int, float)):
format_vars[key] = value
return template.format(**format_vars)
def _build_tasks(
cfg: dict,
datasets: Dict[str, dict],
tile_ranges: Dict[str, dict],
base_output_dir: str,
start_override: Optional[Tuple[int, int]],
end_override: Optional[Tuple[int, int]],
) -> List[DownloadTask]:
tasks: List[DownloadTask] = []
base_url = cfg.get("download", {}).get("base_url", "").rstrip("/")
for dataset_key, dataset_cfg in datasets.items():
tile_range_key = dataset_cfg.get("tile_range")
if not tile_range_key or tile_range_key not in tile_ranges:
raise SystemExit(f"{dataset_key}: tile_range not found: {tile_range_key}")
range_cfg = tile_ranges[tile_range_key]
if start_override and end_override:
range_cfg = _override_range(range_cfg, start_override, end_override)
base_subdir = _resolve_output_dir(dataset_key, dataset_cfg)
dataset_out_dir = os.path.join(base_output_dir, base_subdir)
for x, y in _iter_tiles(range_cfg):
if "files" in dataset_cfg:
for file_cfg in dataset_cfg["files"]:
file_type = file_cfg.get("type", "file")
file_subdir = FILE_TYPE_SUBDIRS.get(file_type, file_type)
out_dir = os.path.join(dataset_out_dir, file_subdir)
url = _format_url(file_cfg["url_template"], base_url, x, y, file_cfg)
filename = os.path.basename(urlparse(url).path)
tasks.append(DownloadTask(dataset_key, url, os.path.join(out_dir, filename)))
else:
url = _format_url(dataset_cfg["url_template"], base_url, x, y, dataset_cfg)
filename = os.path.basename(urlparse(url).path)
tasks.append(DownloadTask(dataset_key, url, os.path.join(dataset_out_dir, filename)))
return tasks
def _select_datasets(cfg: dict, requested: Optional[List[str]]) -> Dict[str, dict]:
datasets = cfg.get("datasets", {})
if not isinstance(datasets, dict) or not datasets:
raise SystemExit("datasets must be defined in the TOML config.")
if requested:
missing = [name for name in requested if name not in datasets]
if missing:
raise SystemExit(f"Unknown dataset(s): {', '.join(missing)}")
selected = {name: datasets[name] for name in requested}
else:
selected = {name: ds for name, ds in datasets.items() if ds.get("enabled", True)}
if not selected:
raise SystemExit("No datasets selected for download.")
return selected
def _safe_remove_dir(base_dir: str, rel_dir: str) -> None:
target = os.path.abspath(os.path.join(base_dir, rel_dir))
base = os.path.abspath(base_dir)
if os.path.commonpath([target, base]) != base:
raise SystemExit(f"Refusing to delete outside base dir: {target}")
if target == base:
raise SystemExit(f"Refusing to delete base dir: {target}")
if os.path.exists(target):
shutil.rmtree(target)
def _download_task(
session: requests.Session,
task: DownloadTask,
timeout: int,
verify: bool | str,
retries: int,
stop_event: threading.Event,
progress: DownloadProgress,
) -> Tuple[str, DownloadTask, Optional[str]]:
os.makedirs(os.path.dirname(task.output_path), exist_ok=True)
tmp_path = f"{task.output_path}.part"
if stop_event.is_set():
return "aborted", task, "Interrupted"
for attempt in range(retries + 1):
if stop_event.is_set():
return "aborted", task, "Interrupted"
try:
with session.get(task.url, stream=True, timeout=timeout, verify=verify) as resp:
if resp.status_code in (404, 410):
return "missing", task, f"HTTP {resp.status_code}"
if resp.status_code >= 400:
return "failed", task, f"HTTP {resp.status_code}"
resp.raise_for_status()
with open(tmp_path, "wb") as fh:
for chunk in resp.iter_content(chunk_size=1024 * 1024):
if stop_event.is_set():
return "aborted", task, "Interrupted"
if chunk:
fh.write(chunk)
progress.add_bytes(len(chunk), task.output_path)
os.replace(tmp_path, task.output_path)
return "downloaded", task, None
except requests.RequestException as exc:
if attempt >= retries:
return "failed", task, str(exc)
time.sleep(1.0 + attempt * 0.5)
except OSError as exc:
return "failed", task, str(exc)
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except OSError:
pass
return "failed", task, "Unknown error"
def run_download(
config_path: str,
requested_datasets: Optional[List[str]] = None,
start_override: Optional[Tuple[int, int]] = None,
end_override: Optional[Tuple[int, int]] = None,
clean_downloads: bool = False,
ca_bundle_override: Optional[str] = None,
) -> int:
cfg = _load_toml(config_path)
download_cfg = cfg.get("download", {})
tile_ranges = _parse_tile_ranges(cfg)
datasets = _select_datasets(cfg, requested_datasets)
logging_cfg = cfg.get("logging", {})
progress_enabled = bool(logging_cfg.get("report_progress", True))
logger = DownloadLogger(
logging_cfg.get("log_file"),
logging_cfg.get("log_format"),
progress_enabled,
)
configured_output_dir = download_cfg.get("output_directory")
base_output_dir = DEFAULT_OUTPUT_DIR
if configured_output_dir and os.path.normpath(configured_output_dir) != DEFAULT_OUTPUT_DIR:
logger.log(
"WARN",
f"Ignoring download.output_directory={configured_output_dir}; using raw/ to match pipeline.",
)
if clean_downloads:
for dataset_key, dataset_cfg in datasets.items():
_safe_remove_dir(base_output_dir, _resolve_output_dir(dataset_key, dataset_cfg))
tasks = _build_tasks(
cfg,
datasets,
tile_ranges,
base_output_dir,
start_override,
end_override,
)
if not tasks:
logger.log("INFO", "No download tasks generated.")
return 0
skip_existing = not clean_downloads
pending: List[DownloadTask] = []
skipped = 0
for task in tasks:
if skip_existing and os.path.exists(task.output_path):
skipped += 1
continue
pending.append(task)
if skipped:
logger.log("INFO", f"Skipped {skipped} existing file(s).")
if not pending:
logger.log("INFO", "Nothing to download after skipping existing files.")
return 0
verify_ssl = download_cfg.get("verify_ssl", True)
ca_bundle = ca_bundle_override or download_cfg.get("ca_bundle")
if verify_ssl and ca_bundle:
if os.path.exists(ca_bundle):
verify = ca_bundle
source = "CLI" if ca_bundle_override else "config"
logger.log("INFO", f"Using CA bundle ({source}): {ca_bundle}")
else:
verify = True
logger.log("WARN", f"CA bundle not found, using system trust: {ca_bundle}")
else:
verify = bool(verify_ssl)
if not verify_ssl:
logger.log("WARN", "TLS verification disabled by config.")
timeout = int(download_cfg.get("timeout_seconds", 300))
retries = int(download_cfg.get("retry_attempts", 3))
parallel = int(download_cfg.get("parallel_downloads", 4))
user_agent = download_cfg.get("user_agent", "geodata-download/1.0")
downloaded = 0
missing = 0
failed = 0
progress = DownloadProgress(len(pending), progress_enabled)
stop_event = threading.Event()
interrupted = False
with requests.Session() as session:
session.headers.update({"User-Agent": user_agent})
executor = DaemonThreadPool(max_workers=parallel)
futures = [
executor.submit(_download_task, session, task, timeout, verify, retries, stop_event, progress)
for task in pending
]
try:
for future in as_completed(futures):
status, task, detail = future.result()
if status == "downloaded":
downloaded += 1
elif status == "missing":
missing += 1
logger.log("WARN", f"Missing tile: {task.url} ({detail})")
elif status == "aborted":
failed += 1
else:
failed += 1
extra = f" ({detail})" if detail else ""
logger.log("ERROR", f"Failed: {task.url}{extra}")
progress.set_counts(downloaded, missing, failed, task.output_path)
except KeyboardInterrupt:
interrupted = True
stop_event.set()
logger.log("WARN", "Interrupted; stopping downloads.")
finally:
if interrupted:
executor.shutdown(wait=False, cancel_futures=True)
else:
executor.shutdown(wait=True)
progress.finish()
if interrupted:
return 130
logger.log(
"INFO",
f"Done. Downloaded={downloaded}, Missing={missing}, Failed={failed}, Skipped={skipped}.",
)
return 1 if failed else 0
def parse_args(argv: Optional[Iterable[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Download geodata tiles from TOML config.")
parser.add_argument(
"--config",
default=DEFAULT_CONFIG,
help="Path to geodata_download.toml.",
)
parser.add_argument(
"--datasets",
help="Comma-separated dataset keys to download (default: enabled datasets).",
)
parser.add_argument(
"--start",
nargs=2,
type=int,
metavar=("X", "Y"),
help="Override tile range start (x y).",
)
parser.add_argument(
"--end",
nargs=2,
type=int,
metavar=("X", "Y"),
help="Override tile range end (x y).",
)
parser.add_argument(
"--clean-downloads",
action="store_true",
help="Delete selected dataset folders before downloading.",
)
parser.add_argument(
"--ca-bundle",
help="Path to a CA bundle file to override the config.",
)
return parser.parse_args(argv)
def main(argv: Optional[Iterable[str]] = None) -> int:
args = parse_args(argv)
datasets = [name.strip() for name in args.datasets.split(",")] if args.datasets else None
start = tuple(args.start) if args.start else None
end = tuple(args.end) if args.end else None
if (start is None) != (end is None):
raise SystemExit("--start and --end must be provided together.")
return run_download(
config_path=args.config,
requested_datasets=datasets,
start_override=start,
end_override=end,
clean_downloads=args.clean_downloads,
ca_bundle_override=args.ca_bundle,
)
if __name__ == "__main__":
sys.exit(main())