Source code for opendetect.detector

from __future__ import annotations

import difflib
from pathlib import Path

import cv2
import numpy as np

from .download import download_model
from .models import available_models, create_model, default_input_size
from .models.base import DetectorModel
from .registry import ModelSpec, get_model_spec
from .types import Detections, LoadedModelInfo


[docs] class Detector: def __init__( self, model: str = "rfdetr-m", *, model_path: str | Path | None = None, input_size: tuple[int, int] | None = None, hardware_acceleration: bool = True, tensor_rt: bool = False, mixed_precision: bool = False, threshold: float = 0.3, num_select: int = 300, class_ids: list[int] | None = None, auto_download: bool = True, cache_dir: str | Path | None = None, show_download_progress: bool = False, ) -> None: spec = self._try_get_spec(model) if spec is not None: implementation = spec.implementation else: implementation = model.strip().lower() if implementation not in available_models(): raise ValueError( f"Unknown model '{model}'. Use one of registry IDs or implementations: {available_models()}" ) resolved_input_size = ( input_size if input_size is not None else ( spec.input_size if spec is not None else default_input_size(implementation) ) ) resolved_model_path: Path | None source: str if model_path is not None: resolved_model_path = Path(model_path) source = "local-path" elif spec is not None and auto_download: resolved_model_path = download_model( spec, cache_dir=Path(cache_dir) if cache_dir is not None else None, show_progress=show_download_progress, ) source = "registry-download" else: resolved_model_path = None source = "backend-default" self._backend: DetectorModel = create_model( implementation, input_size=resolved_input_size, model_path=resolved_model_path, threshold=threshold, num_select=num_select, class_ids=class_ids, hardware_acceleration=hardware_acceleration, tensor_rt=tensor_rt, mixed_precision=mixed_precision, ) resolved_id = spec.model_id if spec is not None else None self.info = LoadedModelInfo( model_id=resolved_id, implementation=implementation, input_size=tuple(self._backend.input_size), model_path=self._backend.model_path, source=source, ) @property def backend(self) -> DetectorModel: return self._backend @property def class_names(self) -> list[str] | None: names = getattr(self._backend, "class_names", None) if names is not None: return [str(name) for name in names] default_names_fn = getattr(self._backend, "default_class_names", None) if callable(default_names_fn): fallback = default_names_fn() if fallback is not None: return [str(name) for name in fallback] return None @staticmethod def _try_get_spec(model: str) -> ModelSpec | None: try: return get_model_spec(model) except ValueError: return None @staticmethod def _normalize_class_name(name: str) -> str: normalized = name.strip().lower().replace("-", " ").replace("_", " ") normalized = " ".join(normalized.split()) return normalized.replace(" ", "")
[docs] def resolve_class_ids_from_names(self, names: list[str]) -> list[int]: class_names = self.class_names if class_names is None: raise ValueError("Class names are not available for this model.") normalized_to_id: dict[str, int] = {} for class_id, class_name in enumerate(class_names): normalized_to_id[self._normalize_class_name(class_name)] = class_id resolved: list[int] = [] unknown: list[str] = [] for name in names: key = self._normalize_class_name(name) class_id = normalized_to_id.get(key) if class_id is None: unknown.append(name) continue if class_id not in resolved: resolved.append(class_id) if unknown: known_names = list(class_names) hints = [] for value in unknown: match = difflib.get_close_matches(value, known_names, n=1, cutoff=0.6) if match: hints.append(f"{value!r} -> {match[0]!r}") hint_text = f" Suggestions: {', '.join(hints)}." if hints else "" raise ValueError( f"Unknown class name(s): {', '.join(repr(item) for item in unknown)}.{hint_text}" ) return resolved
[docs] def set_class_filter(self, class_ids: list[int] | None) -> list[int] | None: resolved = None if class_ids: resolved = [] for class_id in class_ids: class_id_int = int(class_id) if class_id_int not in resolved: resolved.append(class_id_int) class_names = self.class_names if class_names is not None: max_id = len(class_names) - 1 invalid = [idx for idx in resolved if idx < 0 or idx > max_id] if invalid: raise ValueError( f"Class IDs out of range: {invalid}. Valid range: 0..{max_id}." ) # Both current backends support dynamic class filter updates. self._backend.class_ids = resolved if resolved else None self._backend.class_ids_np = ( np.asarray(self._backend.class_ids, dtype=np.int64) if self._backend.class_ids is not None else None ) return self._backend.class_ids
[docs] def list_classes(self) -> list[tuple[int, str]]: class_names = self.class_names if class_names is None: return [] return list(enumerate(class_names))
@staticmethod def _validate_image(image: np.ndarray) -> None: if image.ndim != 3 or image.shape[2] != 3: raise ValueError("Expected image shape [H, W, 3].") def _to_rgb(self, image: np.ndarray, color: str) -> np.ndarray: self._validate_image(image) mode = color.lower() if mode == "rgb": return image if mode == "bgr": return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) raise ValueError("color must be 'rgb' or 'bgr'.")
[docs] def predict(self, image: np.ndarray, *, color: str = "bgr") -> Detections: image_rgb = self._to_rgb(image, color) return self._backend.predict_rgb_frame(image_rgb)
[docs] def annotate( self, image: np.ndarray, detections: Detections | None = None, *, color: str = "bgr", ) -> np.ndarray: mode = color.lower() if detections is None: detections = self.predict(image, color=color) if mode == "bgr": frame_bgr = image.copy() return self._backend.draw_detections_on_bgr_frame(frame_bgr, detections) if mode == "rgb": frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) annotated_bgr = self._backend.draw_detections_on_bgr_frame( frame_bgr, detections ) return cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) raise ValueError("color must be 'rgb' or 'bgr'.")
[docs] def predict_and_annotate( self, image: np.ndarray, *, color: str = "bgr" ) -> tuple[Detections, np.ndarray]: detections = self.predict(image, color=color) annotated = self.annotate(image, detections=detections, color=color) return detections, annotated
[docs] def infer_image_file( self, image_path: str | Path, *, output_path: str | Path | None = None, ) -> tuple[Detections, np.ndarray]: image_path = Path(image_path) frame_bgr = cv2.imread(str(image_path), cv2.IMREAD_COLOR) if frame_bgr is None: raise RuntimeError(f"Unable to open image: {image_path}") detections, annotated = self.predict_and_annotate(frame_bgr, color="bgr") if output_path is not None: out_path = Path(output_path) out_path.parent.mkdir(parents=True, exist_ok=True) if not cv2.imwrite(str(out_path), annotated): raise RuntimeError(f"Unable to write image: {out_path}") return detections, annotated
[docs] def infer_video_file( self, video_path: str | Path, *, output_path: str | Path, max_frames: int | None = None, ) -> int: video_path = Path(video_path) output_path = Path(output_path) cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Unable to open video: {video_path}") frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0: fps = 30.0 output_path.parent.mkdir(parents=True, exist_ok=True) suffix = output_path.suffix.lower() fourcc = cv2.VideoWriter_fourcc(*("MJPG" if suffix == ".avi" else "mp4v")) writer = cv2.VideoWriter( str(output_path), fourcc, fps, (frame_width, frame_height) ) if not writer.isOpened(): cap.release() raise RuntimeError(f"Unable to open video writer for: {output_path}") frames_processed = 0 try: while True: if max_frames is not None and frames_processed >= max_frames: break ok, frame_bgr = cap.read() if not ok: break _, annotated = self.predict_and_annotate(frame_bgr, color="bgr") writer.write(annotated) frames_processed += 1 finally: cap.release() writer.release() return frames_processed
[docs] def load_detector( model: str = "rfdetr-m", **kwargs, ) -> Detector: return Detector(model=model, **kwargs)