Source code for opendetect.download
from __future__ import annotations
import hashlib
import os
import shutil
import sys
import tempfile
import zipfile
from pathlib import Path
from urllib.parse import urlparse
from urllib.request import Request, urlopen
from .registry import ModelSpec, get_model_spec, model_url
[docs]
def default_cache_dir() -> Path:
env_cache = os.getenv("OPENDETECT_CACHE_DIR")
if env_cache:
return Path(env_cache).expanduser().resolve()
xdg_cache = os.getenv("XDG_CACHE_HOME")
if xdg_cache:
return Path(xdg_cache).expanduser().resolve() / "opendetect" / "checkpoints"
return Path.home() / ".cache" / "opendetect" / "checkpoints"
def _print_progress(prefix: str, downloaded: int, total: int | None) -> None:
if total is None or total <= 0:
sys.stderr.write(f"\r{prefix}: {downloaded / (1024 * 1024):.1f} MiB")
else:
pct = min(100.0, 100.0 * downloaded / total)
sys.stderr.write(
f"\r{prefix}: {pct:6.2f}% ({downloaded / (1024 * 1024):.1f}/{total / (1024 * 1024):.1f} MiB)"
)
sys.stderr.flush()
[docs]
def download_url_to_file(
url: str,
destination: Path,
*,
expected_sha256: str | None = None,
show_progress: bool = True,
timeout_sec: int = 60,
) -> Path:
destination = destination.expanduser().resolve()
destination.parent.mkdir(parents=True, exist_ok=True)
req = Request(url, headers={"User-Agent": "opendetect"})
response = urlopen(req, timeout=timeout_sec)
total: int | None = None
content_length = response.headers.get("Content-Length")
if content_length and content_length.isdigit():
total = int(content_length)
tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=destination.parent)
hasher = hashlib.sha256() if expected_sha256 else None
downloaded = 0
try:
while True:
chunk = response.read(8192)
if not chunk:
break
tmp_file.write(chunk)
downloaded += len(chunk)
if hasher is not None:
hasher.update(chunk)
if show_progress:
_print_progress(destination.name, downloaded, total)
tmp_file.close()
if show_progress:
sys.stderr.write("\n")
if hasher is not None:
digest = hasher.hexdigest()
if digest.lower() != expected_sha256.lower():
raise RuntimeError(
f"SHA256 mismatch for {destination.name}: expected {expected_sha256}, got {digest}"
)
Path(tmp_file.name).replace(destination)
finally:
tmp_file.close()
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
return destination
def _extract_onnx_from_zip(zip_path: Path, output_dir: Path) -> Path:
temp_extract_dir = output_dir / f".extract-{zip_path.stem}"
if temp_extract_dir.exists():
shutil.rmtree(temp_extract_dir)
temp_extract_dir.mkdir(parents=True, exist_ok=True)
try:
with zipfile.ZipFile(zip_path, "r") as archive:
archive.extractall(temp_extract_dir)
candidates = sorted(temp_extract_dir.rglob("*.onnx"))
if not candidates:
raise RuntimeError(f"No .onnx file found in archive: {zip_path}")
preferred = [path for path in candidates if path.name.endswith("end2end.onnx")]
onnx_file = preferred[0] if preferred else candidates[0]
destination = output_dir / (zip_path.stem + ".onnx")
destination.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(onnx_file), str(destination))
return destination
finally:
shutil.rmtree(temp_extract_dir, ignore_errors=True)
[docs]
def download_model(
model: str | ModelSpec,
*,
cache_dir: Path | str | None = None,
force: bool = False,
show_progress: bool = True,
timeout_sec: int = 60,
) -> Path:
spec = model if isinstance(model, ModelSpec) else get_model_spec(model)
cache_root = (
Path(cache_dir).expanduser().resolve()
if cache_dir is not None
else default_cache_dir()
)
cache_root.mkdir(parents=True, exist_ok=True)
target_path = cache_root / spec.artifact_path
if target_path.exists() and not force:
return target_path
url = model_url(spec)
destination_name = Path(urlparse(url).path).name or spec.filename
download_path = target_path.with_name(destination_name)
download_url_to_file(
url,
download_path,
show_progress=show_progress,
timeout_sec=timeout_sec,
)
final_path = download_path
if download_path.suffix.lower() == ".zip":
final_path = _extract_onnx_from_zip(download_path, download_path.parent)
download_path.unlink(missing_ok=True)
if final_path != target_path:
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists():
target_path.unlink()
shutil.move(str(final_path), str(target_path))
return target_path