workflows.auto_labeling

   1import datetime
   2import json
   3import logging
   4import os
   5import queue
   6import random
   7import re
   8import shutil
   9import signal
  10import subprocess
  11import sys
  12import time
  13from difflib import get_close_matches
  14from functools import partial
  15from pathlib import Path
  16from typing import Union
  17
  18import albumentations as A
  19import fiftyone as fo
  20import numpy as np
  21import psutil
  22import torch
  23import torch.multiprocessing as mp
  24from accelerate.test_utils.testing import get_backend
  25from datasets import Split
  26from fiftyone import ViewField as F
  27from huggingface_hub import HfApi, hf_hub_download
  28from PIL import Image
  29from torch.utils.data import DataLoader, Subset
  30from torch.utils.tensorboard import SummaryWriter
  31from torchvision.transforms.functional import to_pil_image
  32from tqdm import tqdm
  33from transformers import (
  34    AutoConfig,
  35    AutoModelForObjectDetection,
  36    AutoModelForZeroShotObjectDetection,
  37    AutoProcessor,
  38    EarlyStoppingCallback,
  39    Trainer,
  40    TrainingArguments,
  41)
  42from ultralytics import YOLO
  43from rfdetr import RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRLarge
  44import wandb
  45from config.config import (
  46    ACCEPTED_SPLITS,
  47    GLOBAL_SEED,
  48    HF_DO_UPLOAD,
  49    HF_ROOT,
  50    NUM_WORKERS,
  51    WANDB_ACTIVE,
  52    WORKFLOWS,
  53)
  54from utils.dataset_loader import get_supported_datasets
  55from utils.logging import configure_logging
  56from utils.sample_field_operations import add_sample_field
  57
  58
  59def get_dataset_and_model_from_hf_id(hf_id: str):
  60    """Extract dataset and model name from HuggingFace ID by matching against supported datasets."""
  61    # HF ID follows structure organization/dataset_model
  62    # Both dataset and model can contain "_" as well
  63
  64    # Remove organization (everything before the first "/")
  65    hf_id = hf_id.split("/", 1)[-1]
  66
  67    # Find all dataset names that appear in hf_id
  68    supported_datasets = get_supported_datasets()
  69    matches = [
  70        dataset_name for dataset_name in supported_datasets if dataset_name in hf_id
  71    ]
  72
  73    if not matches:
  74        logging.warning(
  75            f"Dataset name could not be extracted from Hugging Face ID {hf_id}"
  76        )
  77        dataset_name = "no_dataset_name"
  78    else:
  79        # Return the longest match (most specific)
  80        dataset_name = max(matches, key=len)
  81
  82    # Get model name by removing dataset name from hf_id
  83    model_name = hf_id.replace(dataset_name, "").strip("_")
  84    if not model_name:
  85        logging.warning(
  86            f"Model name could not be extracted from Hugging Face ID {hf_id}"
  87        )
  88        model_name = "no_model_name"
  89
  90    return dataset_name, model_name
  91
  92
  93# Handling timeouts
  94class TimeoutException(Exception):
  95    """Custom exception for handling dataloader timeouts."""
  96
  97    pass
  98
  99
 100def timeout_handler(signum, frame):
 101    raise TimeoutException("Dataloader creation timed out")
 102
 103
 104class ZeroShotInferenceCollateFn:
 105    """Collate function for zero-shot inference that prepares batches for model input."""
 106
 107    def __init__(
 108        self,
 109        hf_model_config_name,
 110        hf_processor,
 111        batch_size,
 112        object_classes,
 113        batch_classes,
 114    ):
 115        """Initialize the auto labeling model with the Hugging Face model config, processor, batch size, object classes, and batch classes."""
 116        try:
 117            self.hf_model_config_name = hf_model_config_name
 118            self.processor = hf_processor
 119            self.batch_size = batch_size
 120            self.object_classes = object_classes
 121            self.batch_classes = batch_classes
 122        except Exception as e:
 123            logging.error(f"Error in collate init of DataLoader: {e}")
 124
 125    def __call__(self, batch):
 126        """Processes a batch of data by preparing images and labels for model input."""
 127        try:
 128            images, labels = zip(*batch)
 129            target_sizes = [tuple(img.shape[1:]) for img in images]
 130
 131            # Adjustments for final batch
 132            n_images = len(images)
 133            if n_images < self.batch_size:
 134                self.batch_classes = [self.object_classes] * n_images
 135
 136            # Apply PIL transformation for specific models
 137            if self.hf_model_config_name == "OmDetTurboConfig":
 138                images = [to_pil_image(image) for image in images]
 139
 140            inputs = self.processor(
 141                text=self.batch_classes,
 142                images=images,
 143                return_tensors="pt",
 144                padding=True,  # Allow for differently sized images
 145            )
 146
 147            return inputs, labels, target_sizes, self.batch_classes
 148        except Exception as e:
 149            logging.error(f"Error in collate function of DataLoader: {e}")
 150
 151
 152class ZeroShotObjectDetection:
 153    """Zero-shot object detection using various HuggingFace models with multi-GPU support."""
 154
 155    def __init__(
 156        self,
 157        dataset_torch: torch.utils.data.Dataset,
 158        dataset_info,
 159        config,
 160        detections_path="./output/detections/",
 161        log_root="./logs/",
 162    ):
 163        """Initialize the zero-shot object detection labeler with dataset, configuration, and path settings."""
 164        self.dataset_torch = dataset_torch
 165        self.dataset_info = dataset_info
 166        self.dataset_name = dataset_info["name"]
 167        self.object_classes = config["object_classes"]
 168        self.detection_threshold = config["detection_threshold"]
 169        self.detections_root = os.path.join(detections_path, self.dataset_name)
 170        self.tensorboard_root = os.path.join(
 171            log_root, "tensorboard/zeroshot_object_detection"
 172        )
 173
 174        logging.info(f"Zero-shot models will look for {self.object_classes}")
 175
 176    def exclude_stored_predictions(
 177        self, dataset_v51: fo.Dataset, config, do_exclude=False
 178    ):
 179        """Checks for existing predictions and loads them from disk if available."""
 180        dataset_schema = dataset_v51.get_field_schema()
 181        models_splits_dict = {}
 182        for model_name, value in config["hf_models_zeroshot_objectdetection"].items():
 183            model_name_key = re.sub(r"[\W-]+", "_", model_name)
 184            pred_key = re.sub(
 185                r"[\W-]+", "_", "pred_zsod_" + model_name
 186            )  # od for Object Detection
 187            # Check if data already stored in V51 dataset
 188            if pred_key in dataset_schema and do_exclude is True:
 189                logging.warning(
 190                    f"Skipping model {model_name}. Predictions already stored in Voxel51 dataset."
 191                )
 192            # Check if data already stored on disk
 193            elif (
 194                os.path.isdir(os.path.join(self.detections_root, model_name_key))
 195                and do_exclude is True
 196            ):
 197                try:
 198                    logging.info(f"Loading {model_name} predictions from disk.")
 199                    temp_dataset = fo.Dataset.from_dir(
 200                        dataset_dir=os.path.join(self.detections_root, model_name_key),
 201                        dataset_type=fo.types.COCODetectionDataset,
 202                        name="temp_dataset",
 203                        data_path="data.json",
 204                    )
 205
 206                    # Copy all detections from stored dataset into our dataset
 207                    detections = temp_dataset.values("detections.detections")
 208                    add_sample_field(
 209                        dataset_v51,
 210                        pred_key,
 211                        fo.EmbeddedDocumentField,
 212                        embedded_doc_type=fo.Detections,
 213                    )
 214                    dataset_v51.set_values(f"{pred_key}.detections", detections)
 215                except Exception as e:
 216                    logging.error(
 217                        f"Data in {os.path.join(self.detections_root, model_name_key)} could not be loaded. Error: {e}"
 218                    )
 219                finally:
 220                    fo.delete_dataset("temp_dataset")
 221            # Assign model to be run
 222            else:
 223                models_splits_dict[model_name] = value
 224
 225        logging.info(f"Models to be run: {models_splits_dict}")
 226        return models_splits_dict
 227
 228    # Worker functions
 229    def update_queue_sizes_worker(
 230        self, queues, queue_sizes, largest_queue_index, max_queue_size
 231    ):
 232        """Monitor and manage multiple result queues for balanced processing."""
 233        experiment_name = f"queue_size_monitor_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
 234        log_directory = os.path.join(
 235            self.tensorboard_root, self.dataset_name, experiment_name
 236        )
 237        wandb.tensorboard.patch(root_logdir=log_directory)
 238        if WANDB_ACTIVE:
 239            wandb.init(
 240                name=f"queue_size_monitor_{os.getpid()}",
 241                job_type="inference",
 242                project="Zero Shot Object Detection",
 243            )
 244        writer = SummaryWriter(log_dir=log_directory)
 245
 246        step = 0
 247
 248        while True:
 249            for i, queue in enumerate(queues):
 250                queue_sizes[i] = queue.qsize()
 251                writer.add_scalar(f"queue_size/items/{i}", queue_sizes[i], step)
 252
 253            step += 1
 254
 255            # Find the index of the largest queue
 256            max_size = max(queue_sizes)
 257            max_index = queue_sizes.index(max_size)
 258
 259            # Calculate the total size of all queues
 260            total_size = sum(queue_sizes)
 261
 262            # If total_size is greater than 0, calculate the probabilities
 263            if total_size > 0:
 264                # Normalize the queue sizes by the max_queue_size
 265                normalized_sizes = [size / max_queue_size for size in queue_sizes]
 266
 267                # Calculate probabilities based on normalized sizes
 268                probabilities = [
 269                    size / sum(normalized_sizes) for size in normalized_sizes
 270                ]
 271
 272                # Use random.choices with weights (probabilities)
 273                chosen_queue_index = random.choices(
 274                    range(len(queues)), weights=probabilities, k=1
 275                )[0]
 276
 277                largest_queue_index.value = chosen_queue_index
 278            else:
 279                largest_queue_index.value = max_index
 280
 281            time.sleep(0.1)
 282
 283    def process_outputs_worker(
 284        self,
 285        result_queues,
 286        largest_queue_index,
 287        inference_finished,
 288        max_queue_size,
 289        wandb_activate=False,
 290    ):
 291        """Process model outputs from result queues and save to dataset."""
 292        configure_logging()
 293        logging.info(f"Process ID: {os.getpid()}. Results processing process started")
 294        dataset_v51 = fo.load_dataset(self.dataset_name)
 295        processing_successful = None
 296
 297        # Logging
 298        experiment_name = f"post_process_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
 299        log_directory = os.path.join(
 300            self.tensorboard_root, self.dataset_name, experiment_name
 301        )
 302        wandb.tensorboard.patch(root_logdir=log_directory)
 303        if WANDB_ACTIVE and wandb_activate:
 304            wandb.init(
 305                name=f"post_process_{os.getpid()}",
 306                job_type="inference",
 307                project="Zero Shot Object Detection",
 308            )
 309        writer = SummaryWriter(log_dir=log_directory)
 310        n_processed_images = 0
 311
 312        logging.info(f"Post-Processor {os.getpid()} starting loop.")
 313
 314        while True:
 315            results_queue = result_queues[largest_queue_index.value]
 316            writer.add_scalar(
 317                f"post_processing/selected_queue",
 318                largest_queue_index.value,
 319                n_processed_images,
 320            )
 321
 322            if results_queue.qsize() == max_queue_size:
 323                logging.warning(
 324                    f"Queue full: {results_queue.qsize()}. Consider increasing number of post-processing workers."
 325                )
 326
 327            # Exit only when inference is finished and the queue is empty
 328            if inference_finished.value and results_queue.empty():
 329                dataset_v51.save()
 330                logging.info(
 331                    f"Post-processing worker {os.getpid()} has finished all outputs."
 332                )
 333                break
 334
 335            # Process results from the queue if available
 336            if not results_queue.empty():
 337                try:
 338                    time_start = time.time()
 339
 340                    result = results_queue.get_nowait()
 341
 342                    processing_successful = self.process_outputs(
 343                        dataset_v51,
 344                        result,
 345                        self.object_classes,
 346                        self.detection_threshold,
 347                    )
 348
 349                    # Performance logging
 350                    n_images = len(result["labels"])
 351                    time_end = time.time()
 352                    duration = time_end - time_start
 353                    batches_per_second = 1 / duration
 354                    frames_per_second = batches_per_second * n_images
 355                    n_processed_images += n_images
 356                    writer.add_scalar(
 357                        f"post_processing/frames_per_second",
 358                        frames_per_second,
 359                        n_processed_images,
 360                    )
 361
 362                    del result  # Explicit removal from device
 363
 364                except Exception as e:
 365                    continue
 366
 367            else:
 368                continue
 369
 370        writer.close()
 371        wandb.finish(exit_code=0)
 372        return processing_successful  # Return last processing status
 373
 374    def gpu_worker(
 375        self,
 376        gpu_id,
 377        cpu_cores,
 378        task_queue,
 379        results_queue,
 380        done_event,
 381        post_processing_finished,
 382        set_cpu_affinity=False,
 383    ):
 384        """Run model inference on specified GPU with dedicated CPU cores."""
 385        dataset_v51 = fo.load_dataset(
 386            self.dataset_name
 387        )  # NOTE Only for the case of sequential processing
 388        configure_logging()
 389        # Set CPU
 390        if set_cpu_affinity:
 391            # Allow only certain CPU cores
 392            psutil.Process().cpu_affinity(cpu_cores)
 393        logging.info(f"Available CPU cores: {psutil.Process().cpu_affinity()}")
 394        max_n_cpus = len(cpu_cores)
 395        torch.set_num_threads(max_n_cpus)
 396
 397        # Set GPU
 398        logging.info(f"GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
 399        device = torch.device(f"cuda:{gpu_id}")
 400
 401        run_successful = None
 402        with torch.cuda.device(gpu_id):
 403            while True:
 404                if post_processing_finished.value and task_queue.empty():
 405                    # Keep alive until post-processing is done
 406                    break
 407
 408                if task_queue.empty():
 409                    done_event.set()
 410
 411                if not task_queue.empty():
 412                    try:
 413                        task_metadata = task_queue.get(
 414                            timeout=5
 415                        )  # Timeout to prevent indefinite blocking
 416                    except Exception as e:
 417                        break  # Exit if no more tasks
 418                    run_successful = self.model_inference(
 419                        task_metadata,
 420                        device,
 421                        self.dataset_torch,
 422                        dataset_v51,
 423                        self.object_classes,
 424                        results_queue,
 425                        self.tensorboard_root,
 426                    )
 427                    logging.info(
 428                        f"Worker for GPU {gpu_id} finished run successful: {run_successful}"
 429                    )
 430                else:
 431                    continue
 432        return run_successful  # Return last processing status
 433
 434    def eval_and_export_worker(self, models_ready_queue, n_models):
 435        """Evaluate model performance and export results for completed models."""
 436        configure_logging()
 437        logging.info(f"Process ID: {os.getpid()}. Eval-and-export process started")
 438
 439        dataset = fo.load_dataset(self.dataset_name)
 440        run_successful = None
 441        models_done = 0
 442
 443        while True:
 444            if not models_ready_queue.empty():
 445                try:
 446                    dict = models_ready_queue.get(
 447                        timeout=5
 448                    )  # Timeout to prevent indefinite blocking
 449                    model_name = dict["model_name"]
 450                    pred_key = re.sub(r"[\W-]+", "_", "pred_zsod_" + model_name)
 451                    eval_key = re.sub(r"[\W-]+", "_", "eval_zsod_" + model_name)
 452                    dataset.reload()
 453                    run_successful = self.eval_and_export(
 454                        dataset, model_name, pred_key, eval_key
 455                    )
 456                    models_done += 1
 457                    logging.info(
 458                        f"Evaluation and export of {models_done}/{n_models} models done."
 459                    )
 460                except Exception as e:
 461                    logging.error(f"Error in eval-and-export worker: {e}")
 462                    continue
 463
 464            if models_done == n_models:
 465                break
 466
 467        return run_successful
 468
 469    # Functionality functions
 470    def model_inference(
 471        self,
 472        metadata: dict,
 473        device: str,
 474        dataset: torch.utils.data.Dataset,
 475        dataset_v51: fo.Dataset,
 476        object_classes: list,
 477        results_queue: Union[queue.Queue, mp.Queue],
 478        root_log_dir: str,
 479        persistent_workers: bool = False,
 480    ):
 481        """Model inference method running zero-shot object detection on provided dataset and device, returning success status."""
 482        writer = None
 483        run_successful = True
 484        processor, model, inputs, outputs, result, dataloader = (
 485            None,
 486            None,
 487            None,
 488            None,
 489            None,
 490            None,
 491        )  # For finally block
 492
 493        # Timeout handler
 494        dataloader_timeout = 60
 495        signal.signal(signal.SIGALRM, timeout_handler)
 496
 497        try:
 498            # Metadata
 499            run_id = metadata["run_id"]
 500            model_name = metadata["model_name"]
 501            dataset_name = metadata["dataset_name"]
 502            is_subset = metadata["is_subset"]
 503            batch_size = metadata["batch_size"]
 504
 505            logging.info(
 506                f"Process ID: {os.getpid()}, Run ID: {run_id}, Device: {device}, Model: {model_name}"
 507            )
 508
 509            # Load the model
 510            logging.info(f"Loading model {model_name}")
 511            processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
 512            model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name)
 513            model = model.to(device, non_blocking=True)
 514            model.eval()
 515            hf_model_config = AutoConfig.from_pretrained(model_name)
 516            hf_model_config_name = type(hf_model_config).__name__
 517            batch_classes = [object_classes] * batch_size
 518            logging.info(f"Loaded model type {hf_model_config_name}")
 519
 520            # Dataloader
 521            logging.info("Generating dataloader")
 522            if is_subset:
 523                chunk_index_start = metadata["chunk_index_start"]
 524                chunk_index_end = metadata["chunk_index_end"]
 525                logging.info(f"Length of dataset: {len(dataset)}")
 526                logging.info(f"Subset start index: {chunk_index_start}")
 527                logging.info(f"Subset stop index: {chunk_index_end}")
 528                dataset = Subset(dataset, range(chunk_index_start, chunk_index_end))
 529
 530            zero_shot_inference_preprocessing = ZeroShotInferenceCollateFn(
 531                hf_model_config_name=hf_model_config_name,
 532                hf_processor=processor,
 533                object_classes=object_classes,
 534                batch_size=batch_size,
 535                batch_classes=batch_classes,
 536            )
 537            num_workers = WORKFLOWS["auto_labeling_zero_shot"]["n_worker_dataloader"]
 538            prefetch_factor = WORKFLOWS["auto_labeling_zero_shot"][
 539                "prefetch_factor_dataloader"
 540            ]
 541            dataloader = DataLoader(
 542                dataset,
 543                batch_size=batch_size,
 544                shuffle=False,
 545                num_workers=num_workers,
 546                persistent_workers=persistent_workers,
 547                pin_memory=True,
 548                prefetch_factor=prefetch_factor,
 549                collate_fn=zero_shot_inference_preprocessing,
 550            )
 551
 552            dataloader_length = len(dataloader)
 553            if dataloader_length < 1:
 554                logging.error(
 555                    f"Dataloader has insufficient data: {dataloader_length} entries. Please check your dataset and DataLoader configuration."
 556                )
 557
 558            # Logging
 559            experiment_name = f"{model_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{device}"
 560            log_directory = os.path.join(root_log_dir, dataset_name, experiment_name)
 561            wandb.tensorboard.patch(root_logdir=log_directory)
 562            if WANDB_ACTIVE:
 563                wandb.init(
 564                    name=f"{model_name}_{device}",
 565                    job_type="inference",
 566                    project="Zero Shot Object Detection",
 567                    config=metadata,
 568                )
 569            writer = SummaryWriter(log_dir=log_directory)
 570
 571            # Inference Loop
 572            logging.info(f"{os.getpid()}: Starting inference loop5")
 573            n_processed_images = 0
 574            for inputs, labels, target_sizes, batch_classes in tqdm(
 575                dataloader, desc="Inference Loop"
 576            ):
 577                signal.alarm(dataloader_timeout)
 578                try:
 579                    time_start = time.time()
 580                    n_images = len(labels)
 581                    inputs = inputs.to(device, non_blocking=True)
 582
 583                    with torch.amp.autocast("cuda"), torch.inference_mode():
 584                        outputs = model(**inputs)
 585
 586                    result = {
 587                        "inputs": inputs,
 588                        "outputs": outputs,
 589                        "processor": processor,
 590                        "target_sizes": target_sizes,
 591                        "labels": labels,
 592                        "model_name": model_name,
 593                        "hf_model_config_name": hf_model_config_name,
 594                        "batch_classes": batch_classes,
 595                    }
 596
 597                    logging.debug(f"{os.getpid()}: Putting result into queue")
 598
 599                    results_queue.put(
 600                        result, timeout=60
 601                    )  # Ditch data only after 60 seconds
 602
 603                    # Logging
 604                    time_end = time.time()
 605                    duration = time_end - time_start
 606                    batches_per_second = 1 / duration
 607                    frames_per_second = batches_per_second * n_images
 608                    n_processed_images += n_images
 609                    logging.debug(
 610                        f"{os.getpid()}: Number of processes images: {n_processed_images}"
 611                    )
 612                    writer.add_scalar(
 613                        f"inference/frames_per_second",
 614                        frames_per_second,
 615                        n_processed_images,
 616                    )
 617
 618                except TimeoutException:
 619                    logging.warning(
 620                        f"Dataloader loop got stuck. Continuing with next batch."
 621                    )
 622                    continue
 623
 624                finally:
 625                    signal.alarm(0)  # Cancel the alarm
 626
 627            # Flawless execution
 628            wandb_exit_code = 0
 629
 630        except Exception as e:
 631            wandb_exit_code = 1
 632            run_successful = False
 633            logging.error(f"Error in Process {os.getpid()}: {e}")
 634        finally:
 635            try:
 636                wandb.finish(exit_code=wandb_exit_code)
 637            except:
 638                pass
 639
 640            # Explicit removal from device
 641            del (
 642                processor,
 643                model,
 644                inputs,
 645                outputs,
 646                result,
 647                dataloader,
 648            )
 649
 650            torch.cuda.empty_cache()
 651            wandb.tensorboard.unpatch()
 652            if writer:
 653                writer.close()
 654            return run_successful
 655
 656    def process_outputs(self, dataset_v51, result, object_classes, detection_threshold):
 657        """Process outputs from object detection models, extracting bounding boxes and labels to save to the dataset."""
 658        try:
 659            inputs = result["inputs"]
 660            outputs = result["outputs"]
 661            target_sizes = result["target_sizes"]
 662            labels = result["labels"]
 663            model_name = result["model_name"]
 664            hf_model_config_name = result["hf_model_config_name"]
 665            batch_classes = result["batch_classes"]
 666            processor = result["processor"]
 667
 668            # Processing output
 669            if hf_model_config_name == "GroundingDinoConfig":
 670                results = processor.post_process_grounded_object_detection(
 671                    outputs,
 672                    inputs.input_ids,
 673                    box_threshold=detection_threshold,
 674                    text_threshold=detection_threshold,
 675                )
 676            elif hf_model_config_name in ["Owlv2Config", "OwlViTConfig"]:
 677                results = processor.post_process_grounded_object_detection(
 678                    outputs=outputs,
 679                    threshold=detection_threshold,
 680                    target_sizes=target_sizes,
 681                    text_labels=batch_classes,
 682                )
 683            elif hf_model_config_name == "OmDetTurboConfig":
 684                results = processor.post_process_grounded_object_detection(
 685                    outputs,
 686                    text_labels=batch_classes,
 687                    threshold=detection_threshold,
 688                    nms_threshold=detection_threshold,
 689                    target_sizes=target_sizes,
 690                )
 691            else:
 692                logging.error(f"Invalid model name: {hf_model_config_name}")
 693
 694            if not len(results) == len(target_sizes) == len(labels):
 695                logging.error(
 696                    f"Lengths of results, target_sizes, and labels do not match: {len(results)}, {len(target_sizes)}, {len(labels)}"
 697                )
 698            for result, size, target in zip(results, target_sizes, labels):
 699                boxes, scores, labels = (
 700                    result["boxes"],
 701                    result["scores"],
 702                    result["text_labels"],
 703                )
 704
 705                img_height = size[0]
 706                img_width = size[1]
 707
 708                detections = []
 709                for box, score, label in zip(boxes, scores, labels):
 710                    processing_successful = True
 711                    if hf_model_config_name == "GroundingDinoConfig":
 712                        # Outputs do not comply with given labels
 713                        # Grounding DINO outputs multiple pairs of object boxes and noun phrases for a given (Image, Text) pair
 714                        # There can be either multiple labels per output ("bike van"), incomplete ones ("motorcyc"), or broken ones ("##cic")
 715                        processed_label = label.split()[
 716                            0
 717                        ]  # Assume first output is the best output
 718                        if processed_label in object_classes:
 719                            label = processed_label
 720                            top_left_x = box[0].item()
 721                            top_left_y = box[1].item()
 722                            box_width = (box[2] - box[0]).item()
 723                            box_height = (box[3] - box[1]).item()
 724                        else:
 725                            matches = get_close_matches(
 726                                processed_label, object_classes, n=1, cutoff=0.6
 727                            )
 728                            selected_label = matches[0] if matches else None
 729                            if selected_label:
 730                                logging.debug(
 731                                    f"Mapped output '{processed_label}' to class '{selected_label}'"
 732                                )
 733                                label = selected_label
 734                                top_left_x = box[0].item()
 735                                top_left_y = box[1].item()
 736                                box_width = (box[2] - box[0]).item()
 737                                box_height = (box[3] - box[1]).item()
 738                            else:
 739                                logging.debug(
 740                                    f"Skipped detection with {hf_model_config_name} due to unclear output: {label}"
 741                                )
 742                                processing_successful = False
 743
 744                    elif hf_model_config_name in [
 745                        "Owlv2Config",
 746                        "OwlViTConfig",
 747                        "OmDetTurboConfig",
 748                    ]:
 749                        top_left_x = box[0].item() / img_width
 750                        top_left_y = box[1].item() / img_height
 751                        box_width = (box[2].item() - box[0].item()) / img_width
 752                        box_height = (box[3].item() - box[1].item()) / img_height
 753
 754                    if (
 755                        processing_successful
 756                    ):  # Skip GroundingDinoConfig labels that could not be processed
 757                        detection = fo.Detection(
 758                            label=label,
 759                            bounding_box=[
 760                                top_left_x,
 761                                top_left_y,
 762                                box_width,
 763                                box_height,
 764                            ],
 765                            confidence=score.item(),
 766                        )
 767                        detection["bbox_area"] = (
 768                            detection["bounding_box"][2] * detection["bounding_box"][3]
 769                        )
 770                        detections.append(detection)
 771
 772                # Attach label to V51 dataset
 773                pred_key = re.sub(
 774                    r"[\W-]+", "_", "pred_zsod_" + model_name
 775                )  # zsod Zero-Shot Object Deection
 776                sample = dataset_v51[target["image_id"]]
 777                sample[pred_key] = fo.Detections(detections=detections)
 778                sample.save()
 779
 780        except Exception as e:
 781            logging.error(f"Error in processing outputs: {e}")
 782            processing_successful = False
 783        finally:
 784            return processing_successful
 785
 786    def eval_and_export(self, dataset_v51, model_name, pred_key, eval_key):
 787        """Populate dataset with evaluation results (if ground_truth available)"""
 788        try:
 789            dataset_v51.evaluate_detections(
 790                pred_key,
 791                gt_field="ground_truth",
 792                eval_key=eval_key,
 793                compute_mAP=True,
 794            )
 795        except Exception as e:
 796            logging.warning(f"Evaluation not possible: {e}")
 797
 798        # Store labels https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.export
 799        model_name_key = re.sub(r"[\W-]+", "_", model_name)
 800        dataset_v51.export(
 801            export_dir=os.path.join(self.detections_root, model_name_key),
 802            dataset_type=fo.types.COCODetectionDataset,
 803            data_path="data.json",
 804            export_media=None,  # "manifest",
 805            label_field=pred_key,
 806            progress=True,
 807        )
 808        return True
 809
 810
 811class UltralyticsObjectDetection:
 812    """Object detection using Ultralytics YOLO models with training and inference support."""
 813
 814    def __init__(self, dataset, config):
 815        """Initialize with dataset, config, and setup paths for model and data."""
 816        self.dataset = dataset
 817        self.config = config
 818        self.ultralytics_data_path = os.path.join(
 819            config["export_dataset_root"], config["v51_dataset_name"]
 820        )
 821
 822        self.hf_hub_model_id = (
 823            f"{HF_ROOT}/"
 824            + f"{config['v51_dataset_name']}_{config['model_name']}".replace("/", "_")
 825        )
 826
 827        self.export_root = "output/models/ultralytics/"
 828        self.export_folder = os.path.join(
 829            self.export_root, self.config["v51_dataset_name"]
 830        )
 831
 832        self.model_path = os.path.join(
 833            self.export_folder, self.config["model_name"], "weights", "best.pt"
 834        )
 835
 836    @staticmethod
 837    def export_data(
 838        dataset, dataset_info, export_dataset_root, label_field="ground_truth"
 839    ):
 840        """Export dataset to YOLO format for Ultralytics training."""
 841        ultralytics_data_path = os.path.join(export_dataset_root, dataset_info["name"])
 842        # Delete export directory if it already exists
 843        if os.path.exists(ultralytics_data_path):
 844            shutil.rmtree(ultralytics_data_path)
 845
 846        logging.info("Exporting data for training with Ultralytics")
 847        classes = dataset.distinct(f"{label_field}.detections.label")
 848
 849        # Make directory
 850        os.makedirs(ultralytics_data_path, exist_ok=False)
 851
 852        for split in ACCEPTED_SPLITS:
 853            split_view = dataset.match_tags(split)
 854
 855            if split == "val" or split == "train":  # YOLO expects train and val
 856                split_view.export(
 857                    export_dir=ultralytics_data_path,
 858                    dataset_type=fo.types.YOLOv5Dataset,
 859                    label_field=label_field,
 860                    classes=classes,
 861                    split=split,
 862                )
 863
 864    def train(self):
 865        """Train the YOLO model for object detection using Ultralytics and optionally upload to Hugging Face."""
 866        model = YOLO(self.config["model_name"], task="detect")
 867        # https://docs.ultralytics.com/modes/train/#train-settings
 868
 869        # Use all available GPUs
 870        device = "0"  # Default to GPU 0
 871        if torch.cuda.device_count() > 1:
 872            device = ",".join(map(str, range(torch.cuda.device_count())))
 873
 874        results = model.train(
 875            data=f"{self.ultralytics_data_path}/dataset.yaml",
 876            epochs=self.config["epochs"],
 877            project=self.export_folder,
 878            name=self.config["model_name"],
 879            patience=self.config["patience"],
 880            batch=self.config["batch_size"],
 881            imgsz=self.config["img_size"],
 882            multi_scale=self.config["multi_scale"],
 883            cos_lr=self.config["cos_lr"],
 884            seed=GLOBAL_SEED,
 885            optimizer="AdamW",  # "auto" as default
 886            pretrained=True,
 887            exist_ok=True,
 888            amp=True,
 889            device=device
 890        )
 891        metrics = model.val()
 892        logging.info(f"Model Performance: {metrics}")
 893
 894        # Upload model to Hugging Face
 895        if HF_DO_UPLOAD:
 896            logging.info(f"Uploading model {self.model_path} to Hugging Face.")
 897            api = HfApi()
 898            api.create_repo(
 899                self.hf_hub_model_id, private=True, repo_type="model", exist_ok=True
 900            )
 901            api.upload_file(
 902                path_or_fileobj=self.model_path,
 903                path_in_repo="best.pt",
 904                repo_id=self.hf_hub_model_id,
 905                repo_type="model",
 906            )
 907
 908    def inference(self, gt_field="ground_truth"):
 909        """Performs inference using YOLO model on a dataset, with options to evaluate results."""
 910        logging.info(f"Running inference on dataset {self.config['v51_dataset_name']}")
 911        inference_settings = self.config["inference_settings"]
 912
 913        dataset_name = None
 914        model_name = self.config["model_name"]
 915
 916        model_hf = inference_settings["model_hf"]
 917        if model_hf is not None:
 918            # Use model manually defined in config.
 919            # This way models can be used for inference which were trained on a different dataset
 920            dataset_name, _ = get_dataset_and_model_from_hf_id(model_hf)
 921
 922            # Set up directories
 923            download_dir = os.path.join(
 924                self.export_root, dataset_name, model_name, "weights"
 925            )
 926            os.makedirs(os.path.join(download_dir), exist_ok=True)
 927
 928            self.model_path = os.path.join(download_dir, "best.pt")
 929
 930            # Create directories if they don't exist
 931
 932            file_path = hf_hub_download(
 933                repo_id=model_hf,
 934                filename="best.pt",
 935                local_dir=download_dir,
 936            )
 937        else:
 938            # Automatically determine model based on dataset
 939            dataset_name = self.config["v51_dataset_name"]
 940
 941            try:
 942                if os.path.exists(self.model_path):
 943                    file_path = self.model_path
 944                    logging.info(f"Loading model {model_name} from disk: {file_path}")
 945                else:
 946                    download_dir = self.model_path.replace("best.pt", "")
 947                    os.makedirs(download_dir, exist_ok=True)
 948                    logging.info(
 949                        f"Downloading model {self.hf_hub_model_id} from Hugging Face to {download_dir}"
 950                    )
 951                    file_path = hf_hub_download(
 952                        repo_id=self.hf_hub_model_id,
 953                        filename="best.pt",
 954                        local_dir=download_dir,
 955                    )
 956            except Exception as e:
 957                logging.error(f"Failed to load or download model: {str(e)}.")
 958                return False
 959
 960        pred_key = f"pred_od_{model_name}-{dataset_name}"
 961        logging.info(f"Using model {self.model_path} for inference.")
 962        model = YOLO(self.model_path)
 963
 964        detection_threshold = inference_settings["detection_threshold"]
 965        if inference_settings["inference_on_test"] is True:
 966            dataset_eval_view = self.dataset.match_tags("test")
 967            if len(dataset_eval_view) == 0:
 968                logging.error("Dataset misses split 'test'")
 969            dataset_eval_view.apply_model(
 970                model, label_field=pred_key, confidence_thresh=detection_threshold
 971            )
 972        else:
 973            self.dataset.apply_model(
 974                model, label_field=pred_key, confidence_thresh=detection_threshold
 975            )
 976
 977        if inference_settings["do_eval"]:
 978            eval_key = f"eval_{self.config['model_name']}_{dataset_name}"
 979
 980            if inference_settings["inference_on_test"] is True:
 981                dataset_view = self.dataset.match_tags(["test"])
 982            else:
 983                dataset_view = self.dataset
 984
 985            results = dataset_view.evaluate_detections(
 986                pred_key,
 987                gt_field=gt_field,
 988                eval_key=eval_key,
 989                compute_mAP=True,
 990            )
 991
 992            results.print_report()
 993
 994
 995def transform_batch_standalone(
 996    batch,
 997    image_processor,
 998    do_convert_annotations=True,
 999    return_pixel_mask=False,
1000):
1001    """Apply format annotations in COCO format for object detection task. Outside of class so it can be pickled."""
1002    images = []
1003    annotations = []
1004
1005    for image_path, annotation in zip(batch["image_path"], batch["objects"]):
1006        image = Image.open(image_path).convert("RGB")
1007        image_np = np.array(image)
1008        images.append(image_np)
1009
1010        coco_annotations = []
1011        for i, bbox in enumerate(annotation["bbox"]):
1012
1013            # Conversion from HF dataset bounding boxes to DETR:
1014            # Input: HF dataset bbox is COCO (top_left_x, top_left_y, width, height) in absolute coordinates
1015            # Output:
1016            # DETR expects COCO (top_left_x, top_left_y, width, height) in absolute coordinates if 'do_convert_annotations == True'
1017            # DETR expects YOLO (center_x, center_y, width, height) in relative coordinates between [0,1] if 'do_convert_annotations == False'
1018
1019            if do_convert_annotations == False:
1020                x, y, w, h = bbox
1021                img_height, img_width = image_np.shape[:2]
1022                center_x = (x + w / 2) / img_width
1023                center_y = (y + h / 2) / img_height
1024                width = w / img_width
1025                height = h / img_height
1026                bbox = [center_x, center_y, width, height]
1027
1028                # Ensure bbox values are within the expected range
1029                assert all(0 <= coord <= 1 for coord in bbox), f"Invalid bbox: {bbox}"
1030
1031                logging.debug(
1032                    f"Converted {[x, y, w, h]} to {[center_x, center_y, width, height]} with 'do_convert_annotations' = {do_convert_annotations}"
1033                )
1034
1035            coco_annotation = {
1036                "image_id": annotation["image_id"],
1037                "bbox": bbox,
1038                "category_id": annotation["category_id"][i],
1039                "area": annotation["area"][i],
1040                "iscrowd": 0,
1041            }
1042            coco_annotations.append(coco_annotation)
1043        detr_annotation = {
1044            "image_id": annotation["image_id"],
1045            "annotations": coco_annotations,
1046        }
1047        annotations.append(detr_annotation)
1048
1049        # Apply the image processor transformations: resizing, rescaling, normalization
1050        result = image_processor(
1051            images=images, annotations=annotations, return_tensors="pt"
1052        )
1053
1054    if not return_pixel_mask:
1055        result.pop("pixel_mask", None)
1056
1057    return result
1058
1059
1060class HuggingFaceObjectDetection:
1061    """Object detection using HuggingFace models with support for training and inference."""
1062
1063    def __init__(
1064        self,
1065        dataset,
1066        config,
1067        output_model_path="./output/models/object_detection_hf",
1068        output_detections_path="./output/detections/",
1069        gt_field="ground_truth",
1070    ):
1071        """Initialize with dataset, config, and optional output paths."""
1072        self.dataset = dataset
1073        self.config = config
1074        self.model_name = config["model_name"]
1075        self.model_name_key = re.sub(r"[\W-]+", "_", self.model_name)
1076        self.dataset_name = config["v51_dataset_name"]
1077        self.do_convert_annotations = True  # HF can convert (top_left_x, top_left_y, bottom_right_x, bottom_right_y) in abs. coordinates to (x_min, y_min, width, height) in rel. coordinates https://github.com/huggingface/transformers/blob/v4.48.2/src/transformers/models/conditional_detr/image_processing_conditional_detr.py#L1497
1078
1079        self.detections_root = os.path.join(
1080            output_detections_path, self.dataset_name, self.model_name_key
1081        )
1082
1083        self.model_root = os.path.join(
1084            output_model_path, self.dataset_name, self.model_name_key
1085        )
1086
1087        self.hf_hub_model_id = (
1088            f"{HF_ROOT}/" + f"{self.dataset_name}_{self.model_name}".replace("/", "_")
1089        )
1090
1091        self.categories = dataset.distinct(f"{gt_field}.detections.label")
1092        self.id2label = {index: x for index, x in enumerate(self.categories, start=0)}
1093        self.label2id = {v: k for k, v in self.id2label.items()}
1094
1095    def collate_fn(self, batch):
1096        """Collate function for batching data during training and inference."""
1097        data = {}
1098        data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
1099        data["labels"] = [x["labels"] for x in batch]
1100        if "pixel_mask" in batch[0]:
1101            data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
1102        return data
1103
1104    def train(self, hf_dataset, overwrite_output=True):
1105        """Train models for object detection tasks with support for custom image sizes and transformations."""
1106        torch.cuda.empty_cache()
1107        img_size_target = self.config.get("image_size", None)
1108        if img_size_target is None:
1109            image_processor = AutoProcessor.from_pretrained(
1110                self.model_name,
1111                do_resize=False,
1112                do_pad=True,
1113                use_fast=True,
1114                do_convert_annotations=self.do_convert_annotations,
1115            )
1116        else:
1117            logging.warning(f"Resizing images to target size {img_size_target}.")
1118            image_processor = AutoProcessor.from_pretrained(
1119                self.model_name,
1120                do_resize=True,
1121                size={
1122                    "max_height": img_size_target[1],
1123                    "max_width": img_size_target[0],
1124                },
1125                do_pad=True,
1126                pad_size={"height": img_size_target[1], "width": img_size_target[0]},
1127                use_fast=True,
1128                do_convert_annotations=self.do_convert_annotations,
1129            )
1130
1131        train_transform_batch = partial(
1132            transform_batch_standalone,
1133            image_processor=image_processor,
1134            do_convert_annotations=self.do_convert_annotations,
1135        )
1136        val_test_transform_batch = partial(
1137            transform_batch_standalone,
1138            image_processor=image_processor,
1139            do_convert_annotations=self.do_convert_annotations,
1140        )
1141
1142        hf_dataset[Split.TRAIN] = hf_dataset[Split.TRAIN].with_transform(
1143            train_transform_batch
1144        )
1145        hf_dataset[Split.VALIDATION] = hf_dataset[Split.VALIDATION].with_transform(
1146            val_test_transform_batch
1147        )
1148        hf_dataset[Split.TEST] = hf_dataset[Split.TEST].with_transform(
1149            val_test_transform_batch
1150        )
1151
1152        hf_model_config = AutoConfig.from_pretrained(self.model_name)
1153        hf_model_config_name = type(hf_model_config).__name__
1154
1155        if type(hf_model_config) in AutoModelForObjectDetection._model_mapping:
1156            model = AutoModelForObjectDetection.from_pretrained(
1157                self.model_name,
1158                id2label=self.id2label,
1159                label2id=self.label2id,
1160                ignore_mismatched_sizes=True,
1161            )
1162        else:
1163            model = None
1164            logging.error(
1165                "Hugging Face AutoModel does not support " + str(type(hf_model_config))
1166            )
1167
1168        if (
1169            overwrite_output == True
1170            and os.path.exists(self.model_root)
1171            and os.listdir(self.model_root)
1172        ):
1173            logging.warning(
1174                f"Training will overwrite existing results in {self.model_root}"
1175            )
1176
1177        training_args = TrainingArguments(
1178            run_name=self.model_name,
1179            output_dir=self.model_root,
1180            overwrite_output_dir=overwrite_output,
1181            num_train_epochs=self.config["epochs"],
1182            fp16=True,
1183            per_device_train_batch_size=self.config["batch_size"],
1184            auto_find_batch_size=True,
1185            dataloader_num_workers=min(self.config["n_worker_dataloader"], NUM_WORKERS),
1186            learning_rate=self.config["learning_rate"],
1187            lr_scheduler_type="cosine",
1188            weight_decay=self.config["weight_decay"],
1189            max_grad_norm=self.config["max_grad_norm"],
1190            metric_for_best_model="eval_loss",
1191            greater_is_better=False,
1192            load_best_model_at_end=True,
1193            eval_strategy="epoch",
1194            save_strategy="best",
1195            save_total_limit=1,
1196            remove_unused_columns=False,
1197            eval_do_concat_batches=False,
1198            save_safetensors=False,  # Does not work with all models
1199            hub_model_id=self.hf_hub_model_id,
1200            hub_private_repo=True,
1201            push_to_hub=HF_DO_UPLOAD,
1202            seed=GLOBAL_SEED,
1203            data_seed=GLOBAL_SEED,
1204        )
1205
1206        early_stopping_callback = EarlyStoppingCallback(
1207            early_stopping_patience=self.config["early_stop_patience"],
1208            early_stopping_threshold=self.config["early_stop_threshold"],
1209        )
1210
1211        trainer = Trainer(
1212            model=model,
1213            args=training_args,
1214            train_dataset=hf_dataset[Split.TRAIN],
1215            eval_dataset=hf_dataset[Split.VALIDATION],
1216            tokenizer=image_processor,
1217            data_collator=self.collate_fn,
1218            callbacks=[early_stopping_callback],
1219            # compute_metrics=eval_compute_metrics_fn,
1220        )
1221
1222        logging.info(f"Starting training of model {self.model_name}.")
1223        trainer.train()
1224        if HF_DO_UPLOAD:
1225            trainer.push_to_hub()
1226
1227        metrics = trainer.evaluate(eval_dataset=hf_dataset[Split.TEST])
1228        logging.info(f"Model training completed. Evaluation results: {metrics}")
1229
1230    def inference(self, inference_settings, load_from_hf=True, gt_field="ground_truth"):
1231        """Performs model inference on a dataset, loading from Hugging Face or disk, and optionally evaluates detection results."""
1232
1233        model_hf = inference_settings["model_hf"]
1234        dataset_name = None
1235        if model_hf is not None:
1236            self.hf_hub_model_id = model_hf
1237            dataset_name, model_name = get_dataset_and_model_from_hf_id(model_hf)
1238        else:
1239            dataset_name = self.dataset_name
1240        torch.cuda.empty_cache()
1241        # Load trained model from Hugging Face
1242        load_from_hf_successful = None
1243        if load_from_hf:
1244            try:
1245                logging.info(f"Loading model from Hugging Face: {self.hf_hub_model_id}")
1246                image_processor = AutoProcessor.from_pretrained(self.hf_hub_model_id)
1247                model = AutoModelForObjectDetection.from_pretrained(
1248                    self.hf_hub_model_id
1249                )
1250                load_from_hf_successful = True
1251            except Exception as e:
1252                load_from_hf_successful = False
1253                logging.warning(
1254                    f"Model {self.model_name} could not be loaded from Hugging Face {self.hf_hub_model_id}. Attempting loading from disk."
1255                )
1256        if load_from_hf == False or load_from_hf_successful == False:
1257            try:
1258                # Select folder in self.model_root that include 'checkpoint-'
1259                checkpoint_dirs = [
1260                    d
1261                    for d in os.listdir(self.model_root)
1262                    if "checkpoint-" in d
1263                    and os.path.isdir(os.path.join(self.model_root, d))
1264                ]
1265
1266                if not checkpoint_dirs:
1267                    logging.error(
1268                        f"No checkpoint directory found in {self.model_root}!"
1269                    )
1270                    model_path = None
1271                else:
1272                    # Sort by modification time (latest first)
1273                    checkpoint_dirs.sort(
1274                        key=lambda d: os.path.getmtime(
1275                            os.path.join(self.model_root, d)
1276                        ),
1277                        reverse=True,
1278                    )
1279
1280                    if len(checkpoint_dirs) > 1:
1281                        logging.warning(
1282                            f"Multiple checkpoint directories found: {checkpoint_dirs}. Selecting the latest one: {checkpoint_dirs[0]}."
1283                        )
1284
1285                    selected_checkpoint = checkpoint_dirs[0]
1286                    logging.info(
1287                        f"Loading model from disk: {self.model_root}/{selected_checkpoint}"
1288                    )
1289                    model_path = os.path.join(self.model_root, selected_checkpoint)
1290
1291                image_processor = AutoProcessor.from_pretrained(model_path)
1292                model = AutoModelForObjectDetection.from_pretrained(model_path)
1293            except Exception as e:
1294                logging.error(
1295                    f"Model {self.model_name} could not be loaded from folder {self.model_root}/{selected_checkpoint}. Inference not possible."
1296                )
1297
1298        device, _, _ = get_backend()
1299        logging.info(f"Using device {device} for inference.")
1300        model = model.to(device)
1301        model.eval()
1302
1303        pred_key = f"pred_od_{self.model_name_key}-{dataset_name}"
1304
1305        if inference_settings["inference_on_test"] is True:
1306            INFERENCE_SPLITS = ["test"]
1307            dataset_eval_view = self.dataset.match_tags(INFERENCE_SPLITS)
1308        else:
1309            dataset_eval_view = self.dataset
1310
1311        detection_threshold = inference_settings["detection_threshold"]
1312
1313        with torch.amp.autocast("cuda"), torch.inference_mode():
1314            for sample in dataset_eval_view.iter_samples(progress=True, autosave=True):
1315                image_width = sample.metadata.width
1316                image_height = sample.metadata.height
1317                img_filepath = sample.filepath
1318
1319                image = Image.open(img_filepath)
1320                inputs = image_processor(images=[image], return_tensors="pt")
1321                outputs = model(**inputs.to(device))
1322                target_sizes = torch.tensor([[image.size[1], image.size[0]]])
1323
1324                results = image_processor.post_process_object_detection(
1325                    outputs, threshold=detection_threshold, target_sizes=target_sizes
1326                )[0]
1327
1328                detections = []
1329                for score, label, box in zip(
1330                    results["scores"], results["labels"], results["boxes"]
1331                ):
1332                    # Bbox is in absolute coordinates x, y, x2, y2
1333                    box = box.tolist()
1334                    text_label = model.config.id2label[label.item()]
1335
1336                    # Voxel51 requires relative coordinates between 0 and 1
1337                    top_left_x = box[0] / image_width
1338                    top_left_y = box[1] / image_height
1339                    box_width = (box[2] - box[0]) / image_width
1340                    box_height = (box[3] - box[1]) / image_height
1341                    detection = fo.Detection(
1342                        label=text_label,
1343                        bounding_box=[
1344                            top_left_x,
1345                            top_left_y,
1346                            box_width,
1347                            box_height,
1348                        ],
1349                        confidence=score.item(),
1350                    )
1351                    detections.append(detection)
1352
1353                sample[pred_key] = fo.Detections(detections=detections)
1354
1355        if inference_settings["do_eval"] is True:
1356            eval_key = re.sub(
1357                r"[\W-]+", "_", "eval_" + self.model_name + "_" + self.dataset_name
1358            )
1359
1360            if inference_settings["inference_on_test"] is True:
1361                dataset_view = self.dataset.match_tags(["test"])
1362            else:
1363                dataset_view = self.dataset
1364
1365            results = dataset_view.evaluate_detections(
1366                pred_key,
1367                gt_field=gt_field,
1368                eval_key=eval_key,
1369                compute_mAP=True,
1370            )
1371
1372            results.print_report()
1373
1374
1375class CustomCoDETRObjectDetection:
1376    """Interface for running Co-DETR object detection model training and inference in containers"""
1377
1378    def __init__(self, dataset, dataset_info, run_config):
1379        """Initialize Co-DETR interface with dataset and configuration"""
1380        self.root_codetr = "./custom_models/CoDETR/Co-DETR"
1381        self.root_codetr_models = "output/models/codetr"
1382        self.dataset = dataset
1383        self.dataset_name = dataset_info["name"]
1384        self.export_dir_root = run_config["export_dataset_root"]
1385        self.config_key = os.path.splitext(os.path.basename(run_config["config"]))[0]
1386        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_{self.config_key}"
1387
1388    def convert_data(self):
1389        """Convert dataset to COCO format required by Co-DETR"""
1390
1391        export_dir = os.path.join(self.export_dir_root, self.dataset_name, "coco")
1392
1393        # Check if folder already exists
1394        if not os.path.exists(export_dir):
1395            # Make directory
1396            os.makedirs(export_dir, exist_ok=True)
1397            logging.info(f"Exporting data to {export_dir}")
1398            splits = [
1399                "train",
1400                "val",
1401                "test",
1402            ]  # CoDETR expects data in 'train' and 'val' folder
1403            for split in splits:
1404                split_view = self.dataset.match_tags(split)
1405                split_view.export(
1406                    dataset_type=fo.types.COCODetectionDataset,
1407                    data_path=os.path.join(export_dir, f"{split}2017"),
1408                    labels_path=os.path.join(
1409                        export_dir, "annotations", f"instances_{split}2017.json"
1410                    ),
1411                    label_field="ground_truth",
1412                )
1413        else:
1414            logging.warning(
1415                f"Folder {export_dir} already exists, skipping data export."
1416            )
1417
1418    def update_config_file(self, dataset_name, config_file, max_epochs):
1419        """Update Co-DETR config file with dataset-specific parameters"""
1420
1421        config_path = os.path.join(self.root_codetr, config_file)
1422
1423        # Get classes from exported data
1424        annotations_json = os.path.join(
1425            self.export_dir_root,
1426            dataset_name,
1427            "coco/annotations/instances_train2017.json",
1428        )
1429        # Read the JSON file
1430        with open(annotations_json, "r") as file:
1431            data = json.load(file)
1432
1433        # Extract the value associated with the key "categories"
1434        categories = data.get("categories")
1435        class_names = tuple(category["name"] for category in categories)
1436        num_classes = len(class_names)
1437
1438        # Update configuration file
1439        # This assumes that 'classes = '('a','b',...)' are already defined and will be overwritten.
1440        with open(config_path, "r") as file:
1441            content = file.read()
1442
1443        # Update the classes tuple
1444        content = re.sub(r"classes\s*=\s*\(.*?\)", f"classes = {class_names}", content)
1445
1446        # Update all instances of num_classes
1447        content = re.sub(r"num_classes=\d+", f"num_classes={num_classes}", content)
1448
1449        # Update all instances of max_epochs
1450        content = re.sub(r"max_epochs=\d+", f"max_epochs={max_epochs}", content)
1451
1452        with open(config_path, "w") as file:
1453            file.write(content)
1454
1455        logging.warning(
1456            f"Updated {config_path} with classes={class_names} and num_classes={num_classes} and max_epochs={max_epochs}"
1457        )
1458
1459    def train(self, param_config, param_n_gpus, container_tool, param_function="train"):
1460        """Train Co-DETR model using containerized environment"""
1461
1462        # Check if model already exists
1463        output_folder_codetr = os.path.join(self.root_codetr, "output")
1464        os.makedirs(output_folder_codetr, exist_ok=True)
1465        param_config_name = os.path.splitext(os.path.basename(param_config))[0]
1466        best_models_dir = os.path.join(output_folder_codetr, "best")
1467        os.makedirs(best_models_dir, exist_ok=True)
1468        # Best model files follow the naming scheme "config_dataset.pth"
1469        pth_model_files = (
1470            [f for f in os.listdir(best_models_dir) if f.endswith(".pth")]
1471            if os.path.exists(best_models_dir) and os.path.isdir(best_models_dir)
1472            else []
1473        )
1474
1475        # Best model files are stored in the format "config_dataset.pth"
1476        matching_files = [
1477            f
1478            for f in pth_model_files
1479            if f.startswith(param_config_name)
1480            and self.dataset_name in f
1481            and f.endswith(".pth")
1482        ]
1483        if len(matching_files) > 0:
1484            logging.warning(
1485                f"Model {param_config_name} already trained on dataset {self.dataset_name}. Skipping training."
1486            )
1487            if len(matching_files) > 1:
1488                logging.warning(f"Multiple weights found: {matching_files}")
1489        else:
1490            logging.info(
1491                f"Launching training for Co-DETR config {param_config} and dataset {self.dataset_name}."
1492            )
1493            volume_data = os.path.join(self.export_dir_root, self.dataset_name)
1494
1495            # Train model, store checkpoints in 'output_folder_codetr'
1496            train_result = self._run_container(
1497                volume_data=volume_data,
1498                param_function=param_function,
1499                param_config=param_config,
1500                param_n_gpus=param_n_gpus,
1501                container_tool=container_tool,
1502            )
1503
1504            # Find the best_bbox checkpoint file
1505            checkpoint_files = [
1506                f
1507                for f in os.listdir(output_folder_codetr)
1508                if "best_bbox" in f and f.endswith(".pth")
1509            ]
1510            if not checkpoint_files:
1511                logging.error(
1512                    "Co-DETR was not trained, model pth file missing. No checkpoint file with 'best_bbox' found."
1513                )
1514            else:
1515                if len(checkpoint_files) > 1:
1516                    logging.warning(
1517                        f"Found {len(checkpoint_files)} checkpoint files. Selecting {checkpoint_files[0]}."
1518                    )
1519                checkpoint = checkpoint_files[0]
1520                checkpoint_path = os.path.join(output_folder_codetr, checkpoint)
1521                logging.info("Co-DETR was trained successfully.")
1522
1523                # Upload best model to Hugging Face
1524                if HF_DO_UPLOAD == True:
1525                    logging.info("Uploading Co-DETR model to Hugging Face.")
1526                    api = HfApi()
1527                    api.create_repo(
1528                        self.hf_repo_name,
1529                        private=True,
1530                        repo_type="model",
1531                        exist_ok=True,
1532                    )
1533                    api.upload_file(
1534                        path_or_fileobj=checkpoint_path,
1535                        path_in_repo="model.pth",
1536                        repo_id=self.hf_repo_name,
1537                        repo_type="model",
1538                    )
1539
1540                # Move best model file and clear output folder
1541                self._run_container(
1542                    volume_data=volume_data,
1543                    param_function="clear-output",
1544                    param_config=param_config,
1545                    param_dataset_name=self.dataset_name,
1546                    container_tool=container_tool,
1547                )
1548
1549    @staticmethod
1550    def _find_file_iteratively(start_path, filename):
1551        """Direct access or recursively search for a file in a directory structure."""
1552        # Convert start_path to a Path object
1553        start_path = Path(start_path)
1554
1555        # Check if the file exists in the start_path directly (very fast)
1556        file_path = start_path / filename
1557        if file_path.exists():
1558            return str(file_path)
1559
1560        # Start with the highest directory and go up iteratively
1561        current_dir = start_path
1562        checked_dirs = set()
1563
1564        while current_dir != current_dir.root:
1565            # Check if the file is in the current directory
1566            file_path = current_dir / filename
1567            if file_path.exists():
1568                return str(file_path)
1569
1570            # If we haven't checked the sibling directories, check them as well
1571            parent_dir = current_dir.parent
1572            if parent_dir not in checked_dirs:
1573                # Check sibling directories
1574                for sibling in parent_dir.iterdir():
1575                    if sibling != current_dir and sibling.is_dir():
1576                        sibling_file_path = sibling / filename
1577                        if sibling_file_path.exists():
1578                            return str(sibling_file_path)
1579                checked_dirs.add(parent_dir)
1580
1581            # Otherwise, go one level up
1582            current_dir = current_dir.parent
1583
1584        # If file is not found after traversing all levels, return None
1585        logging.error(f"File {filename} could not be found.")
1586        return None
1587
1588    def run_inference(
1589        self,
1590        dataset,
1591        param_config,
1592        param_n_gpus,
1593        container_tool,
1594        inference_settings,
1595        param_function="inference",
1596        inference_output_folder="custom_models/CoDETR/Co-DETR/output/inference/",
1597        gt_field="ground_truth",
1598    ):
1599        """Run inference using trained Co-DETR model and convert results to FiftyOne format"""
1600
1601        logging.info(f"Launching inference for Co-DETR config {param_config}.")
1602        volume_data = os.path.join(self.export_dir_root, self.dataset_name)
1603
1604        if inference_settings["inference_on_test"] is True:
1605            folder_inference = os.path.join("coco", "test2017")
1606        else:
1607            folder_inference = os.path.join("coco")
1608
1609        # Get model from Hugging Face
1610        dataset_name = None
1611        config_key = None
1612        try:
1613            if inference_settings["model_hf"] is None:
1614                hf_path = self.hf_repo_name
1615            else:
1616                hf_path = inference_settings["model_hf"]
1617
1618            dataset_name, config_key = get_dataset_and_model_from_hf_id(hf_path)
1619
1620            download_folder = os.path.join(
1621                self.root_codetr_models, dataset_name, config_key
1622            )
1623
1624            logging.info(
1625                f"Downloading model {hf_path} from Hugging Face into {download_folder}"
1626            )
1627            os.makedirs(download_folder, exist_ok=True)
1628
1629            file_path = hf_hub_download(
1630                repo_id=hf_path,
1631                filename="model.pth",
1632                local_dir=download_folder,
1633            )
1634        except Exception as e:
1635            logging.error(f"An error occured during model download: {e}")
1636
1637        model_path = os.path.join(dataset_name, config_key, "model.pth")
1638        logging.info(f"Starting inference for model {model_path}")
1639
1640        inference_result = self._run_container(
1641            volume_data=volume_data,
1642            param_function=param_function,
1643            param_config=param_config,
1644            param_n_gpus=param_n_gpus,
1645            container_tool=container_tool,
1646            param_inference_dataset_folder=folder_inference,
1647            param_inference_model_checkpoint=model_path,
1648        )
1649
1650        # Convert results from JSON output into V51 dataset
1651        # Files follow format inference_results_{timestamp}.json (run_inference.py)
1652        os.makedirs(inference_output_folder, exist_ok=True)
1653        output_files = [
1654            f
1655            for f in os.listdir(inference_output_folder)
1656            if f.startswith("inference_results_") and f.endswith(".json")
1657        ]
1658        logging.debug(f"Found files with inference content: {output_files}")
1659
1660        if not output_files:
1661            logging.error(
1662                f"No inference result files found in {inference_output_folder}"
1663            )
1664
1665        # Get full path for each file
1666        file_paths = [os.path.join(inference_output_folder, f) for f in output_files]
1667
1668        # Extract timestamp from the filename and sort based on the timestamp
1669        file_paths_sorted = sorted(
1670            file_paths,
1671            key=lambda f: datetime.datetime.strptime(
1672                f.split("_")[-2] + "_" + f.split("_")[-1].replace(".json", ""),
1673                "%Y%m%d_%H%M%S",
1674            ),
1675            reverse=True,
1676        )
1677
1678        # Use the most recent file based on timestamp
1679        latest_file = file_paths_sorted[0]
1680        logging.info(f"Using inference results from: {latest_file}")
1681        with open(latest_file, "r") as file:
1682            data = json.load(file)
1683
1684        # Get conversion for annotated classes
1685        annotations_path = os.path.join(
1686            volume_data, "coco", "annotations", "instances_train2017.json"
1687        )
1688
1689        with open(annotations_path, "r") as file:
1690            data_annotations = json.load(file)
1691
1692        class_ids_and_names = [
1693            (category["id"], category["name"])
1694            for category in data_annotations["categories"]
1695        ]
1696
1697        # Match sample filepaths (from exported Co-DETR COCO format) to V51 filepaths
1698        sample = dataset.first()
1699        root_dir_samples = sample.filepath
1700
1701        # Convert results into V51 file format
1702        detection_threshold = inference_settings["detection_threshold"]
1703        pred_key = f"pred_od_{config_key}-{dataset_name}"
1704        for key, value in tqdm(data.items(), desc="Processing Co-DETR detection"):
1705            try:
1706                # Get filename
1707                filepath = CustomCoDETRObjectDetection._find_file_iteratively(
1708                    root_dir_samples, os.path.basename(key)
1709                )
1710                sample = dataset[filepath]
1711
1712                img_width = sample.metadata.width
1713                img_height = sample.metadata.height
1714
1715                detections_v51 = []
1716                for class_id, class_detections in enumerate(data[key]):  # Starts with 0
1717                    if len(class_detections) > 0:
1718                        objects_class = class_ids_and_names[class_id]
1719                        for detection in class_detections:
1720                            confidence = detection[4]
1721                            detection_v51 = fo.Detection(
1722                                label=objects_class[1],
1723                                bounding_box=[
1724                                    detection[0] / img_width,
1725                                    detection[1] / img_height,
1726                                    (detection[2] - detection[0]) / img_width,
1727                                    (detection[3] - detection[1]) / img_height,
1728                                ],
1729                                confidence=confidence,
1730                            )
1731                            if confidence >= detection_threshold:
1732                                detections_v51.append(detection_v51)
1733
1734                sample[pred_key] = fo.Detections(detections=detections_v51)
1735                sample.save()
1736            except Exception as e:
1737                logging.error(
1738                    f"An error occured during the conversion of Co-DETR inference results to the V51 dataset: {e}"
1739                )
1740
1741        # Run V51 evaluation
1742        if inference_settings["do_eval"] is True:
1743            eval_key = pred_key.replace("pred_", "eval_").replace("-", "_")
1744
1745            if inference_settings["inference_on_test"] is True:
1746                dataset_view = dataset.match_tags(["test"])
1747            else:
1748                dataset_view = dataset
1749
1750            logging.info(
1751                f"Starting evaluation for {pred_key} in evaluation key {eval_key}."
1752            )
1753
1754            results = dataset_view.evaluate_detections(
1755                pred_key,
1756                gt_field=gt_field,
1757                eval_key=eval_key,
1758                compute_mAP=True,
1759            )
1760
1761            results.print_report()
1762
1763    def _run_container(
1764        self,
1765        volume_data,
1766        param_function,
1767        param_config="",
1768        param_n_gpus="1",
1769        param_dataset_name="",
1770        param_inference_dataset_folder="",
1771        param_inference_model_checkpoint="",
1772        image="dbogdollresearch/codetr",
1773        workdir="/launch",
1774        container_tool="docker",
1775    ):
1776        """Execute Co-DETR container with specified parameters using Docker or Singularity"""
1777
1778        try:
1779            # Convert relative paths to absolute paths (necessary under WSL2)
1780            root_codetr_abs = os.path.abspath(self.root_codetr)
1781            volume_data_abs = os.path.abspath(volume_data)
1782            root_codetr_models_abs = os.path.abspath(self.root_codetr_models)
1783
1784            # Check if using Docker or Singularity and define the appropriate command
1785            if container_tool == "docker":
1786                command = [
1787                    "docker",
1788                    "run",
1789                    "--gpus",
1790                    "all",
1791                    "--workdir",
1792                    workdir,
1793                    "--volume",
1794                    f"{root_codetr_abs}:{workdir}",
1795                    "--volume",
1796                    f"{volume_data_abs}:{workdir}/data:ro",
1797                    "--volume",
1798                    f"{root_codetr_models_abs}:{workdir}/hf_models:ro",
1799                    "--shm-size=8g",
1800                    image,
1801                    param_function,
1802                    param_config,
1803                    param_n_gpus,
1804                    param_dataset_name,
1805                    param_inference_dataset_folder,
1806                    param_inference_model_checkpoint,
1807                ]
1808            elif container_tool == "singularity":
1809                command = [
1810                    "singularity",
1811                    "run",
1812                    "--nv",
1813                    "--pwd",
1814                    workdir,
1815                    "--bind",
1816                    f"{self.root_codetr}:{workdir}",
1817                    "--bind",
1818                    f"{volume_data}:{workdir}/data:ro",
1819                    "--bind",
1820                    f"{self.root_codetr_models}:{workdir}/hf_models:ro",
1821                    f"docker://{image}",
1822                    param_function,
1823                    param_config,
1824                    param_n_gpus,
1825                    param_dataset_name,
1826                    param_inference_dataset_folder,
1827                    param_inference_model_checkpoint,
1828                ]
1829            else:
1830                raise ValueError(
1831                    f"Invalid container tool specified: {container_tool}. Choose 'docker' or 'singularity'."
1832                )
1833
1834            # Start the process and stream outputs to the console
1835            logging.info(f"Launching terminal command {command}")
1836            with subprocess.Popen(
1837                command, stdout=sys.stdout, stderr=sys.stderr, text=True
1838            ) as proc:
1839                proc.wait()  # Wait for the process to complete
1840            return True
1841        except Exception as e:
1842            logging.error(f"Error during Co-DETR container run: {e}")
1843            return False
1844
1845
1846class CustomRFDETRObjectDetection:
1847    """Interface for running RF-DETR object detection model training and inference"""
1848
1849    def __init__(self, dataset, dataset_info, run_config):
1850        """Initialize RF-DETR interface with dataset and configuration"""
1851        self.dataset = dataset
1852        self.dataset_name = dataset_info["name"]
1853        self.export_dir_root = run_config["export_dataset_root"]
1854        self.config_key = os.path.splitext(os.path.basename(run_config.get("config", "rfdetr")))[0]
1855        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_{self.config_key}"
1856
1857    def convert_data(self):
1858        """
1859        Convert dataset to RF-DETR COCO format with zero-indexed categories.
1860        Automatically creates missing splits by splitting val or test 50/50.
1861
1862        Expected output structure:
1863        dsname/rfdetr/
1864            test/
1865                _annotations.coco.json
1866                images...
1867            train/
1868                _annotations.coco.json
1869                images...
1870            valid/
1871                _annotations.coco.json
1872                images...
1873        """
1874        export_dir = os.path.join(self.export_dir_root, self.dataset_name, "rfdetr")
1875
1876        # Check if folder already exists
1877        if os.path.exists(export_dir):
1878            logging.warning(
1879                f"Folder {export_dir} already exists, skipping data export."
1880            )
1881            return
1882
1883        # Make directory
1884        os.makedirs(export_dir, exist_ok=True)
1885        logging.info(f"Exporting data to {export_dir}")
1886
1887        # Check what splits exist
1888        available_tags = self.dataset.distinct("tags")
1889        has_train = "train" in available_tags
1890        has_val = "val" in available_tags
1891        has_test = "test" in available_tags
1892
1893        logging.info(f"Available splits in dataset: {available_tags}")
1894
1895        # Handle missing splits - split val or test 50/50
1896        if has_train and has_val and not has_test:
1897            logging.info("No test split found. Splitting val 50/50 into valid and test...")
1898            self._split_50_50("val", "test")
1899        elif has_train and has_test and not has_val:
1900            logging.info("No val split found. Splitting test 50/50 into valid and test...")
1901            self._split_50_50("test", "val")
1902        elif not has_train or (not has_val and not has_test):
1903            logging.error(
1904                f"Dataset must have 'train' and at least one of 'val' or 'test'. "
1905                f"Found: {available_tags}"
1906            )
1907            raise ValueError("Insufficient splits in dataset")
1908
1909        # RF-DETR expects 'train', 'valid', 'test' splits
1910        split_mapping = {
1911            "train": "train",
1912            "val": "valid",  # Map 'val' to 'valid' for RF-DETR
1913            "test": "test"
1914        }
1915
1916        for v51_split, rfdetr_split in split_mapping.items():
1917            split_view = self.dataset.match_tags(v51_split)
1918
1919            if len(split_view) == 0:
1920                logging.warning(f"No samples found for split '{v51_split}', skipping.")
1921                continue
1922
1923            split_export_dir = os.path.join(export_dir, rfdetr_split)
1924            os.makedirs(split_export_dir, exist_ok=True)
1925
1926            # Export to COCO format
1927            annotation_path = os.path.join(split_export_dir, "_annotations.coco.json")
1928
1929            logging.info(f"Exporting {len(split_view)} samples to {rfdetr_split}/")
1930
1931            split_view.export(
1932                dataset_type=fo.types.COCODetectionDataset,
1933                data_path=split_export_dir,
1934                labels_path=annotation_path,
1935                label_field="ground_truth",
1936            )
1937
1938            # Fix category IDs: Convert from 1-indexed to 0-indexed
1939            self._fix_annotation_indices(annotation_path)
1940
1941        logging.info(f"Successfully exported dataset to RF-DETR format at {export_dir}")
1942
1943    def _split_50_50(self, source_split, target_split):
1944        """
1945        Split a dataset split 50/50 into two splits.
1946
1947        Args:
1948            source_split: The split to divide (e.g., "val" or "test")
1949            target_split: The new split to create (e.g., "test" or "val")
1950
1951        Example:
1952            - 1000 val samples → 500 val + 500 test
1953            - 1000 test samples → 500 val + 500 test
1954        """
1955        source_samples = self.dataset.match_tags(source_split)
1956        source_ids = source_samples.values("id")
1957
1958        if len(source_ids) < 2:
1959            logging.error(
1960                f"Not enough samples in '{source_split}' to split. "
1961                f"Need at least 2, found {len(source_ids)}"
1962            )
1963            raise ValueError(f"Insufficient samples in {source_split} split")
1964
1965        # Shuffle for random split
1966        random.seed(GLOBAL_SEED)  # Use GLOBAL_SEED instead of 42
1967        random.shuffle(source_ids)
1968
1969        # Split 50/50
1970        split_point = len(source_ids) // 2
1971        keep_in_source = source_ids[:split_point]
1972        move_to_target = source_ids[split_point:]
1973
1974        logging.info(
1975            f"Splitting {len(source_ids)} '{source_split}' samples: "
1976            f"{len(keep_in_source)} remain in '{source_split}', "
1977            f"{len(move_to_target)} moved to '{target_split}'"
1978        )
1979
1980        # Move samples to target split
1981        for sample_id in move_to_target:
1982            sample = self.dataset[sample_id]
1983            sample.tags.remove(source_split)
1984            sample.tags.append(target_split)
1985            sample.save()
1986
1987        self.dataset.save()
1988        logging.info(f"Successfully created '{target_split}' split from '{source_split}'")
1989
1990    def _fix_annotation_indices(self, annotation_path):
1991        """
1992        Fix COCO annotation file to use zero-indexed category IDs.
1993
1994        Args:
1995            annotation_path: Path to the _annotations.coco.json file
1996        """
1997        if not os.path.exists(annotation_path):
1998            logging.error(f"Annotation file not found: {annotation_path}")
1999            return
2000
2001        try:
2002            # Create backup
2003            backup_path = f"{annotation_path}.backup"
2004            if not os.path.exists(backup_path):
2005                shutil.copy2(annotation_path, backup_path)
2006                logging.debug(f"Created backup: {backup_path}")
2007
2008            # Read annotation file
2009            with open(annotation_path, 'r') as f:
2010                data = json.load(f)
2011
2012            # Fix categories: 1-indexed → 0-indexed
2013            if 'categories' in data:
2014                for cat in data['categories']:
2015                    if cat['id'] > 0:
2016                        cat['id'] -= 1
2017                logging.debug(f"Fixed {len(data['categories'])} category IDs")
2018
2019            # Fix annotations: 1-indexed → 0-indexed
2020            if 'annotations' in data:
2021                for ann in data['annotations']:
2022                    if ann['category_id'] > 0:
2023                        ann['category_id'] -= 1
2024                logging.debug(f"Fixed {len(data['annotations'])} annotation category IDs")
2025
2026            # Save fixed file
2027            with open(annotation_path, 'w') as f:
2028                json.dump(data, f, indent=2)
2029
2030            logging.info(f"Successfully fixed indices in: {annotation_path}")
2031
2032        except Exception as e:
2033            logging.error(f"Error fixing annotation indices in {annotation_path}: {e}")
2034            # Restore from backup if something went wrong
2035            backup_path = f"{annotation_path}.backup"
2036            if os.path.exists(backup_path):
2037                shutil.copy2(backup_path, annotation_path)
2038                logging.info(f"Restored from backup due to error")
2039
2040    def train(self, run_config, shared_config):
2041        """Train RF-DETR model using shared and model-specific configuration"""
2042
2043
2044
2045
2046        # Model selection mapping
2047        MODEL_REGISTRY = {
2048            "rfdetr_nano": RFDETRNano,
2049            "rfdetr_small": RFDETRSmall,
2050            "rfdetr_medium": RFDETRMedium,
2051            "rfdetr_large": RFDETRLarge,
2052        }
2053
2054        model_name = self.config_key.lower()
2055
2056        if model_name not in MODEL_REGISTRY:
2057            logging.error(
2058                f"Model '{model_name}' not supported. "
2059                f"Available models: {list(MODEL_REGISTRY.keys())}"
2060            )
2061            raise ValueError(f"Unsupported RF-DETR model: {model_name}")
2062
2063        # Initialize model
2064        logging.info(f"Initializing {model_name}...")
2065        ModelClass = MODEL_REGISTRY[model_name]
2066        model = ModelClass()
2067
2068        # Prepare dataset directory
2069        dataset_dir = os.path.join(self.export_dir_root, self.dataset_name, "rfdetr")
2070
2071        if not os.path.exists(dataset_dir):
2072            logging.error(f"Dataset directory not found: {dataset_dir}")
2073            logging.info("Please run convert_data() first to prepare the dataset.")
2074            raise FileNotFoundError(f"Dataset not found at {dataset_dir}")
2075
2076        # Output directory
2077        output_dir = os.path.join("output/models/rfdetr", self.dataset_name, model_name)
2078        os.makedirs(output_dir, exist_ok=True)
2079
2080        # Build training arguments
2081        train_kwargs = {
2082            "dataset_dir": dataset_dir,
2083            "output_dir": output_dir,
2084
2085            # === SHARED parameters from top-level config ===
2086            "epochs": shared_config.get("epochs", 50),
2087            "lr": shared_config.get("learning_rate", 1e-4),
2088            "weight_decay": shared_config.get("weight_decay", 0.0001),
2089
2090            # === RF-DETR specific parameters ===
2091            "batch_size": run_config.get("batch_size", 16),
2092            "grad_accum_steps": run_config.get("grad_accum_steps", 1),
2093            "lr_encoder": run_config.get("lr_encoder", None),
2094            "resolution": run_config.get("resolution", None),
2095            "use_ema": run_config.get("use_ema", True),
2096            "gradient_checkpointing": run_config.get("gradient_checkpointing", False),
2097
2098            # === Logging (use global settings) ===
2099            "tensorboard": True,
2100            "wandb": WANDB_ACTIVE,
2101            "project": f"MCityDataEngine-RFDETR",
2102            "run": f"{self.dataset_name}_{model_name}",
2103
2104            # === Early stopping ===
2105            "early_stopping": True,
2106            "early_stopping_patience": shared_config.get("early_stop_patience", 10),
2107            "early_stopping_min_delta": run_config.get(
2108                "early_stopping_min_delta",
2109                shared_config.get("early_stop_threshold", 0.001)
2110            ),
2111            "early_stopping_use_ema": run_config.get("early_stopping_use_ema", True),
2112        }
2113
2114        # Set device - use all available GPUs
2115        if torch.cuda.is_available():
2116            if torch.cuda.device_count() > 1:
2117                device = "cuda"
2118                logging.info(f"Using {torch.cuda.device_count()} GPUs for training")
2119            else:
2120                device = "cuda:0"
2121                logging.info("Using single GPU for training")
2122        else:
2123            device = "cpu"
2124            logging.warning("No GPU available, training on CPU")
2125
2126        train_kwargs["device"] = device
2127
2128        # Remove None values
2129        train_kwargs = {k: v for k, v in train_kwargs.items() if v is not None}
2130
2131        # Log configuration
2132        logging.info("="*70)
2133        logging.info("RF-DETR TRAINING CONFIGURATION")
2134        logging.info("="*70)
2135        logging.info(f"Model: {model_name}")
2136        logging.info(f"Dataset: {dataset_dir}")
2137        logging.info(f"Output: {output_dir}")
2138        logging.info(f"WandB Active: {WANDB_ACTIVE}")
2139        for key, value in train_kwargs.items():
2140            if key not in ["dataset_dir", "output_dir"]:
2141                logging.info(f"  {key}: {value}")
2142        logging.info("="*70)
2143
2144        # Train
2145        try:
2146            logging.info("Starting RF-DETR training...")
2147            model.train(**train_kwargs)
2148            logging.info("RF-DETR training completed successfully!")
2149
2150            # Model paths to check (RF-DETR can save in different locations)
2151            possible_model_paths = [
2152                os.path.join(output_dir, "checkpoints", "best.pt"),
2153                os.path.join(output_dir, "checkpoint_best_total.pth"),
2154                os.path.join(output_dir, "best.pt"),
2155            ]
2156
2157            self.model_path = None
2158            for path in possible_model_paths:
2159                if os.path.exists(path):
2160                    self.model_path = path
2161                    logging.info(f"Found trained model at: {path}")
2162                    break
2163
2164            if self.model_path is None:
2165                logging.warning("Could not find trained model file in expected locations")
2166                self.model_path = possible_model_paths[0]  # Default to first path
2167
2168            # Upload to Hugging Face if configured
2169            if HF_DO_UPLOAD:
2170                self._upload_to_hf()
2171
2172            return True
2173
2174        except Exception as e:
2175            logging.error(f"❌ Training failed: {e}")
2176            import traceback
2177            traceback.print_exc()
2178            return False
2179
2180
2181    def _upload_to_hf(self):
2182        """Upload trained RF-DETR model to Hugging Face"""
2183
2184        if not os.path.exists(self.model_path):
2185            logging.warning(f"Model file not found at {self.model_path}, skipping upload.")
2186            return
2187
2188        try:
2189            logging.info(f"Uploading RF-DETR model to Hugging Face: {self.hf_repo_name}")
2190            api = HfApi()
2191
2192            # Create repository
2193            api.create_repo(
2194                self.hf_repo_name,
2195                private=True,
2196                repo_type="model",
2197                exist_ok=True
2198            )
2199
2200            # Upload model file
2201            api.upload_file(
2202                path_or_fileobj=self.model_path,
2203                path_in_repo="best.pt",
2204                repo_id=self.hf_repo_name,
2205                repo_type="model",
2206            )
2207
2208            logging.info(f"Model uploaded successfully to {self.hf_repo_name}")
2209
2210        except Exception as e:
2211            logging.error(f"Failed to upload model to Hugging Face: {e}")
2212            import traceback
2213            traceback.print_exc()
2214
2215
2216    def inference(self, inference_settings, gt_field="ground_truth"):
2217        """Performs inference using RF-DETR model on a dataset with optional evaluation"""
2218
2219
2220
2221        logging.info(f"Running RF-DETR inference on dataset {self.dataset_name}")
2222
2223        # Model selection mapping
2224        MODEL_REGISTRY = {
2225            "rfdetr_nano": RFDETRNano,
2226            "rfdetr_small": RFDETRSmall,
2227            "rfdetr_medium": RFDETRMedium,
2228            "rfdetr_large": RFDETRLarge,
2229        }
2230
2231        # Determine model and dataset names
2232        dataset_name = None
2233        model_name = self.config_key.lower()
2234
2235        model_hf = inference_settings.get("model_hf", None)
2236
2237        # Determine model path
2238        if model_hf is not None:
2239            # Use model from Hugging Face
2240            logging.info(f"Using model from Hugging Face: {model_hf}")
2241            dataset_name, model_name = get_dataset_and_model_from_hf_id(model_hf)
2242
2243            # Set up directories
2244            download_dir = os.path.join(
2245                "output/models/rfdetr", dataset_name, model_name
2246            )
2247            os.makedirs(download_dir, exist_ok=True)
2248
2249            # Download model from Hugging Face
2250            try:
2251                logging.info(f"Downloading model from Hugging Face: {model_hf}")
2252                model_path = hf_hub_download(
2253                    repo_id=model_hf,
2254                    filename="best.pt",
2255                    local_dir=download_dir,
2256                )
2257            except Exception as e:
2258                logging.error(f"Failed to download model from Hugging Face: {e}")
2259                return False
2260        else:
2261            # Use locally trained model
2262            dataset_name = self.dataset_name
2263
2264            # Check multiple possible locations
2265            possible_paths = [
2266                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "checkpoints", "best.pt"),
2267                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "checkpoint_best_total.pth"),
2268                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "best.pt"),
2269            ]
2270
2271            model_path = None
2272            for path in possible_paths:
2273                if os.path.exists(path):
2274                    model_path = path
2275                    logging.info(f"Found model at: {path}")
2276                    break
2277
2278            if model_path is None:
2279                # Try downloading from auto-generated HF repo
2280                logging.info(f"Local model not found. Attempting to download from {self.hf_repo_name}")
2281                download_dir = os.path.join(
2282                    "output/models/rfdetr", self.dataset_name, model_name
2283                )
2284                os.makedirs(download_dir, exist_ok=True)
2285
2286                try:
2287                    model_path = hf_hub_download(
2288                        repo_id=self.hf_repo_name,
2289                        filename="best.pt",
2290                        local_dir=download_dir,
2291                    )
2292                except Exception as e:
2293                    logging.error(f"Failed to load or download model: {e}")
2294                    return False
2295
2296        # Check if model exists
2297        if not os.path.exists(model_path):
2298            logging.error(f"Model file not found: {model_path}")
2299            return False
2300
2301        logging.info(f"Using model: {model_path}")
2302
2303        # Initialize model
2304        if model_name not in MODEL_REGISTRY:
2305            logging.error(f"Model '{model_name}' not supported.")
2306            return False
2307
2308        ModelClass = MODEL_REGISTRY[model_name]
2309
2310        # Get class names from dataset
2311        try:
2312            class_names = self.dataset.distinct(f"{gt_field}.detections.label")
2313            class_names = sorted(class_names)
2314            num_classes = len(class_names)
2315            logging.info(f"Found {num_classes} classes: {class_names}")
2316        except Exception as e:
2317            logging.warning(f"Could not extract class names from dataset: {e}")
2318            num_classes = 8  # Default fallback
2319            class_names = None
2320
2321        # Load model with trained weights
2322        try:
2323            logging.info("Loading RF-DETR model...")
2324            model = ModelClass(
2325                pretrain_weights=model_path,
2326                num_classes=num_classes
2327            )
2328
2329            logging.info("RF-DETR model loaded successfully")
2330        except Exception as e:
2331            logging.error(f"Failed to load model: {e}")
2332            return False
2333
2334        # Prepare dataset view
2335        detection_threshold = inference_settings.get("detection_threshold", 0.2)
2336
2337        if inference_settings.get("inference_on_test", True):
2338            INFERENCE_SPLITS = ["test"]
2339            dataset_view = self.dataset.match_tags(INFERENCE_SPLITS)
2340            if len(dataset_view) == 0:
2341                logging.error(f"Dataset has no splits: {INFERENCE_SPLITS}")
2342                return False
2343        else:
2344            dataset_view = self.dataset
2345
2346        # Prediction key
2347        pred_key = f"pred_od_{model_name}-{dataset_name}"
2348
2349        logging.info(f"Running inference on {len(dataset_view)} samples...")
2350        logging.info(f"Detection threshold: {detection_threshold}")
2351
2352        # Run inference on each sample
2353        try:
2354            processed_count = 0
2355
2356            for sample in tqdm(dataset_view.iter_samples(progress=True, autosave=True),
2357                            total=len(dataset_view),
2358                            desc="RF-DETR Inference"):
2359
2360                try:
2361                    # Load image
2362                    image = Image.open(sample.filepath)
2363                    img_width, img_height = image.size
2364
2365                    # Run inference using RF-DETR's predict method
2366                    detections = model.predict(
2367                        image,
2368                        threshold=detection_threshold
2369                    )
2370
2371                    # Convert supervision detections to FiftyOne format
2372                    fo_detections = []
2373
2374                    if len(detections) > 0:
2375                        for i in range(len(detections)):
2376                            # Get detection data (RF-DETR returns supervision format)
2377                            bbox = detections.xyxy[i]  # [x1, y1, x2, y2] in pixel coordinates
2378                            confidence = detections.confidence[i] if detections.confidence is not None else 1.0
2379                            class_id = detections.class_id[i] if detections.class_id is not None else 0
2380
2381                            # Convert to relative coordinates [x, y, width, height]
2382                            x1, y1, x2, y2 = bbox
2383                            rel_x = x1 / img_width
2384                            rel_y = y1 / img_height
2385                            rel_w = (x2 - x1) / img_width
2386                            rel_h = (y2 - y1) / img_height
2387
2388                            # Get class name
2389                            if class_names and class_id < len(class_names):
2390                                class_name = class_names[class_id]
2391                            else:
2392                                class_name = f"class_{class_id}"
2393
2394                            # Create FiftyOne detection
2395                            fo_detection = fo.Detection(
2396                                label=class_name,
2397                                bounding_box=[rel_x, rel_y, rel_w, rel_h],
2398                                confidence=float(confidence)
2399                            )
2400                            fo_detections.append(fo_detection)
2401
2402                    # Save detections to sample
2403                    sample[pred_key] = fo.Detections(detections=fo_detections)
2404                    processed_count += 1
2405
2406                except Exception as e:
2407                    logging.error(f"Error processing sample {sample.id}: {e}")
2408                    continue
2409
2410            logging.info(f"Inference completed on {processed_count}/{len(dataset_view)} samples")
2411            logging.info(f"Predictions saved to field '{pred_key}'")
2412
2413        except Exception as e:
2414            logging.error(f"Error during inference: {e}")
2415            import traceback
2416            traceback.print_exc()
2417            return False
2418
2419        # Evaluate if requested
2420        if inference_settings.get("do_eval", True):
2421            eval_key = f"eval_{model_name}_{dataset_name}".replace("-", "_")
2422
2423            if inference_settings.get("inference_on_test", True):
2424                dataset_view = self.dataset.match_tags(["test"])
2425            else:
2426                dataset_view = self.dataset
2427
2428            # Filter samples that have both predictions and ground truth
2429            dataset_view = dataset_view.exists(pred_key).exists(gt_field)
2430
2431            if len(dataset_view) == 0:
2432                logging.warning("No samples found with both predictions and ground truth for evaluation")
2433            else:
2434                try:
2435                    logging.info(f"Evaluating predictions on {len(dataset_view)} samples...")
2436
2437                    results = dataset_view.evaluate_detections(
2438                        pred_key,
2439                        gt_field=gt_field,
2440                        eval_key=eval_key,
2441                        compute_mAP=True,
2442                        iou=0.5  # IoU threshold for matching
2443                    )
2444
2445                    # Print evaluation report
2446                    logging.info("="*70)
2447                    logging.info("EVALUATION RESULTS")
2448                    logging.info("="*70)
2449                    results.print_report()
2450                    logging.info("="*70)
2451
2452                    logging.info("Evaluation completed")
2453                except Exception as e:
2454                    logging.error(f"Evaluation failed: {e}")
2455                    import traceback
2456                    traceback.print_exc()
2457
2458        return True
def get_dataset_and_model_from_hf_id(hf_id: str):
60def get_dataset_and_model_from_hf_id(hf_id: str):
61    """Extract dataset and model name from HuggingFace ID by matching against supported datasets."""
62    # HF ID follows structure organization/dataset_model
63    # Both dataset and model can contain "_" as well
64
65    # Remove organization (everything before the first "/")
66    hf_id = hf_id.split("/", 1)[-1]
67
68    # Find all dataset names that appear in hf_id
69    supported_datasets = get_supported_datasets()
70    matches = [
71        dataset_name for dataset_name in supported_datasets if dataset_name in hf_id
72    ]
73
74    if not matches:
75        logging.warning(
76            f"Dataset name could not be extracted from Hugging Face ID {hf_id}"
77        )
78        dataset_name = "no_dataset_name"
79    else:
80        # Return the longest match (most specific)
81        dataset_name = max(matches, key=len)
82
83    # Get model name by removing dataset name from hf_id
84    model_name = hf_id.replace(dataset_name, "").strip("_")
85    if not model_name:
86        logging.warning(
87            f"Model name could not be extracted from Hugging Face ID {hf_id}"
88        )
89        model_name = "no_model_name"
90
91    return dataset_name, model_name

Extract dataset and model name from HuggingFace ID by matching against supported datasets.

class TimeoutException(builtins.Exception):
95class TimeoutException(Exception):
96    """Custom exception for handling dataloader timeouts."""
97
98    pass

Custom exception for handling dataloader timeouts.

def timeout_handler(signum, frame):
101def timeout_handler(signum, frame):
102    raise TimeoutException("Dataloader creation timed out")
class ZeroShotInferenceCollateFn:
105class ZeroShotInferenceCollateFn:
106    """Collate function for zero-shot inference that prepares batches for model input."""
107
108    def __init__(
109        self,
110        hf_model_config_name,
111        hf_processor,
112        batch_size,
113        object_classes,
114        batch_classes,
115    ):
116        """Initialize the auto labeling model with the Hugging Face model config, processor, batch size, object classes, and batch classes."""
117        try:
118            self.hf_model_config_name = hf_model_config_name
119            self.processor = hf_processor
120            self.batch_size = batch_size
121            self.object_classes = object_classes
122            self.batch_classes = batch_classes
123        except Exception as e:
124            logging.error(f"Error in collate init of DataLoader: {e}")
125
126    def __call__(self, batch):
127        """Processes a batch of data by preparing images and labels for model input."""
128        try:
129            images, labels = zip(*batch)
130            target_sizes = [tuple(img.shape[1:]) for img in images]
131
132            # Adjustments for final batch
133            n_images = len(images)
134            if n_images < self.batch_size:
135                self.batch_classes = [self.object_classes] * n_images
136
137            # Apply PIL transformation for specific models
138            if self.hf_model_config_name == "OmDetTurboConfig":
139                images = [to_pil_image(image) for image in images]
140
141            inputs = self.processor(
142                text=self.batch_classes,
143                images=images,
144                return_tensors="pt",
145                padding=True,  # Allow for differently sized images
146            )
147
148            return inputs, labels, target_sizes, self.batch_classes
149        except Exception as e:
150            logging.error(f"Error in collate function of DataLoader: {e}")

Collate function for zero-shot inference that prepares batches for model input.

ZeroShotInferenceCollateFn( hf_model_config_name, hf_processor, batch_size, object_classes, batch_classes)
108    def __init__(
109        self,
110        hf_model_config_name,
111        hf_processor,
112        batch_size,
113        object_classes,
114        batch_classes,
115    ):
116        """Initialize the auto labeling model with the Hugging Face model config, processor, batch size, object classes, and batch classes."""
117        try:
118            self.hf_model_config_name = hf_model_config_name
119            self.processor = hf_processor
120            self.batch_size = batch_size
121            self.object_classes = object_classes
122            self.batch_classes = batch_classes
123        except Exception as e:
124            logging.error(f"Error in collate init of DataLoader: {e}")

Initialize the auto labeling model with the Hugging Face model config, processor, batch size, object classes, and batch classes.

class ZeroShotObjectDetection:
153class ZeroShotObjectDetection:
154    """Zero-shot object detection using various HuggingFace models with multi-GPU support."""
155
156    def __init__(
157        self,
158        dataset_torch: torch.utils.data.Dataset,
159        dataset_info,
160        config,
161        detections_path="./output/detections/",
162        log_root="./logs/",
163    ):
164        """Initialize the zero-shot object detection labeler with dataset, configuration, and path settings."""
165        self.dataset_torch = dataset_torch
166        self.dataset_info = dataset_info
167        self.dataset_name = dataset_info["name"]
168        self.object_classes = config["object_classes"]
169        self.detection_threshold = config["detection_threshold"]
170        self.detections_root = os.path.join(detections_path, self.dataset_name)
171        self.tensorboard_root = os.path.join(
172            log_root, "tensorboard/zeroshot_object_detection"
173        )
174
175        logging.info(f"Zero-shot models will look for {self.object_classes}")
176
177    def exclude_stored_predictions(
178        self, dataset_v51: fo.Dataset, config, do_exclude=False
179    ):
180        """Checks for existing predictions and loads them from disk if available."""
181        dataset_schema = dataset_v51.get_field_schema()
182        models_splits_dict = {}
183        for model_name, value in config["hf_models_zeroshot_objectdetection"].items():
184            model_name_key = re.sub(r"[\W-]+", "_", model_name)
185            pred_key = re.sub(
186                r"[\W-]+", "_", "pred_zsod_" + model_name
187            )  # od for Object Detection
188            # Check if data already stored in V51 dataset
189            if pred_key in dataset_schema and do_exclude is True:
190                logging.warning(
191                    f"Skipping model {model_name}. Predictions already stored in Voxel51 dataset."
192                )
193            # Check if data already stored on disk
194            elif (
195                os.path.isdir(os.path.join(self.detections_root, model_name_key))
196                and do_exclude is True
197            ):
198                try:
199                    logging.info(f"Loading {model_name} predictions from disk.")
200                    temp_dataset = fo.Dataset.from_dir(
201                        dataset_dir=os.path.join(self.detections_root, model_name_key),
202                        dataset_type=fo.types.COCODetectionDataset,
203                        name="temp_dataset",
204                        data_path="data.json",
205                    )
206
207                    # Copy all detections from stored dataset into our dataset
208                    detections = temp_dataset.values("detections.detections")
209                    add_sample_field(
210                        dataset_v51,
211                        pred_key,
212                        fo.EmbeddedDocumentField,
213                        embedded_doc_type=fo.Detections,
214                    )
215                    dataset_v51.set_values(f"{pred_key}.detections", detections)
216                except Exception as e:
217                    logging.error(
218                        f"Data in {os.path.join(self.detections_root, model_name_key)} could not be loaded. Error: {e}"
219                    )
220                finally:
221                    fo.delete_dataset("temp_dataset")
222            # Assign model to be run
223            else:
224                models_splits_dict[model_name] = value
225
226        logging.info(f"Models to be run: {models_splits_dict}")
227        return models_splits_dict
228
229    # Worker functions
230    def update_queue_sizes_worker(
231        self, queues, queue_sizes, largest_queue_index, max_queue_size
232    ):
233        """Monitor and manage multiple result queues for balanced processing."""
234        experiment_name = f"queue_size_monitor_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
235        log_directory = os.path.join(
236            self.tensorboard_root, self.dataset_name, experiment_name
237        )
238        wandb.tensorboard.patch(root_logdir=log_directory)
239        if WANDB_ACTIVE:
240            wandb.init(
241                name=f"queue_size_monitor_{os.getpid()}",
242                job_type="inference",
243                project="Zero Shot Object Detection",
244            )
245        writer = SummaryWriter(log_dir=log_directory)
246
247        step = 0
248
249        while True:
250            for i, queue in enumerate(queues):
251                queue_sizes[i] = queue.qsize()
252                writer.add_scalar(f"queue_size/items/{i}", queue_sizes[i], step)
253
254            step += 1
255
256            # Find the index of the largest queue
257            max_size = max(queue_sizes)
258            max_index = queue_sizes.index(max_size)
259
260            # Calculate the total size of all queues
261            total_size = sum(queue_sizes)
262
263            # If total_size is greater than 0, calculate the probabilities
264            if total_size > 0:
265                # Normalize the queue sizes by the max_queue_size
266                normalized_sizes = [size / max_queue_size for size in queue_sizes]
267
268                # Calculate probabilities based on normalized sizes
269                probabilities = [
270                    size / sum(normalized_sizes) for size in normalized_sizes
271                ]
272
273                # Use random.choices with weights (probabilities)
274                chosen_queue_index = random.choices(
275                    range(len(queues)), weights=probabilities, k=1
276                )[0]
277
278                largest_queue_index.value = chosen_queue_index
279            else:
280                largest_queue_index.value = max_index
281
282            time.sleep(0.1)
283
284    def process_outputs_worker(
285        self,
286        result_queues,
287        largest_queue_index,
288        inference_finished,
289        max_queue_size,
290        wandb_activate=False,
291    ):
292        """Process model outputs from result queues and save to dataset."""
293        configure_logging()
294        logging.info(f"Process ID: {os.getpid()}. Results processing process started")
295        dataset_v51 = fo.load_dataset(self.dataset_name)
296        processing_successful = None
297
298        # Logging
299        experiment_name = f"post_process_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
300        log_directory = os.path.join(
301            self.tensorboard_root, self.dataset_name, experiment_name
302        )
303        wandb.tensorboard.patch(root_logdir=log_directory)
304        if WANDB_ACTIVE and wandb_activate:
305            wandb.init(
306                name=f"post_process_{os.getpid()}",
307                job_type="inference",
308                project="Zero Shot Object Detection",
309            )
310        writer = SummaryWriter(log_dir=log_directory)
311        n_processed_images = 0
312
313        logging.info(f"Post-Processor {os.getpid()} starting loop.")
314
315        while True:
316            results_queue = result_queues[largest_queue_index.value]
317            writer.add_scalar(
318                f"post_processing/selected_queue",
319                largest_queue_index.value,
320                n_processed_images,
321            )
322
323            if results_queue.qsize() == max_queue_size:
324                logging.warning(
325                    f"Queue full: {results_queue.qsize()}. Consider increasing number of post-processing workers."
326                )
327
328            # Exit only when inference is finished and the queue is empty
329            if inference_finished.value and results_queue.empty():
330                dataset_v51.save()
331                logging.info(
332                    f"Post-processing worker {os.getpid()} has finished all outputs."
333                )
334                break
335
336            # Process results from the queue if available
337            if not results_queue.empty():
338                try:
339                    time_start = time.time()
340
341                    result = results_queue.get_nowait()
342
343                    processing_successful = self.process_outputs(
344                        dataset_v51,
345                        result,
346                        self.object_classes,
347                        self.detection_threshold,
348                    )
349
350                    # Performance logging
351                    n_images = len(result["labels"])
352                    time_end = time.time()
353                    duration = time_end - time_start
354                    batches_per_second = 1 / duration
355                    frames_per_second = batches_per_second * n_images
356                    n_processed_images += n_images
357                    writer.add_scalar(
358                        f"post_processing/frames_per_second",
359                        frames_per_second,
360                        n_processed_images,
361                    )
362
363                    del result  # Explicit removal from device
364
365                except Exception as e:
366                    continue
367
368            else:
369                continue
370
371        writer.close()
372        wandb.finish(exit_code=0)
373        return processing_successful  # Return last processing status
374
375    def gpu_worker(
376        self,
377        gpu_id,
378        cpu_cores,
379        task_queue,
380        results_queue,
381        done_event,
382        post_processing_finished,
383        set_cpu_affinity=False,
384    ):
385        """Run model inference on specified GPU with dedicated CPU cores."""
386        dataset_v51 = fo.load_dataset(
387            self.dataset_name
388        )  # NOTE Only for the case of sequential processing
389        configure_logging()
390        # Set CPU
391        if set_cpu_affinity:
392            # Allow only certain CPU cores
393            psutil.Process().cpu_affinity(cpu_cores)
394        logging.info(f"Available CPU cores: {psutil.Process().cpu_affinity()}")
395        max_n_cpus = len(cpu_cores)
396        torch.set_num_threads(max_n_cpus)
397
398        # Set GPU
399        logging.info(f"GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
400        device = torch.device(f"cuda:{gpu_id}")
401
402        run_successful = None
403        with torch.cuda.device(gpu_id):
404            while True:
405                if post_processing_finished.value and task_queue.empty():
406                    # Keep alive until post-processing is done
407                    break
408
409                if task_queue.empty():
410                    done_event.set()
411
412                if not task_queue.empty():
413                    try:
414                        task_metadata = task_queue.get(
415                            timeout=5
416                        )  # Timeout to prevent indefinite blocking
417                    except Exception as e:
418                        break  # Exit if no more tasks
419                    run_successful = self.model_inference(
420                        task_metadata,
421                        device,
422                        self.dataset_torch,
423                        dataset_v51,
424                        self.object_classes,
425                        results_queue,
426                        self.tensorboard_root,
427                    )
428                    logging.info(
429                        f"Worker for GPU {gpu_id} finished run successful: {run_successful}"
430                    )
431                else:
432                    continue
433        return run_successful  # Return last processing status
434
435    def eval_and_export_worker(self, models_ready_queue, n_models):
436        """Evaluate model performance and export results for completed models."""
437        configure_logging()
438        logging.info(f"Process ID: {os.getpid()}. Eval-and-export process started")
439
440        dataset = fo.load_dataset(self.dataset_name)
441        run_successful = None
442        models_done = 0
443
444        while True:
445            if not models_ready_queue.empty():
446                try:
447                    dict = models_ready_queue.get(
448                        timeout=5
449                    )  # Timeout to prevent indefinite blocking
450                    model_name = dict["model_name"]
451                    pred_key = re.sub(r"[\W-]+", "_", "pred_zsod_" + model_name)
452                    eval_key = re.sub(r"[\W-]+", "_", "eval_zsod_" + model_name)
453                    dataset.reload()
454                    run_successful = self.eval_and_export(
455                        dataset, model_name, pred_key, eval_key
456                    )
457                    models_done += 1
458                    logging.info(
459                        f"Evaluation and export of {models_done}/{n_models} models done."
460                    )
461                except Exception as e:
462                    logging.error(f"Error in eval-and-export worker: {e}")
463                    continue
464
465            if models_done == n_models:
466                break
467
468        return run_successful
469
470    # Functionality functions
471    def model_inference(
472        self,
473        metadata: dict,
474        device: str,
475        dataset: torch.utils.data.Dataset,
476        dataset_v51: fo.Dataset,
477        object_classes: list,
478        results_queue: Union[queue.Queue, mp.Queue],
479        root_log_dir: str,
480        persistent_workers: bool = False,
481    ):
482        """Model inference method running zero-shot object detection on provided dataset and device, returning success status."""
483        writer = None
484        run_successful = True
485        processor, model, inputs, outputs, result, dataloader = (
486            None,
487            None,
488            None,
489            None,
490            None,
491            None,
492        )  # For finally block
493
494        # Timeout handler
495        dataloader_timeout = 60
496        signal.signal(signal.SIGALRM, timeout_handler)
497
498        try:
499            # Metadata
500            run_id = metadata["run_id"]
501            model_name = metadata["model_name"]
502            dataset_name = metadata["dataset_name"]
503            is_subset = metadata["is_subset"]
504            batch_size = metadata["batch_size"]
505
506            logging.info(
507                f"Process ID: {os.getpid()}, Run ID: {run_id}, Device: {device}, Model: {model_name}"
508            )
509
510            # Load the model
511            logging.info(f"Loading model {model_name}")
512            processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
513            model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name)
514            model = model.to(device, non_blocking=True)
515            model.eval()
516            hf_model_config = AutoConfig.from_pretrained(model_name)
517            hf_model_config_name = type(hf_model_config).__name__
518            batch_classes = [object_classes] * batch_size
519            logging.info(f"Loaded model type {hf_model_config_name}")
520
521            # Dataloader
522            logging.info("Generating dataloader")
523            if is_subset:
524                chunk_index_start = metadata["chunk_index_start"]
525                chunk_index_end = metadata["chunk_index_end"]
526                logging.info(f"Length of dataset: {len(dataset)}")
527                logging.info(f"Subset start index: {chunk_index_start}")
528                logging.info(f"Subset stop index: {chunk_index_end}")
529                dataset = Subset(dataset, range(chunk_index_start, chunk_index_end))
530
531            zero_shot_inference_preprocessing = ZeroShotInferenceCollateFn(
532                hf_model_config_name=hf_model_config_name,
533                hf_processor=processor,
534                object_classes=object_classes,
535                batch_size=batch_size,
536                batch_classes=batch_classes,
537            )
538            num_workers = WORKFLOWS["auto_labeling_zero_shot"]["n_worker_dataloader"]
539            prefetch_factor = WORKFLOWS["auto_labeling_zero_shot"][
540                "prefetch_factor_dataloader"
541            ]
542            dataloader = DataLoader(
543                dataset,
544                batch_size=batch_size,
545                shuffle=False,
546                num_workers=num_workers,
547                persistent_workers=persistent_workers,
548                pin_memory=True,
549                prefetch_factor=prefetch_factor,
550                collate_fn=zero_shot_inference_preprocessing,
551            )
552
553            dataloader_length = len(dataloader)
554            if dataloader_length < 1:
555                logging.error(
556                    f"Dataloader has insufficient data: {dataloader_length} entries. Please check your dataset and DataLoader configuration."
557                )
558
559            # Logging
560            experiment_name = f"{model_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{device}"
561            log_directory = os.path.join(root_log_dir, dataset_name, experiment_name)
562            wandb.tensorboard.patch(root_logdir=log_directory)
563            if WANDB_ACTIVE:
564                wandb.init(
565                    name=f"{model_name}_{device}",
566                    job_type="inference",
567                    project="Zero Shot Object Detection",
568                    config=metadata,
569                )
570            writer = SummaryWriter(log_dir=log_directory)
571
572            # Inference Loop
573            logging.info(f"{os.getpid()}: Starting inference loop5")
574            n_processed_images = 0
575            for inputs, labels, target_sizes, batch_classes in tqdm(
576                dataloader, desc="Inference Loop"
577            ):
578                signal.alarm(dataloader_timeout)
579                try:
580                    time_start = time.time()
581                    n_images = len(labels)
582                    inputs = inputs.to(device, non_blocking=True)
583
584                    with torch.amp.autocast("cuda"), torch.inference_mode():
585                        outputs = model(**inputs)
586
587                    result = {
588                        "inputs": inputs,
589                        "outputs": outputs,
590                        "processor": processor,
591                        "target_sizes": target_sizes,
592                        "labels": labels,
593                        "model_name": model_name,
594                        "hf_model_config_name": hf_model_config_name,
595                        "batch_classes": batch_classes,
596                    }
597
598                    logging.debug(f"{os.getpid()}: Putting result into queue")
599
600                    results_queue.put(
601                        result, timeout=60
602                    )  # Ditch data only after 60 seconds
603
604                    # Logging
605                    time_end = time.time()
606                    duration = time_end - time_start
607                    batches_per_second = 1 / duration
608                    frames_per_second = batches_per_second * n_images
609                    n_processed_images += n_images
610                    logging.debug(
611                        f"{os.getpid()}: Number of processes images: {n_processed_images}"
612                    )
613                    writer.add_scalar(
614                        f"inference/frames_per_second",
615                        frames_per_second,
616                        n_processed_images,
617                    )
618
619                except TimeoutException:
620                    logging.warning(
621                        f"Dataloader loop got stuck. Continuing with next batch."
622                    )
623                    continue
624
625                finally:
626                    signal.alarm(0)  # Cancel the alarm
627
628            # Flawless execution
629            wandb_exit_code = 0
630
631        except Exception as e:
632            wandb_exit_code = 1
633            run_successful = False
634            logging.error(f"Error in Process {os.getpid()}: {e}")
635        finally:
636            try:
637                wandb.finish(exit_code=wandb_exit_code)
638            except:
639                pass
640
641            # Explicit removal from device
642            del (
643                processor,
644                model,
645                inputs,
646                outputs,
647                result,
648                dataloader,
649            )
650
651            torch.cuda.empty_cache()
652            wandb.tensorboard.unpatch()
653            if writer:
654                writer.close()
655            return run_successful
656
657    def process_outputs(self, dataset_v51, result, object_classes, detection_threshold):
658        """Process outputs from object detection models, extracting bounding boxes and labels to save to the dataset."""
659        try:
660            inputs = result["inputs"]
661            outputs = result["outputs"]
662            target_sizes = result["target_sizes"]
663            labels = result["labels"]
664            model_name = result["model_name"]
665            hf_model_config_name = result["hf_model_config_name"]
666            batch_classes = result["batch_classes"]
667            processor = result["processor"]
668
669            # Processing output
670            if hf_model_config_name == "GroundingDinoConfig":
671                results = processor.post_process_grounded_object_detection(
672                    outputs,
673                    inputs.input_ids,
674                    box_threshold=detection_threshold,
675                    text_threshold=detection_threshold,
676                )
677            elif hf_model_config_name in ["Owlv2Config", "OwlViTConfig"]:
678                results = processor.post_process_grounded_object_detection(
679                    outputs=outputs,
680                    threshold=detection_threshold,
681                    target_sizes=target_sizes,
682                    text_labels=batch_classes,
683                )
684            elif hf_model_config_name == "OmDetTurboConfig":
685                results = processor.post_process_grounded_object_detection(
686                    outputs,
687                    text_labels=batch_classes,
688                    threshold=detection_threshold,
689                    nms_threshold=detection_threshold,
690                    target_sizes=target_sizes,
691                )
692            else:
693                logging.error(f"Invalid model name: {hf_model_config_name}")
694
695            if not len(results) == len(target_sizes) == len(labels):
696                logging.error(
697                    f"Lengths of results, target_sizes, and labels do not match: {len(results)}, {len(target_sizes)}, {len(labels)}"
698                )
699            for result, size, target in zip(results, target_sizes, labels):
700                boxes, scores, labels = (
701                    result["boxes"],
702                    result["scores"],
703                    result["text_labels"],
704                )
705
706                img_height = size[0]
707                img_width = size[1]
708
709                detections = []
710                for box, score, label in zip(boxes, scores, labels):
711                    processing_successful = True
712                    if hf_model_config_name == "GroundingDinoConfig":
713                        # Outputs do not comply with given labels
714                        # Grounding DINO outputs multiple pairs of object boxes and noun phrases for a given (Image, Text) pair
715                        # There can be either multiple labels per output ("bike van"), incomplete ones ("motorcyc"), or broken ones ("##cic")
716                        processed_label = label.split()[
717                            0
718                        ]  # Assume first output is the best output
719                        if processed_label in object_classes:
720                            label = processed_label
721                            top_left_x = box[0].item()
722                            top_left_y = box[1].item()
723                            box_width = (box[2] - box[0]).item()
724                            box_height = (box[3] - box[1]).item()
725                        else:
726                            matches = get_close_matches(
727                                processed_label, object_classes, n=1, cutoff=0.6
728                            )
729                            selected_label = matches[0] if matches else None
730                            if selected_label:
731                                logging.debug(
732                                    f"Mapped output '{processed_label}' to class '{selected_label}'"
733                                )
734                                label = selected_label
735                                top_left_x = box[0].item()
736                                top_left_y = box[1].item()
737                                box_width = (box[2] - box[0]).item()
738                                box_height = (box[3] - box[1]).item()
739                            else:
740                                logging.debug(
741                                    f"Skipped detection with {hf_model_config_name} due to unclear output: {label}"
742                                )
743                                processing_successful = False
744
745                    elif hf_model_config_name in [
746                        "Owlv2Config",
747                        "OwlViTConfig",
748                        "OmDetTurboConfig",
749                    ]:
750                        top_left_x = box[0].item() / img_width
751                        top_left_y = box[1].item() / img_height
752                        box_width = (box[2].item() - box[0].item()) / img_width
753                        box_height = (box[3].item() - box[1].item()) / img_height
754
755                    if (
756                        processing_successful
757                    ):  # Skip GroundingDinoConfig labels that could not be processed
758                        detection = fo.Detection(
759                            label=label,
760                            bounding_box=[
761                                top_left_x,
762                                top_left_y,
763                                box_width,
764                                box_height,
765                            ],
766                            confidence=score.item(),
767                        )
768                        detection["bbox_area"] = (
769                            detection["bounding_box"][2] * detection["bounding_box"][3]
770                        )
771                        detections.append(detection)
772
773                # Attach label to V51 dataset
774                pred_key = re.sub(
775                    r"[\W-]+", "_", "pred_zsod_" + model_name
776                )  # zsod Zero-Shot Object Deection
777                sample = dataset_v51[target["image_id"]]
778                sample[pred_key] = fo.Detections(detections=detections)
779                sample.save()
780
781        except Exception as e:
782            logging.error(f"Error in processing outputs: {e}")
783            processing_successful = False
784        finally:
785            return processing_successful
786
787    def eval_and_export(self, dataset_v51, model_name, pred_key, eval_key):
788        """Populate dataset with evaluation results (if ground_truth available)"""
789        try:
790            dataset_v51.evaluate_detections(
791                pred_key,
792                gt_field="ground_truth",
793                eval_key=eval_key,
794                compute_mAP=True,
795            )
796        except Exception as e:
797            logging.warning(f"Evaluation not possible: {e}")
798
799        # Store labels https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.export
800        model_name_key = re.sub(r"[\W-]+", "_", model_name)
801        dataset_v51.export(
802            export_dir=os.path.join(self.detections_root, model_name_key),
803            dataset_type=fo.types.COCODetectionDataset,
804            data_path="data.json",
805            export_media=None,  # "manifest",
806            label_field=pred_key,
807            progress=True,
808        )
809        return True

Zero-shot object detection using various HuggingFace models with multi-GPU support.

ZeroShotObjectDetection( dataset_torch: torch.utils.data.dataset.Dataset, dataset_info, config, detections_path='./output/detections/', log_root='./logs/')
156    def __init__(
157        self,
158        dataset_torch: torch.utils.data.Dataset,
159        dataset_info,
160        config,
161        detections_path="./output/detections/",
162        log_root="./logs/",
163    ):
164        """Initialize the zero-shot object detection labeler with dataset, configuration, and path settings."""
165        self.dataset_torch = dataset_torch
166        self.dataset_info = dataset_info
167        self.dataset_name = dataset_info["name"]
168        self.object_classes = config["object_classes"]
169        self.detection_threshold = config["detection_threshold"]
170        self.detections_root = os.path.join(detections_path, self.dataset_name)
171        self.tensorboard_root = os.path.join(
172            log_root, "tensorboard/zeroshot_object_detection"
173        )
174
175        logging.info(f"Zero-shot models will look for {self.object_classes}")

Initialize the zero-shot object detection labeler with dataset, configuration, and path settings.

dataset_torch
dataset_info
dataset_name
object_classes
detection_threshold
detections_root
tensorboard_root
def exclude_stored_predictions( self, dataset_v51: fiftyone.core.dataset.Dataset, config, do_exclude=False):
177    def exclude_stored_predictions(
178        self, dataset_v51: fo.Dataset, config, do_exclude=False
179    ):
180        """Checks for existing predictions and loads them from disk if available."""
181        dataset_schema = dataset_v51.get_field_schema()
182        models_splits_dict = {}
183        for model_name, value in config["hf_models_zeroshot_objectdetection"].items():
184            model_name_key = re.sub(r"[\W-]+", "_", model_name)
185            pred_key = re.sub(
186                r"[\W-]+", "_", "pred_zsod_" + model_name
187            )  # od for Object Detection
188            # Check if data already stored in V51 dataset
189            if pred_key in dataset_schema and do_exclude is True:
190                logging.warning(
191                    f"Skipping model {model_name}. Predictions already stored in Voxel51 dataset."
192                )
193            # Check if data already stored on disk
194            elif (
195                os.path.isdir(os.path.join(self.detections_root, model_name_key))
196                and do_exclude is True
197            ):
198                try:
199                    logging.info(f"Loading {model_name} predictions from disk.")
200                    temp_dataset = fo.Dataset.from_dir(
201                        dataset_dir=os.path.join(self.detections_root, model_name_key),
202                        dataset_type=fo.types.COCODetectionDataset,
203                        name="temp_dataset",
204                        data_path="data.json",
205                    )
206
207                    # Copy all detections from stored dataset into our dataset
208                    detections = temp_dataset.values("detections.detections")
209                    add_sample_field(
210                        dataset_v51,
211                        pred_key,
212                        fo.EmbeddedDocumentField,
213                        embedded_doc_type=fo.Detections,
214                    )
215                    dataset_v51.set_values(f"{pred_key}.detections", detections)
216                except Exception as e:
217                    logging.error(
218                        f"Data in {os.path.join(self.detections_root, model_name_key)} could not be loaded. Error: {e}"
219                    )
220                finally:
221                    fo.delete_dataset("temp_dataset")
222            # Assign model to be run
223            else:
224                models_splits_dict[model_name] = value
225
226        logging.info(f"Models to be run: {models_splits_dict}")
227        return models_splits_dict

Checks for existing predictions and loads them from disk if available.

def update_queue_sizes_worker(self, queues, queue_sizes, largest_queue_index, max_queue_size):
230    def update_queue_sizes_worker(
231        self, queues, queue_sizes, largest_queue_index, max_queue_size
232    ):
233        """Monitor and manage multiple result queues for balanced processing."""
234        experiment_name = f"queue_size_monitor_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
235        log_directory = os.path.join(
236            self.tensorboard_root, self.dataset_name, experiment_name
237        )
238        wandb.tensorboard.patch(root_logdir=log_directory)
239        if WANDB_ACTIVE:
240            wandb.init(
241                name=f"queue_size_monitor_{os.getpid()}",
242                job_type="inference",
243                project="Zero Shot Object Detection",
244            )
245        writer = SummaryWriter(log_dir=log_directory)
246
247        step = 0
248
249        while True:
250            for i, queue in enumerate(queues):
251                queue_sizes[i] = queue.qsize()
252                writer.add_scalar(f"queue_size/items/{i}", queue_sizes[i], step)
253
254            step += 1
255
256            # Find the index of the largest queue
257            max_size = max(queue_sizes)
258            max_index = queue_sizes.index(max_size)
259
260            # Calculate the total size of all queues
261            total_size = sum(queue_sizes)
262
263            # If total_size is greater than 0, calculate the probabilities
264            if total_size > 0:
265                # Normalize the queue sizes by the max_queue_size
266                normalized_sizes = [size / max_queue_size for size in queue_sizes]
267
268                # Calculate probabilities based on normalized sizes
269                probabilities = [
270                    size / sum(normalized_sizes) for size in normalized_sizes
271                ]
272
273                # Use random.choices with weights (probabilities)
274                chosen_queue_index = random.choices(
275                    range(len(queues)), weights=probabilities, k=1
276                )[0]
277
278                largest_queue_index.value = chosen_queue_index
279            else:
280                largest_queue_index.value = max_index
281
282            time.sleep(0.1)

Monitor and manage multiple result queues for balanced processing.

def process_outputs_worker( self, result_queues, largest_queue_index, inference_finished, max_queue_size, wandb_activate=False):
284    def process_outputs_worker(
285        self,
286        result_queues,
287        largest_queue_index,
288        inference_finished,
289        max_queue_size,
290        wandb_activate=False,
291    ):
292        """Process model outputs from result queues and save to dataset."""
293        configure_logging()
294        logging.info(f"Process ID: {os.getpid()}. Results processing process started")
295        dataset_v51 = fo.load_dataset(self.dataset_name)
296        processing_successful = None
297
298        # Logging
299        experiment_name = f"post_process_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
300        log_directory = os.path.join(
301            self.tensorboard_root, self.dataset_name, experiment_name
302        )
303        wandb.tensorboard.patch(root_logdir=log_directory)
304        if WANDB_ACTIVE and wandb_activate:
305            wandb.init(
306                name=f"post_process_{os.getpid()}",
307                job_type="inference",
308                project="Zero Shot Object Detection",
309            )
310        writer = SummaryWriter(log_dir=log_directory)
311        n_processed_images = 0
312
313        logging.info(f"Post-Processor {os.getpid()} starting loop.")
314
315        while True:
316            results_queue = result_queues[largest_queue_index.value]
317            writer.add_scalar(
318                f"post_processing/selected_queue",
319                largest_queue_index.value,
320                n_processed_images,
321            )
322
323            if results_queue.qsize() == max_queue_size:
324                logging.warning(
325                    f"Queue full: {results_queue.qsize()}. Consider increasing number of post-processing workers."
326                )
327
328            # Exit only when inference is finished and the queue is empty
329            if inference_finished.value and results_queue.empty():
330                dataset_v51.save()
331                logging.info(
332                    f"Post-processing worker {os.getpid()} has finished all outputs."
333                )
334                break
335
336            # Process results from the queue if available
337            if not results_queue.empty():
338                try:
339                    time_start = time.time()
340
341                    result = results_queue.get_nowait()
342
343                    processing_successful = self.process_outputs(
344                        dataset_v51,
345                        result,
346                        self.object_classes,
347                        self.detection_threshold,
348                    )
349
350                    # Performance logging
351                    n_images = len(result["labels"])
352                    time_end = time.time()
353                    duration = time_end - time_start
354                    batches_per_second = 1 / duration
355                    frames_per_second = batches_per_second * n_images
356                    n_processed_images += n_images
357                    writer.add_scalar(
358                        f"post_processing/frames_per_second",
359                        frames_per_second,
360                        n_processed_images,
361                    )
362
363                    del result  # Explicit removal from device
364
365                except Exception as e:
366                    continue
367
368            else:
369                continue
370
371        writer.close()
372        wandb.finish(exit_code=0)
373        return processing_successful  # Return last processing status

Process model outputs from result queues and save to dataset.

def gpu_worker( self, gpu_id, cpu_cores, task_queue, results_queue, done_event, post_processing_finished, set_cpu_affinity=False):
375    def gpu_worker(
376        self,
377        gpu_id,
378        cpu_cores,
379        task_queue,
380        results_queue,
381        done_event,
382        post_processing_finished,
383        set_cpu_affinity=False,
384    ):
385        """Run model inference on specified GPU with dedicated CPU cores."""
386        dataset_v51 = fo.load_dataset(
387            self.dataset_name
388        )  # NOTE Only for the case of sequential processing
389        configure_logging()
390        # Set CPU
391        if set_cpu_affinity:
392            # Allow only certain CPU cores
393            psutil.Process().cpu_affinity(cpu_cores)
394        logging.info(f"Available CPU cores: {psutil.Process().cpu_affinity()}")
395        max_n_cpus = len(cpu_cores)
396        torch.set_num_threads(max_n_cpus)
397
398        # Set GPU
399        logging.info(f"GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
400        device = torch.device(f"cuda:{gpu_id}")
401
402        run_successful = None
403        with torch.cuda.device(gpu_id):
404            while True:
405                if post_processing_finished.value and task_queue.empty():
406                    # Keep alive until post-processing is done
407                    break
408
409                if task_queue.empty():
410                    done_event.set()
411
412                if not task_queue.empty():
413                    try:
414                        task_metadata = task_queue.get(
415                            timeout=5
416                        )  # Timeout to prevent indefinite blocking
417                    except Exception as e:
418                        break  # Exit if no more tasks
419                    run_successful = self.model_inference(
420                        task_metadata,
421                        device,
422                        self.dataset_torch,
423                        dataset_v51,
424                        self.object_classes,
425                        results_queue,
426                        self.tensorboard_root,
427                    )
428                    logging.info(
429                        f"Worker for GPU {gpu_id} finished run successful: {run_successful}"
430                    )
431                else:
432                    continue
433        return run_successful  # Return last processing status

Run model inference on specified GPU with dedicated CPU cores.

def eval_and_export_worker(self, models_ready_queue, n_models):
435    def eval_and_export_worker(self, models_ready_queue, n_models):
436        """Evaluate model performance and export results for completed models."""
437        configure_logging()
438        logging.info(f"Process ID: {os.getpid()}. Eval-and-export process started")
439
440        dataset = fo.load_dataset(self.dataset_name)
441        run_successful = None
442        models_done = 0
443
444        while True:
445            if not models_ready_queue.empty():
446                try:
447                    dict = models_ready_queue.get(
448                        timeout=5
449                    )  # Timeout to prevent indefinite blocking
450                    model_name = dict["model_name"]
451                    pred_key = re.sub(r"[\W-]+", "_", "pred_zsod_" + model_name)
452                    eval_key = re.sub(r"[\W-]+", "_", "eval_zsod_" + model_name)
453                    dataset.reload()
454                    run_successful = self.eval_and_export(
455                        dataset, model_name, pred_key, eval_key
456                    )
457                    models_done += 1
458                    logging.info(
459                        f"Evaluation and export of {models_done}/{n_models} models done."
460                    )
461                except Exception as e:
462                    logging.error(f"Error in eval-and-export worker: {e}")
463                    continue
464
465            if models_done == n_models:
466                break
467
468        return run_successful

Evaluate model performance and export results for completed models.

def model_inference( self, metadata: dict, device: str, dataset: torch.utils.data.dataset.Dataset, dataset_v51: fiftyone.core.dataset.Dataset, object_classes: list, results_queue: Union[queue.Queue, <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>], root_log_dir: str, persistent_workers: bool = False):
471    def model_inference(
472        self,
473        metadata: dict,
474        device: str,
475        dataset: torch.utils.data.Dataset,
476        dataset_v51: fo.Dataset,
477        object_classes: list,
478        results_queue: Union[queue.Queue, mp.Queue],
479        root_log_dir: str,
480        persistent_workers: bool = False,
481    ):
482        """Model inference method running zero-shot object detection on provided dataset and device, returning success status."""
483        writer = None
484        run_successful = True
485        processor, model, inputs, outputs, result, dataloader = (
486            None,
487            None,
488            None,
489            None,
490            None,
491            None,
492        )  # For finally block
493
494        # Timeout handler
495        dataloader_timeout = 60
496        signal.signal(signal.SIGALRM, timeout_handler)
497
498        try:
499            # Metadata
500            run_id = metadata["run_id"]
501            model_name = metadata["model_name"]
502            dataset_name = metadata["dataset_name"]
503            is_subset = metadata["is_subset"]
504            batch_size = metadata["batch_size"]
505
506            logging.info(
507                f"Process ID: {os.getpid()}, Run ID: {run_id}, Device: {device}, Model: {model_name}"
508            )
509
510            # Load the model
511            logging.info(f"Loading model {model_name}")
512            processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
513            model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name)
514            model = model.to(device, non_blocking=True)
515            model.eval()
516            hf_model_config = AutoConfig.from_pretrained(model_name)
517            hf_model_config_name = type(hf_model_config).__name__
518            batch_classes = [object_classes] * batch_size
519            logging.info(f"Loaded model type {hf_model_config_name}")
520
521            # Dataloader
522            logging.info("Generating dataloader")
523            if is_subset:
524                chunk_index_start = metadata["chunk_index_start"]
525                chunk_index_end = metadata["chunk_index_end"]
526                logging.info(f"Length of dataset: {len(dataset)}")
527                logging.info(f"Subset start index: {chunk_index_start}")
528                logging.info(f"Subset stop index: {chunk_index_end}")
529                dataset = Subset(dataset, range(chunk_index_start, chunk_index_end))
530
531            zero_shot_inference_preprocessing = ZeroShotInferenceCollateFn(
532                hf_model_config_name=hf_model_config_name,
533                hf_processor=processor,
534                object_classes=object_classes,
535                batch_size=batch_size,
536                batch_classes=batch_classes,
537            )
538            num_workers = WORKFLOWS["auto_labeling_zero_shot"]["n_worker_dataloader"]
539            prefetch_factor = WORKFLOWS["auto_labeling_zero_shot"][
540                "prefetch_factor_dataloader"
541            ]
542            dataloader = DataLoader(
543                dataset,
544                batch_size=batch_size,
545                shuffle=False,
546                num_workers=num_workers,
547                persistent_workers=persistent_workers,
548                pin_memory=True,
549                prefetch_factor=prefetch_factor,
550                collate_fn=zero_shot_inference_preprocessing,
551            )
552
553            dataloader_length = len(dataloader)
554            if dataloader_length < 1:
555                logging.error(
556                    f"Dataloader has insufficient data: {dataloader_length} entries. Please check your dataset and DataLoader configuration."
557                )
558
559            # Logging
560            experiment_name = f"{model_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{device}"
561            log_directory = os.path.join(root_log_dir, dataset_name, experiment_name)
562            wandb.tensorboard.patch(root_logdir=log_directory)
563            if WANDB_ACTIVE:
564                wandb.init(
565                    name=f"{model_name}_{device}",
566                    job_type="inference",
567                    project="Zero Shot Object Detection",
568                    config=metadata,
569                )
570            writer = SummaryWriter(log_dir=log_directory)
571
572            # Inference Loop
573            logging.info(f"{os.getpid()}: Starting inference loop5")
574            n_processed_images = 0
575            for inputs, labels, target_sizes, batch_classes in tqdm(
576                dataloader, desc="Inference Loop"
577            ):
578                signal.alarm(dataloader_timeout)
579                try:
580                    time_start = time.time()
581                    n_images = len(labels)
582                    inputs = inputs.to(device, non_blocking=True)
583
584                    with torch.amp.autocast("cuda"), torch.inference_mode():
585                        outputs = model(**inputs)
586
587                    result = {
588                        "inputs": inputs,
589                        "outputs": outputs,
590                        "processor": processor,
591                        "target_sizes": target_sizes,
592                        "labels": labels,
593                        "model_name": model_name,
594                        "hf_model_config_name": hf_model_config_name,
595                        "batch_classes": batch_classes,
596                    }
597
598                    logging.debug(f"{os.getpid()}: Putting result into queue")
599
600                    results_queue.put(
601                        result, timeout=60
602                    )  # Ditch data only after 60 seconds
603
604                    # Logging
605                    time_end = time.time()
606                    duration = time_end - time_start
607                    batches_per_second = 1 / duration
608                    frames_per_second = batches_per_second * n_images
609                    n_processed_images += n_images
610                    logging.debug(
611                        f"{os.getpid()}: Number of processes images: {n_processed_images}"
612                    )
613                    writer.add_scalar(
614                        f"inference/frames_per_second",
615                        frames_per_second,
616                        n_processed_images,
617                    )
618
619                except TimeoutException:
620                    logging.warning(
621                        f"Dataloader loop got stuck. Continuing with next batch."
622                    )
623                    continue
624
625                finally:
626                    signal.alarm(0)  # Cancel the alarm
627
628            # Flawless execution
629            wandb_exit_code = 0
630
631        except Exception as e:
632            wandb_exit_code = 1
633            run_successful = False
634            logging.error(f"Error in Process {os.getpid()}: {e}")
635        finally:
636            try:
637                wandb.finish(exit_code=wandb_exit_code)
638            except:
639                pass
640
641            # Explicit removal from device
642            del (
643                processor,
644                model,
645                inputs,
646                outputs,
647                result,
648                dataloader,
649            )
650
651            torch.cuda.empty_cache()
652            wandb.tensorboard.unpatch()
653            if writer:
654                writer.close()
655            return run_successful

Model inference method running zero-shot object detection on provided dataset and device, returning success status.

def process_outputs(self, dataset_v51, result, object_classes, detection_threshold):
657    def process_outputs(self, dataset_v51, result, object_classes, detection_threshold):
658        """Process outputs from object detection models, extracting bounding boxes and labels to save to the dataset."""
659        try:
660            inputs = result["inputs"]
661            outputs = result["outputs"]
662            target_sizes = result["target_sizes"]
663            labels = result["labels"]
664            model_name = result["model_name"]
665            hf_model_config_name = result["hf_model_config_name"]
666            batch_classes = result["batch_classes"]
667            processor = result["processor"]
668
669            # Processing output
670            if hf_model_config_name == "GroundingDinoConfig":
671                results = processor.post_process_grounded_object_detection(
672                    outputs,
673                    inputs.input_ids,
674                    box_threshold=detection_threshold,
675                    text_threshold=detection_threshold,
676                )
677            elif hf_model_config_name in ["Owlv2Config", "OwlViTConfig"]:
678                results = processor.post_process_grounded_object_detection(
679                    outputs=outputs,
680                    threshold=detection_threshold,
681                    target_sizes=target_sizes,
682                    text_labels=batch_classes,
683                )
684            elif hf_model_config_name == "OmDetTurboConfig":
685                results = processor.post_process_grounded_object_detection(
686                    outputs,
687                    text_labels=batch_classes,
688                    threshold=detection_threshold,
689                    nms_threshold=detection_threshold,
690                    target_sizes=target_sizes,
691                )
692            else:
693                logging.error(f"Invalid model name: {hf_model_config_name}")
694
695            if not len(results) == len(target_sizes) == len(labels):
696                logging.error(
697                    f"Lengths of results, target_sizes, and labels do not match: {len(results)}, {len(target_sizes)}, {len(labels)}"
698                )
699            for result, size, target in zip(results, target_sizes, labels):
700                boxes, scores, labels = (
701                    result["boxes"],
702                    result["scores"],
703                    result["text_labels"],
704                )
705
706                img_height = size[0]
707                img_width = size[1]
708
709                detections = []
710                for box, score, label in zip(boxes, scores, labels):
711                    processing_successful = True
712                    if hf_model_config_name == "GroundingDinoConfig":
713                        # Outputs do not comply with given labels
714                        # Grounding DINO outputs multiple pairs of object boxes and noun phrases for a given (Image, Text) pair
715                        # There can be either multiple labels per output ("bike van"), incomplete ones ("motorcyc"), or broken ones ("##cic")
716                        processed_label = label.split()[
717                            0
718                        ]  # Assume first output is the best output
719                        if processed_label in object_classes:
720                            label = processed_label
721                            top_left_x = box[0].item()
722                            top_left_y = box[1].item()
723                            box_width = (box[2] - box[0]).item()
724                            box_height = (box[3] - box[1]).item()
725                        else:
726                            matches = get_close_matches(
727                                processed_label, object_classes, n=1, cutoff=0.6
728                            )
729                            selected_label = matches[0] if matches else None
730                            if selected_label:
731                                logging.debug(
732                                    f"Mapped output '{processed_label}' to class '{selected_label}'"
733                                )
734                                label = selected_label
735                                top_left_x = box[0].item()
736                                top_left_y = box[1].item()
737                                box_width = (box[2] - box[0]).item()
738                                box_height = (box[3] - box[1]).item()
739                            else:
740                                logging.debug(
741                                    f"Skipped detection with {hf_model_config_name} due to unclear output: {label}"
742                                )
743                                processing_successful = False
744
745                    elif hf_model_config_name in [
746                        "Owlv2Config",
747                        "OwlViTConfig",
748                        "OmDetTurboConfig",
749                    ]:
750                        top_left_x = box[0].item() / img_width
751                        top_left_y = box[1].item() / img_height
752                        box_width = (box[2].item() - box[0].item()) / img_width
753                        box_height = (box[3].item() - box[1].item()) / img_height
754
755                    if (
756                        processing_successful
757                    ):  # Skip GroundingDinoConfig labels that could not be processed
758                        detection = fo.Detection(
759                            label=label,
760                            bounding_box=[
761                                top_left_x,
762                                top_left_y,
763                                box_width,
764                                box_height,
765                            ],
766                            confidence=score.item(),
767                        )
768                        detection["bbox_area"] = (
769                            detection["bounding_box"][2] * detection["bounding_box"][3]
770                        )
771                        detections.append(detection)
772
773                # Attach label to V51 dataset
774                pred_key = re.sub(
775                    r"[\W-]+", "_", "pred_zsod_" + model_name
776                )  # zsod Zero-Shot Object Deection
777                sample = dataset_v51[target["image_id"]]
778                sample[pred_key] = fo.Detections(detections=detections)
779                sample.save()
780
781        except Exception as e:
782            logging.error(f"Error in processing outputs: {e}")
783            processing_successful = False
784        finally:
785            return processing_successful

Process outputs from object detection models, extracting bounding boxes and labels to save to the dataset.

def eval_and_export(self, dataset_v51, model_name, pred_key, eval_key):
787    def eval_and_export(self, dataset_v51, model_name, pred_key, eval_key):
788        """Populate dataset with evaluation results (if ground_truth available)"""
789        try:
790            dataset_v51.evaluate_detections(
791                pred_key,
792                gt_field="ground_truth",
793                eval_key=eval_key,
794                compute_mAP=True,
795            )
796        except Exception as e:
797            logging.warning(f"Evaluation not possible: {e}")
798
799        # Store labels https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.export
800        model_name_key = re.sub(r"[\W-]+", "_", model_name)
801        dataset_v51.export(
802            export_dir=os.path.join(self.detections_root, model_name_key),
803            dataset_type=fo.types.COCODetectionDataset,
804            data_path="data.json",
805            export_media=None,  # "manifest",
806            label_field=pred_key,
807            progress=True,
808        )
809        return True

Populate dataset with evaluation results (if ground_truth available)

class UltralyticsObjectDetection:
812class UltralyticsObjectDetection:
813    """Object detection using Ultralytics YOLO models with training and inference support."""
814
815    def __init__(self, dataset, config):
816        """Initialize with dataset, config, and setup paths for model and data."""
817        self.dataset = dataset
818        self.config = config
819        self.ultralytics_data_path = os.path.join(
820            config["export_dataset_root"], config["v51_dataset_name"]
821        )
822
823        self.hf_hub_model_id = (
824            f"{HF_ROOT}/"
825            + f"{config['v51_dataset_name']}_{config['model_name']}".replace("/", "_")
826        )
827
828        self.export_root = "output/models/ultralytics/"
829        self.export_folder = os.path.join(
830            self.export_root, self.config["v51_dataset_name"]
831        )
832
833        self.model_path = os.path.join(
834            self.export_folder, self.config["model_name"], "weights", "best.pt"
835        )
836
837    @staticmethod
838    def export_data(
839        dataset, dataset_info, export_dataset_root, label_field="ground_truth"
840    ):
841        """Export dataset to YOLO format for Ultralytics training."""
842        ultralytics_data_path = os.path.join(export_dataset_root, dataset_info["name"])
843        # Delete export directory if it already exists
844        if os.path.exists(ultralytics_data_path):
845            shutil.rmtree(ultralytics_data_path)
846
847        logging.info("Exporting data for training with Ultralytics")
848        classes = dataset.distinct(f"{label_field}.detections.label")
849
850        # Make directory
851        os.makedirs(ultralytics_data_path, exist_ok=False)
852
853        for split in ACCEPTED_SPLITS:
854            split_view = dataset.match_tags(split)
855
856            if split == "val" or split == "train":  # YOLO expects train and val
857                split_view.export(
858                    export_dir=ultralytics_data_path,
859                    dataset_type=fo.types.YOLOv5Dataset,
860                    label_field=label_field,
861                    classes=classes,
862                    split=split,
863                )
864
865    def train(self):
866        """Train the YOLO model for object detection using Ultralytics and optionally upload to Hugging Face."""
867        model = YOLO(self.config["model_name"], task="detect")
868        # https://docs.ultralytics.com/modes/train/#train-settings
869
870        # Use all available GPUs
871        device = "0"  # Default to GPU 0
872        if torch.cuda.device_count() > 1:
873            device = ",".join(map(str, range(torch.cuda.device_count())))
874
875        results = model.train(
876            data=f"{self.ultralytics_data_path}/dataset.yaml",
877            epochs=self.config["epochs"],
878            project=self.export_folder,
879            name=self.config["model_name"],
880            patience=self.config["patience"],
881            batch=self.config["batch_size"],
882            imgsz=self.config["img_size"],
883            multi_scale=self.config["multi_scale"],
884            cos_lr=self.config["cos_lr"],
885            seed=GLOBAL_SEED,
886            optimizer="AdamW",  # "auto" as default
887            pretrained=True,
888            exist_ok=True,
889            amp=True,
890            device=device
891        )
892        metrics = model.val()
893        logging.info(f"Model Performance: {metrics}")
894
895        # Upload model to Hugging Face
896        if HF_DO_UPLOAD:
897            logging.info(f"Uploading model {self.model_path} to Hugging Face.")
898            api = HfApi()
899            api.create_repo(
900                self.hf_hub_model_id, private=True, repo_type="model", exist_ok=True
901            )
902            api.upload_file(
903                path_or_fileobj=self.model_path,
904                path_in_repo="best.pt",
905                repo_id=self.hf_hub_model_id,
906                repo_type="model",
907            )
908
909    def inference(self, gt_field="ground_truth"):
910        """Performs inference using YOLO model on a dataset, with options to evaluate results."""
911        logging.info(f"Running inference on dataset {self.config['v51_dataset_name']}")
912        inference_settings = self.config["inference_settings"]
913
914        dataset_name = None
915        model_name = self.config["model_name"]
916
917        model_hf = inference_settings["model_hf"]
918        if model_hf is not None:
919            # Use model manually defined in config.
920            # This way models can be used for inference which were trained on a different dataset
921            dataset_name, _ = get_dataset_and_model_from_hf_id(model_hf)
922
923            # Set up directories
924            download_dir = os.path.join(
925                self.export_root, dataset_name, model_name, "weights"
926            )
927            os.makedirs(os.path.join(download_dir), exist_ok=True)
928
929            self.model_path = os.path.join(download_dir, "best.pt")
930
931            # Create directories if they don't exist
932
933            file_path = hf_hub_download(
934                repo_id=model_hf,
935                filename="best.pt",
936                local_dir=download_dir,
937            )
938        else:
939            # Automatically determine model based on dataset
940            dataset_name = self.config["v51_dataset_name"]
941
942            try:
943                if os.path.exists(self.model_path):
944                    file_path = self.model_path
945                    logging.info(f"Loading model {model_name} from disk: {file_path}")
946                else:
947                    download_dir = self.model_path.replace("best.pt", "")
948                    os.makedirs(download_dir, exist_ok=True)
949                    logging.info(
950                        f"Downloading model {self.hf_hub_model_id} from Hugging Face to {download_dir}"
951                    )
952                    file_path = hf_hub_download(
953                        repo_id=self.hf_hub_model_id,
954                        filename="best.pt",
955                        local_dir=download_dir,
956                    )
957            except Exception as e:
958                logging.error(f"Failed to load or download model: {str(e)}.")
959                return False
960
961        pred_key = f"pred_od_{model_name}-{dataset_name}"
962        logging.info(f"Using model {self.model_path} for inference.")
963        model = YOLO(self.model_path)
964
965        detection_threshold = inference_settings["detection_threshold"]
966        if inference_settings["inference_on_test"] is True:
967            dataset_eval_view = self.dataset.match_tags("test")
968            if len(dataset_eval_view) == 0:
969                logging.error("Dataset misses split 'test'")
970            dataset_eval_view.apply_model(
971                model, label_field=pred_key, confidence_thresh=detection_threshold
972            )
973        else:
974            self.dataset.apply_model(
975                model, label_field=pred_key, confidence_thresh=detection_threshold
976            )
977
978        if inference_settings["do_eval"]:
979            eval_key = f"eval_{self.config['model_name']}_{dataset_name}"
980
981            if inference_settings["inference_on_test"] is True:
982                dataset_view = self.dataset.match_tags(["test"])
983            else:
984                dataset_view = self.dataset
985
986            results = dataset_view.evaluate_detections(
987                pred_key,
988                gt_field=gt_field,
989                eval_key=eval_key,
990                compute_mAP=True,
991            )
992
993            results.print_report()

Object detection using Ultralytics YOLO models with training and inference support.

UltralyticsObjectDetection(dataset, config)
815    def __init__(self, dataset, config):
816        """Initialize with dataset, config, and setup paths for model and data."""
817        self.dataset = dataset
818        self.config = config
819        self.ultralytics_data_path = os.path.join(
820            config["export_dataset_root"], config["v51_dataset_name"]
821        )
822
823        self.hf_hub_model_id = (
824            f"{HF_ROOT}/"
825            + f"{config['v51_dataset_name']}_{config['model_name']}".replace("/", "_")
826        )
827
828        self.export_root = "output/models/ultralytics/"
829        self.export_folder = os.path.join(
830            self.export_root, self.config["v51_dataset_name"]
831        )
832
833        self.model_path = os.path.join(
834            self.export_folder, self.config["model_name"], "weights", "best.pt"
835        )

Initialize with dataset, config, and setup paths for model and data.

dataset
config
ultralytics_data_path
hf_hub_model_id
export_root
export_folder
model_path
@staticmethod
def export_data( dataset, dataset_info, export_dataset_root, label_field='ground_truth'):
837    @staticmethod
838    def export_data(
839        dataset, dataset_info, export_dataset_root, label_field="ground_truth"
840    ):
841        """Export dataset to YOLO format for Ultralytics training."""
842        ultralytics_data_path = os.path.join(export_dataset_root, dataset_info["name"])
843        # Delete export directory if it already exists
844        if os.path.exists(ultralytics_data_path):
845            shutil.rmtree(ultralytics_data_path)
846
847        logging.info("Exporting data for training with Ultralytics")
848        classes = dataset.distinct(f"{label_field}.detections.label")
849
850        # Make directory
851        os.makedirs(ultralytics_data_path, exist_ok=False)
852
853        for split in ACCEPTED_SPLITS:
854            split_view = dataset.match_tags(split)
855
856            if split == "val" or split == "train":  # YOLO expects train and val
857                split_view.export(
858                    export_dir=ultralytics_data_path,
859                    dataset_type=fo.types.YOLOv5Dataset,
860                    label_field=label_field,
861                    classes=classes,
862                    split=split,
863                )

Export dataset to YOLO format for Ultralytics training.

def train(self):
865    def train(self):
866        """Train the YOLO model for object detection using Ultralytics and optionally upload to Hugging Face."""
867        model = YOLO(self.config["model_name"], task="detect")
868        # https://docs.ultralytics.com/modes/train/#train-settings
869
870        # Use all available GPUs
871        device = "0"  # Default to GPU 0
872        if torch.cuda.device_count() > 1:
873            device = ",".join(map(str, range(torch.cuda.device_count())))
874
875        results = model.train(
876            data=f"{self.ultralytics_data_path}/dataset.yaml",
877            epochs=self.config["epochs"],
878            project=self.export_folder,
879            name=self.config["model_name"],
880            patience=self.config["patience"],
881            batch=self.config["batch_size"],
882            imgsz=self.config["img_size"],
883            multi_scale=self.config["multi_scale"],
884            cos_lr=self.config["cos_lr"],
885            seed=GLOBAL_SEED,
886            optimizer="AdamW",  # "auto" as default
887            pretrained=True,
888            exist_ok=True,
889            amp=True,
890            device=device
891        )
892        metrics = model.val()
893        logging.info(f"Model Performance: {metrics}")
894
895        # Upload model to Hugging Face
896        if HF_DO_UPLOAD:
897            logging.info(f"Uploading model {self.model_path} to Hugging Face.")
898            api = HfApi()
899            api.create_repo(
900                self.hf_hub_model_id, private=True, repo_type="model", exist_ok=True
901            )
902            api.upload_file(
903                path_or_fileobj=self.model_path,
904                path_in_repo="best.pt",
905                repo_id=self.hf_hub_model_id,
906                repo_type="model",
907            )

Train the YOLO model for object detection using Ultralytics and optionally upload to Hugging Face.

def inference(self, gt_field='ground_truth'):
909    def inference(self, gt_field="ground_truth"):
910        """Performs inference using YOLO model on a dataset, with options to evaluate results."""
911        logging.info(f"Running inference on dataset {self.config['v51_dataset_name']}")
912        inference_settings = self.config["inference_settings"]
913
914        dataset_name = None
915        model_name = self.config["model_name"]
916
917        model_hf = inference_settings["model_hf"]
918        if model_hf is not None:
919            # Use model manually defined in config.
920            # This way models can be used for inference which were trained on a different dataset
921            dataset_name, _ = get_dataset_and_model_from_hf_id(model_hf)
922
923            # Set up directories
924            download_dir = os.path.join(
925                self.export_root, dataset_name, model_name, "weights"
926            )
927            os.makedirs(os.path.join(download_dir), exist_ok=True)
928
929            self.model_path = os.path.join(download_dir, "best.pt")
930
931            # Create directories if they don't exist
932
933            file_path = hf_hub_download(
934                repo_id=model_hf,
935                filename="best.pt",
936                local_dir=download_dir,
937            )
938        else:
939            # Automatically determine model based on dataset
940            dataset_name = self.config["v51_dataset_name"]
941
942            try:
943                if os.path.exists(self.model_path):
944                    file_path = self.model_path
945                    logging.info(f"Loading model {model_name} from disk: {file_path}")
946                else:
947                    download_dir = self.model_path.replace("best.pt", "")
948                    os.makedirs(download_dir, exist_ok=True)
949                    logging.info(
950                        f"Downloading model {self.hf_hub_model_id} from Hugging Face to {download_dir}"
951                    )
952                    file_path = hf_hub_download(
953                        repo_id=self.hf_hub_model_id,
954                        filename="best.pt",
955                        local_dir=download_dir,
956                    )
957            except Exception as e:
958                logging.error(f"Failed to load or download model: {str(e)}.")
959                return False
960
961        pred_key = f"pred_od_{model_name}-{dataset_name}"
962        logging.info(f"Using model {self.model_path} for inference.")
963        model = YOLO(self.model_path)
964
965        detection_threshold = inference_settings["detection_threshold"]
966        if inference_settings["inference_on_test"] is True:
967            dataset_eval_view = self.dataset.match_tags("test")
968            if len(dataset_eval_view) == 0:
969                logging.error("Dataset misses split 'test'")
970            dataset_eval_view.apply_model(
971                model, label_field=pred_key, confidence_thresh=detection_threshold
972            )
973        else:
974            self.dataset.apply_model(
975                model, label_field=pred_key, confidence_thresh=detection_threshold
976            )
977
978        if inference_settings["do_eval"]:
979            eval_key = f"eval_{self.config['model_name']}_{dataset_name}"
980
981            if inference_settings["inference_on_test"] is True:
982                dataset_view = self.dataset.match_tags(["test"])
983            else:
984                dataset_view = self.dataset
985
986            results = dataset_view.evaluate_detections(
987                pred_key,
988                gt_field=gt_field,
989                eval_key=eval_key,
990                compute_mAP=True,
991            )
992
993            results.print_report()

Performs inference using YOLO model on a dataset, with options to evaluate results.

def transform_batch_standalone( batch, image_processor, do_convert_annotations=True, return_pixel_mask=False):
 996def transform_batch_standalone(
 997    batch,
 998    image_processor,
 999    do_convert_annotations=True,
1000    return_pixel_mask=False,
1001):
1002    """Apply format annotations in COCO format for object detection task. Outside of class so it can be pickled."""
1003    images = []
1004    annotations = []
1005
1006    for image_path, annotation in zip(batch["image_path"], batch["objects"]):
1007        image = Image.open(image_path).convert("RGB")
1008        image_np = np.array(image)
1009        images.append(image_np)
1010
1011        coco_annotations = []
1012        for i, bbox in enumerate(annotation["bbox"]):
1013
1014            # Conversion from HF dataset bounding boxes to DETR:
1015            # Input: HF dataset bbox is COCO (top_left_x, top_left_y, width, height) in absolute coordinates
1016            # Output:
1017            # DETR expects COCO (top_left_x, top_left_y, width, height) in absolute coordinates if 'do_convert_annotations == True'
1018            # DETR expects YOLO (center_x, center_y, width, height) in relative coordinates between [0,1] if 'do_convert_annotations == False'
1019
1020            if do_convert_annotations == False:
1021                x, y, w, h = bbox
1022                img_height, img_width = image_np.shape[:2]
1023                center_x = (x + w / 2) / img_width
1024                center_y = (y + h / 2) / img_height
1025                width = w / img_width
1026                height = h / img_height
1027                bbox = [center_x, center_y, width, height]
1028
1029                # Ensure bbox values are within the expected range
1030                assert all(0 <= coord <= 1 for coord in bbox), f"Invalid bbox: {bbox}"
1031
1032                logging.debug(
1033                    f"Converted {[x, y, w, h]} to {[center_x, center_y, width, height]} with 'do_convert_annotations' = {do_convert_annotations}"
1034                )
1035
1036            coco_annotation = {
1037                "image_id": annotation["image_id"],
1038                "bbox": bbox,
1039                "category_id": annotation["category_id"][i],
1040                "area": annotation["area"][i],
1041                "iscrowd": 0,
1042            }
1043            coco_annotations.append(coco_annotation)
1044        detr_annotation = {
1045            "image_id": annotation["image_id"],
1046            "annotations": coco_annotations,
1047        }
1048        annotations.append(detr_annotation)
1049
1050        # Apply the image processor transformations: resizing, rescaling, normalization
1051        result = image_processor(
1052            images=images, annotations=annotations, return_tensors="pt"
1053        )
1054
1055    if not return_pixel_mask:
1056        result.pop("pixel_mask", None)
1057
1058    return result

Apply format annotations in COCO format for object detection task. Outside of class so it can be pickled.

class HuggingFaceObjectDetection:
1061class HuggingFaceObjectDetection:
1062    """Object detection using HuggingFace models with support for training and inference."""
1063
1064    def __init__(
1065        self,
1066        dataset,
1067        config,
1068        output_model_path="./output/models/object_detection_hf",
1069        output_detections_path="./output/detections/",
1070        gt_field="ground_truth",
1071    ):
1072        """Initialize with dataset, config, and optional output paths."""
1073        self.dataset = dataset
1074        self.config = config
1075        self.model_name = config["model_name"]
1076        self.model_name_key = re.sub(r"[\W-]+", "_", self.model_name)
1077        self.dataset_name = config["v51_dataset_name"]
1078        self.do_convert_annotations = True  # HF can convert (top_left_x, top_left_y, bottom_right_x, bottom_right_y) in abs. coordinates to (x_min, y_min, width, height) in rel. coordinates https://github.com/huggingface/transformers/blob/v4.48.2/src/transformers/models/conditional_detr/image_processing_conditional_detr.py#L1497
1079
1080        self.detections_root = os.path.join(
1081            output_detections_path, self.dataset_name, self.model_name_key
1082        )
1083
1084        self.model_root = os.path.join(
1085            output_model_path, self.dataset_name, self.model_name_key
1086        )
1087
1088        self.hf_hub_model_id = (
1089            f"{HF_ROOT}/" + f"{self.dataset_name}_{self.model_name}".replace("/", "_")
1090        )
1091
1092        self.categories = dataset.distinct(f"{gt_field}.detections.label")
1093        self.id2label = {index: x for index, x in enumerate(self.categories, start=0)}
1094        self.label2id = {v: k for k, v in self.id2label.items()}
1095
1096    def collate_fn(self, batch):
1097        """Collate function for batching data during training and inference."""
1098        data = {}
1099        data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
1100        data["labels"] = [x["labels"] for x in batch]
1101        if "pixel_mask" in batch[0]:
1102            data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
1103        return data
1104
1105    def train(self, hf_dataset, overwrite_output=True):
1106        """Train models for object detection tasks with support for custom image sizes and transformations."""
1107        torch.cuda.empty_cache()
1108        img_size_target = self.config.get("image_size", None)
1109        if img_size_target is None:
1110            image_processor = AutoProcessor.from_pretrained(
1111                self.model_name,
1112                do_resize=False,
1113                do_pad=True,
1114                use_fast=True,
1115                do_convert_annotations=self.do_convert_annotations,
1116            )
1117        else:
1118            logging.warning(f"Resizing images to target size {img_size_target}.")
1119            image_processor = AutoProcessor.from_pretrained(
1120                self.model_name,
1121                do_resize=True,
1122                size={
1123                    "max_height": img_size_target[1],
1124                    "max_width": img_size_target[0],
1125                },
1126                do_pad=True,
1127                pad_size={"height": img_size_target[1], "width": img_size_target[0]},
1128                use_fast=True,
1129                do_convert_annotations=self.do_convert_annotations,
1130            )
1131
1132        train_transform_batch = partial(
1133            transform_batch_standalone,
1134            image_processor=image_processor,
1135            do_convert_annotations=self.do_convert_annotations,
1136        )
1137        val_test_transform_batch = partial(
1138            transform_batch_standalone,
1139            image_processor=image_processor,
1140            do_convert_annotations=self.do_convert_annotations,
1141        )
1142
1143        hf_dataset[Split.TRAIN] = hf_dataset[Split.TRAIN].with_transform(
1144            train_transform_batch
1145        )
1146        hf_dataset[Split.VALIDATION] = hf_dataset[Split.VALIDATION].with_transform(
1147            val_test_transform_batch
1148        )
1149        hf_dataset[Split.TEST] = hf_dataset[Split.TEST].with_transform(
1150            val_test_transform_batch
1151        )
1152
1153        hf_model_config = AutoConfig.from_pretrained(self.model_name)
1154        hf_model_config_name = type(hf_model_config).__name__
1155
1156        if type(hf_model_config) in AutoModelForObjectDetection._model_mapping:
1157            model = AutoModelForObjectDetection.from_pretrained(
1158                self.model_name,
1159                id2label=self.id2label,
1160                label2id=self.label2id,
1161                ignore_mismatched_sizes=True,
1162            )
1163        else:
1164            model = None
1165            logging.error(
1166                "Hugging Face AutoModel does not support " + str(type(hf_model_config))
1167            )
1168
1169        if (
1170            overwrite_output == True
1171            and os.path.exists(self.model_root)
1172            and os.listdir(self.model_root)
1173        ):
1174            logging.warning(
1175                f"Training will overwrite existing results in {self.model_root}"
1176            )
1177
1178        training_args = TrainingArguments(
1179            run_name=self.model_name,
1180            output_dir=self.model_root,
1181            overwrite_output_dir=overwrite_output,
1182            num_train_epochs=self.config["epochs"],
1183            fp16=True,
1184            per_device_train_batch_size=self.config["batch_size"],
1185            auto_find_batch_size=True,
1186            dataloader_num_workers=min(self.config["n_worker_dataloader"], NUM_WORKERS),
1187            learning_rate=self.config["learning_rate"],
1188            lr_scheduler_type="cosine",
1189            weight_decay=self.config["weight_decay"],
1190            max_grad_norm=self.config["max_grad_norm"],
1191            metric_for_best_model="eval_loss",
1192            greater_is_better=False,
1193            load_best_model_at_end=True,
1194            eval_strategy="epoch",
1195            save_strategy="best",
1196            save_total_limit=1,
1197            remove_unused_columns=False,
1198            eval_do_concat_batches=False,
1199            save_safetensors=False,  # Does not work with all models
1200            hub_model_id=self.hf_hub_model_id,
1201            hub_private_repo=True,
1202            push_to_hub=HF_DO_UPLOAD,
1203            seed=GLOBAL_SEED,
1204            data_seed=GLOBAL_SEED,
1205        )
1206
1207        early_stopping_callback = EarlyStoppingCallback(
1208            early_stopping_patience=self.config["early_stop_patience"],
1209            early_stopping_threshold=self.config["early_stop_threshold"],
1210        )
1211
1212        trainer = Trainer(
1213            model=model,
1214            args=training_args,
1215            train_dataset=hf_dataset[Split.TRAIN],
1216            eval_dataset=hf_dataset[Split.VALIDATION],
1217            tokenizer=image_processor,
1218            data_collator=self.collate_fn,
1219            callbacks=[early_stopping_callback],
1220            # compute_metrics=eval_compute_metrics_fn,
1221        )
1222
1223        logging.info(f"Starting training of model {self.model_name}.")
1224        trainer.train()
1225        if HF_DO_UPLOAD:
1226            trainer.push_to_hub()
1227
1228        metrics = trainer.evaluate(eval_dataset=hf_dataset[Split.TEST])
1229        logging.info(f"Model training completed. Evaluation results: {metrics}")
1230
1231    def inference(self, inference_settings, load_from_hf=True, gt_field="ground_truth"):
1232        """Performs model inference on a dataset, loading from Hugging Face or disk, and optionally evaluates detection results."""
1233
1234        model_hf = inference_settings["model_hf"]
1235        dataset_name = None
1236        if model_hf is not None:
1237            self.hf_hub_model_id = model_hf
1238            dataset_name, model_name = get_dataset_and_model_from_hf_id(model_hf)
1239        else:
1240            dataset_name = self.dataset_name
1241        torch.cuda.empty_cache()
1242        # Load trained model from Hugging Face
1243        load_from_hf_successful = None
1244        if load_from_hf:
1245            try:
1246                logging.info(f"Loading model from Hugging Face: {self.hf_hub_model_id}")
1247                image_processor = AutoProcessor.from_pretrained(self.hf_hub_model_id)
1248                model = AutoModelForObjectDetection.from_pretrained(
1249                    self.hf_hub_model_id
1250                )
1251                load_from_hf_successful = True
1252            except Exception as e:
1253                load_from_hf_successful = False
1254                logging.warning(
1255                    f"Model {self.model_name} could not be loaded from Hugging Face {self.hf_hub_model_id}. Attempting loading from disk."
1256                )
1257        if load_from_hf == False or load_from_hf_successful == False:
1258            try:
1259                # Select folder in self.model_root that include 'checkpoint-'
1260                checkpoint_dirs = [
1261                    d
1262                    for d in os.listdir(self.model_root)
1263                    if "checkpoint-" in d
1264                    and os.path.isdir(os.path.join(self.model_root, d))
1265                ]
1266
1267                if not checkpoint_dirs:
1268                    logging.error(
1269                        f"No checkpoint directory found in {self.model_root}!"
1270                    )
1271                    model_path = None
1272                else:
1273                    # Sort by modification time (latest first)
1274                    checkpoint_dirs.sort(
1275                        key=lambda d: os.path.getmtime(
1276                            os.path.join(self.model_root, d)
1277                        ),
1278                        reverse=True,
1279                    )
1280
1281                    if len(checkpoint_dirs) > 1:
1282                        logging.warning(
1283                            f"Multiple checkpoint directories found: {checkpoint_dirs}. Selecting the latest one: {checkpoint_dirs[0]}."
1284                        )
1285
1286                    selected_checkpoint = checkpoint_dirs[0]
1287                    logging.info(
1288                        f"Loading model from disk: {self.model_root}/{selected_checkpoint}"
1289                    )
1290                    model_path = os.path.join(self.model_root, selected_checkpoint)
1291
1292                image_processor = AutoProcessor.from_pretrained(model_path)
1293                model = AutoModelForObjectDetection.from_pretrained(model_path)
1294            except Exception as e:
1295                logging.error(
1296                    f"Model {self.model_name} could not be loaded from folder {self.model_root}/{selected_checkpoint}. Inference not possible."
1297                )
1298
1299        device, _, _ = get_backend()
1300        logging.info(f"Using device {device} for inference.")
1301        model = model.to(device)
1302        model.eval()
1303
1304        pred_key = f"pred_od_{self.model_name_key}-{dataset_name}"
1305
1306        if inference_settings["inference_on_test"] is True:
1307            INFERENCE_SPLITS = ["test"]
1308            dataset_eval_view = self.dataset.match_tags(INFERENCE_SPLITS)
1309        else:
1310            dataset_eval_view = self.dataset
1311
1312        detection_threshold = inference_settings["detection_threshold"]
1313
1314        with torch.amp.autocast("cuda"), torch.inference_mode():
1315            for sample in dataset_eval_view.iter_samples(progress=True, autosave=True):
1316                image_width = sample.metadata.width
1317                image_height = sample.metadata.height
1318                img_filepath = sample.filepath
1319
1320                image = Image.open(img_filepath)
1321                inputs = image_processor(images=[image], return_tensors="pt")
1322                outputs = model(**inputs.to(device))
1323                target_sizes = torch.tensor([[image.size[1], image.size[0]]])
1324
1325                results = image_processor.post_process_object_detection(
1326                    outputs, threshold=detection_threshold, target_sizes=target_sizes
1327                )[0]
1328
1329                detections = []
1330                for score, label, box in zip(
1331                    results["scores"], results["labels"], results["boxes"]
1332                ):
1333                    # Bbox is in absolute coordinates x, y, x2, y2
1334                    box = box.tolist()
1335                    text_label = model.config.id2label[label.item()]
1336
1337                    # Voxel51 requires relative coordinates between 0 and 1
1338                    top_left_x = box[0] / image_width
1339                    top_left_y = box[1] / image_height
1340                    box_width = (box[2] - box[0]) / image_width
1341                    box_height = (box[3] - box[1]) / image_height
1342                    detection = fo.Detection(
1343                        label=text_label,
1344                        bounding_box=[
1345                            top_left_x,
1346                            top_left_y,
1347                            box_width,
1348                            box_height,
1349                        ],
1350                        confidence=score.item(),
1351                    )
1352                    detections.append(detection)
1353
1354                sample[pred_key] = fo.Detections(detections=detections)
1355
1356        if inference_settings["do_eval"] is True:
1357            eval_key = re.sub(
1358                r"[\W-]+", "_", "eval_" + self.model_name + "_" + self.dataset_name
1359            )
1360
1361            if inference_settings["inference_on_test"] is True:
1362                dataset_view = self.dataset.match_tags(["test"])
1363            else:
1364                dataset_view = self.dataset
1365
1366            results = dataset_view.evaluate_detections(
1367                pred_key,
1368                gt_field=gt_field,
1369                eval_key=eval_key,
1370                compute_mAP=True,
1371            )
1372
1373            results.print_report()

Object detection using HuggingFace models with support for training and inference.

HuggingFaceObjectDetection( dataset, config, output_model_path='./output/models/object_detection_hf', output_detections_path='./output/detections/', gt_field='ground_truth')
1064    def __init__(
1065        self,
1066        dataset,
1067        config,
1068        output_model_path="./output/models/object_detection_hf",
1069        output_detections_path="./output/detections/",
1070        gt_field="ground_truth",
1071    ):
1072        """Initialize with dataset, config, and optional output paths."""
1073        self.dataset = dataset
1074        self.config = config
1075        self.model_name = config["model_name"]
1076        self.model_name_key = re.sub(r"[\W-]+", "_", self.model_name)
1077        self.dataset_name = config["v51_dataset_name"]
1078        self.do_convert_annotations = True  # HF can convert (top_left_x, top_left_y, bottom_right_x, bottom_right_y) in abs. coordinates to (x_min, y_min, width, height) in rel. coordinates https://github.com/huggingface/transformers/blob/v4.48.2/src/transformers/models/conditional_detr/image_processing_conditional_detr.py#L1497
1079
1080        self.detections_root = os.path.join(
1081            output_detections_path, self.dataset_name, self.model_name_key
1082        )
1083
1084        self.model_root = os.path.join(
1085            output_model_path, self.dataset_name, self.model_name_key
1086        )
1087
1088        self.hf_hub_model_id = (
1089            f"{HF_ROOT}/" + f"{self.dataset_name}_{self.model_name}".replace("/", "_")
1090        )
1091
1092        self.categories = dataset.distinct(f"{gt_field}.detections.label")
1093        self.id2label = {index: x for index, x in enumerate(self.categories, start=0)}
1094        self.label2id = {v: k for k, v in self.id2label.items()}

Initialize with dataset, config, and optional output paths.

dataset
config
model_name
model_name_key
dataset_name
do_convert_annotations
detections_root
model_root
hf_hub_model_id
categories
id2label
label2id
def collate_fn(self, batch):
1096    def collate_fn(self, batch):
1097        """Collate function for batching data during training and inference."""
1098        data = {}
1099        data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
1100        data["labels"] = [x["labels"] for x in batch]
1101        if "pixel_mask" in batch[0]:
1102            data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
1103        return data

Collate function for batching data during training and inference.

def train(self, hf_dataset, overwrite_output=True):
1105    def train(self, hf_dataset, overwrite_output=True):
1106        """Train models for object detection tasks with support for custom image sizes and transformations."""
1107        torch.cuda.empty_cache()
1108        img_size_target = self.config.get("image_size", None)
1109        if img_size_target is None:
1110            image_processor = AutoProcessor.from_pretrained(
1111                self.model_name,
1112                do_resize=False,
1113                do_pad=True,
1114                use_fast=True,
1115                do_convert_annotations=self.do_convert_annotations,
1116            )
1117        else:
1118            logging.warning(f"Resizing images to target size {img_size_target}.")
1119            image_processor = AutoProcessor.from_pretrained(
1120                self.model_name,
1121                do_resize=True,
1122                size={
1123                    "max_height": img_size_target[1],
1124                    "max_width": img_size_target[0],
1125                },
1126                do_pad=True,
1127                pad_size={"height": img_size_target[1], "width": img_size_target[0]},
1128                use_fast=True,
1129                do_convert_annotations=self.do_convert_annotations,
1130            )
1131
1132        train_transform_batch = partial(
1133            transform_batch_standalone,
1134            image_processor=image_processor,
1135            do_convert_annotations=self.do_convert_annotations,
1136        )
1137        val_test_transform_batch = partial(
1138            transform_batch_standalone,
1139            image_processor=image_processor,
1140            do_convert_annotations=self.do_convert_annotations,
1141        )
1142
1143        hf_dataset[Split.TRAIN] = hf_dataset[Split.TRAIN].with_transform(
1144            train_transform_batch
1145        )
1146        hf_dataset[Split.VALIDATION] = hf_dataset[Split.VALIDATION].with_transform(
1147            val_test_transform_batch
1148        )
1149        hf_dataset[Split.TEST] = hf_dataset[Split.TEST].with_transform(
1150            val_test_transform_batch
1151        )
1152
1153        hf_model_config = AutoConfig.from_pretrained(self.model_name)
1154        hf_model_config_name = type(hf_model_config).__name__
1155
1156        if type(hf_model_config) in AutoModelForObjectDetection._model_mapping:
1157            model = AutoModelForObjectDetection.from_pretrained(
1158                self.model_name,
1159                id2label=self.id2label,
1160                label2id=self.label2id,
1161                ignore_mismatched_sizes=True,
1162            )
1163        else:
1164            model = None
1165            logging.error(
1166                "Hugging Face AutoModel does not support " + str(type(hf_model_config))
1167            )
1168
1169        if (
1170            overwrite_output == True
1171            and os.path.exists(self.model_root)
1172            and os.listdir(self.model_root)
1173        ):
1174            logging.warning(
1175                f"Training will overwrite existing results in {self.model_root}"
1176            )
1177
1178        training_args = TrainingArguments(
1179            run_name=self.model_name,
1180            output_dir=self.model_root,
1181            overwrite_output_dir=overwrite_output,
1182            num_train_epochs=self.config["epochs"],
1183            fp16=True,
1184            per_device_train_batch_size=self.config["batch_size"],
1185            auto_find_batch_size=True,
1186            dataloader_num_workers=min(self.config["n_worker_dataloader"], NUM_WORKERS),
1187            learning_rate=self.config["learning_rate"],
1188            lr_scheduler_type="cosine",
1189            weight_decay=self.config["weight_decay"],
1190            max_grad_norm=self.config["max_grad_norm"],
1191            metric_for_best_model="eval_loss",
1192            greater_is_better=False,
1193            load_best_model_at_end=True,
1194            eval_strategy="epoch",
1195            save_strategy="best",
1196            save_total_limit=1,
1197            remove_unused_columns=False,
1198            eval_do_concat_batches=False,
1199            save_safetensors=False,  # Does not work with all models
1200            hub_model_id=self.hf_hub_model_id,
1201            hub_private_repo=True,
1202            push_to_hub=HF_DO_UPLOAD,
1203            seed=GLOBAL_SEED,
1204            data_seed=GLOBAL_SEED,
1205        )
1206
1207        early_stopping_callback = EarlyStoppingCallback(
1208            early_stopping_patience=self.config["early_stop_patience"],
1209            early_stopping_threshold=self.config["early_stop_threshold"],
1210        )
1211
1212        trainer = Trainer(
1213            model=model,
1214            args=training_args,
1215            train_dataset=hf_dataset[Split.TRAIN],
1216            eval_dataset=hf_dataset[Split.VALIDATION],
1217            tokenizer=image_processor,
1218            data_collator=self.collate_fn,
1219            callbacks=[early_stopping_callback],
1220            # compute_metrics=eval_compute_metrics_fn,
1221        )
1222
1223        logging.info(f"Starting training of model {self.model_name}.")
1224        trainer.train()
1225        if HF_DO_UPLOAD:
1226            trainer.push_to_hub()
1227
1228        metrics = trainer.evaluate(eval_dataset=hf_dataset[Split.TEST])
1229        logging.info(f"Model training completed. Evaluation results: {metrics}")

Train models for object detection tasks with support for custom image sizes and transformations.

def inference(self, inference_settings, load_from_hf=True, gt_field='ground_truth'):
1231    def inference(self, inference_settings, load_from_hf=True, gt_field="ground_truth"):
1232        """Performs model inference on a dataset, loading from Hugging Face or disk, and optionally evaluates detection results."""
1233
1234        model_hf = inference_settings["model_hf"]
1235        dataset_name = None
1236        if model_hf is not None:
1237            self.hf_hub_model_id = model_hf
1238            dataset_name, model_name = get_dataset_and_model_from_hf_id(model_hf)
1239        else:
1240            dataset_name = self.dataset_name
1241        torch.cuda.empty_cache()
1242        # Load trained model from Hugging Face
1243        load_from_hf_successful = None
1244        if load_from_hf:
1245            try:
1246                logging.info(f"Loading model from Hugging Face: {self.hf_hub_model_id}")
1247                image_processor = AutoProcessor.from_pretrained(self.hf_hub_model_id)
1248                model = AutoModelForObjectDetection.from_pretrained(
1249                    self.hf_hub_model_id
1250                )
1251                load_from_hf_successful = True
1252            except Exception as e:
1253                load_from_hf_successful = False
1254                logging.warning(
1255                    f"Model {self.model_name} could not be loaded from Hugging Face {self.hf_hub_model_id}. Attempting loading from disk."
1256                )
1257        if load_from_hf == False or load_from_hf_successful == False:
1258            try:
1259                # Select folder in self.model_root that include 'checkpoint-'
1260                checkpoint_dirs = [
1261                    d
1262                    for d in os.listdir(self.model_root)
1263                    if "checkpoint-" in d
1264                    and os.path.isdir(os.path.join(self.model_root, d))
1265                ]
1266
1267                if not checkpoint_dirs:
1268                    logging.error(
1269                        f"No checkpoint directory found in {self.model_root}!"
1270                    )
1271                    model_path = None
1272                else:
1273                    # Sort by modification time (latest first)
1274                    checkpoint_dirs.sort(
1275                        key=lambda d: os.path.getmtime(
1276                            os.path.join(self.model_root, d)
1277                        ),
1278                        reverse=True,
1279                    )
1280
1281                    if len(checkpoint_dirs) > 1:
1282                        logging.warning(
1283                            f"Multiple checkpoint directories found: {checkpoint_dirs}. Selecting the latest one: {checkpoint_dirs[0]}."
1284                        )
1285
1286                    selected_checkpoint = checkpoint_dirs[0]
1287                    logging.info(
1288                        f"Loading model from disk: {self.model_root}/{selected_checkpoint}"
1289                    )
1290                    model_path = os.path.join(self.model_root, selected_checkpoint)
1291
1292                image_processor = AutoProcessor.from_pretrained(model_path)
1293                model = AutoModelForObjectDetection.from_pretrained(model_path)
1294            except Exception as e:
1295                logging.error(
1296                    f"Model {self.model_name} could not be loaded from folder {self.model_root}/{selected_checkpoint}. Inference not possible."
1297                )
1298
1299        device, _, _ = get_backend()
1300        logging.info(f"Using device {device} for inference.")
1301        model = model.to(device)
1302        model.eval()
1303
1304        pred_key = f"pred_od_{self.model_name_key}-{dataset_name}"
1305
1306        if inference_settings["inference_on_test"] is True:
1307            INFERENCE_SPLITS = ["test"]
1308            dataset_eval_view = self.dataset.match_tags(INFERENCE_SPLITS)
1309        else:
1310            dataset_eval_view = self.dataset
1311
1312        detection_threshold = inference_settings["detection_threshold"]
1313
1314        with torch.amp.autocast("cuda"), torch.inference_mode():
1315            for sample in dataset_eval_view.iter_samples(progress=True, autosave=True):
1316                image_width = sample.metadata.width
1317                image_height = sample.metadata.height
1318                img_filepath = sample.filepath
1319
1320                image = Image.open(img_filepath)
1321                inputs = image_processor(images=[image], return_tensors="pt")
1322                outputs = model(**inputs.to(device))
1323                target_sizes = torch.tensor([[image.size[1], image.size[0]]])
1324
1325                results = image_processor.post_process_object_detection(
1326                    outputs, threshold=detection_threshold, target_sizes=target_sizes
1327                )[0]
1328
1329                detections = []
1330                for score, label, box in zip(
1331                    results["scores"], results["labels"], results["boxes"]
1332                ):
1333                    # Bbox is in absolute coordinates x, y, x2, y2
1334                    box = box.tolist()
1335                    text_label = model.config.id2label[label.item()]
1336
1337                    # Voxel51 requires relative coordinates between 0 and 1
1338                    top_left_x = box[0] / image_width
1339                    top_left_y = box[1] / image_height
1340                    box_width = (box[2] - box[0]) / image_width
1341                    box_height = (box[3] - box[1]) / image_height
1342                    detection = fo.Detection(
1343                        label=text_label,
1344                        bounding_box=[
1345                            top_left_x,
1346                            top_left_y,
1347                            box_width,
1348                            box_height,
1349                        ],
1350                        confidence=score.item(),
1351                    )
1352                    detections.append(detection)
1353
1354                sample[pred_key] = fo.Detections(detections=detections)
1355
1356        if inference_settings["do_eval"] is True:
1357            eval_key = re.sub(
1358                r"[\W-]+", "_", "eval_" + self.model_name + "_" + self.dataset_name
1359            )
1360
1361            if inference_settings["inference_on_test"] is True:
1362                dataset_view = self.dataset.match_tags(["test"])
1363            else:
1364                dataset_view = self.dataset
1365
1366            results = dataset_view.evaluate_detections(
1367                pred_key,
1368                gt_field=gt_field,
1369                eval_key=eval_key,
1370                compute_mAP=True,
1371            )
1372
1373            results.print_report()

Performs model inference on a dataset, loading from Hugging Face or disk, and optionally evaluates detection results.

class CustomCoDETRObjectDetection:
1376class CustomCoDETRObjectDetection:
1377    """Interface for running Co-DETR object detection model training and inference in containers"""
1378
1379    def __init__(self, dataset, dataset_info, run_config):
1380        """Initialize Co-DETR interface with dataset and configuration"""
1381        self.root_codetr = "./custom_models/CoDETR/Co-DETR"
1382        self.root_codetr_models = "output/models/codetr"
1383        self.dataset = dataset
1384        self.dataset_name = dataset_info["name"]
1385        self.export_dir_root = run_config["export_dataset_root"]
1386        self.config_key = os.path.splitext(os.path.basename(run_config["config"]))[0]
1387        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_{self.config_key}"
1388
1389    def convert_data(self):
1390        """Convert dataset to COCO format required by Co-DETR"""
1391
1392        export_dir = os.path.join(self.export_dir_root, self.dataset_name, "coco")
1393
1394        # Check if folder already exists
1395        if not os.path.exists(export_dir):
1396            # Make directory
1397            os.makedirs(export_dir, exist_ok=True)
1398            logging.info(f"Exporting data to {export_dir}")
1399            splits = [
1400                "train",
1401                "val",
1402                "test",
1403            ]  # CoDETR expects data in 'train' and 'val' folder
1404            for split in splits:
1405                split_view = self.dataset.match_tags(split)
1406                split_view.export(
1407                    dataset_type=fo.types.COCODetectionDataset,
1408                    data_path=os.path.join(export_dir, f"{split}2017"),
1409                    labels_path=os.path.join(
1410                        export_dir, "annotations", f"instances_{split}2017.json"
1411                    ),
1412                    label_field="ground_truth",
1413                )
1414        else:
1415            logging.warning(
1416                f"Folder {export_dir} already exists, skipping data export."
1417            )
1418
1419    def update_config_file(self, dataset_name, config_file, max_epochs):
1420        """Update Co-DETR config file with dataset-specific parameters"""
1421
1422        config_path = os.path.join(self.root_codetr, config_file)
1423
1424        # Get classes from exported data
1425        annotations_json = os.path.join(
1426            self.export_dir_root,
1427            dataset_name,
1428            "coco/annotations/instances_train2017.json",
1429        )
1430        # Read the JSON file
1431        with open(annotations_json, "r") as file:
1432            data = json.load(file)
1433
1434        # Extract the value associated with the key "categories"
1435        categories = data.get("categories")
1436        class_names = tuple(category["name"] for category in categories)
1437        num_classes = len(class_names)
1438
1439        # Update configuration file
1440        # This assumes that 'classes = '('a','b',...)' are already defined and will be overwritten.
1441        with open(config_path, "r") as file:
1442            content = file.read()
1443
1444        # Update the classes tuple
1445        content = re.sub(r"classes\s*=\s*\(.*?\)", f"classes = {class_names}", content)
1446
1447        # Update all instances of num_classes
1448        content = re.sub(r"num_classes=\d+", f"num_classes={num_classes}", content)
1449
1450        # Update all instances of max_epochs
1451        content = re.sub(r"max_epochs=\d+", f"max_epochs={max_epochs}", content)
1452
1453        with open(config_path, "w") as file:
1454            file.write(content)
1455
1456        logging.warning(
1457            f"Updated {config_path} with classes={class_names} and num_classes={num_classes} and max_epochs={max_epochs}"
1458        )
1459
1460    def train(self, param_config, param_n_gpus, container_tool, param_function="train"):
1461        """Train Co-DETR model using containerized environment"""
1462
1463        # Check if model already exists
1464        output_folder_codetr = os.path.join(self.root_codetr, "output")
1465        os.makedirs(output_folder_codetr, exist_ok=True)
1466        param_config_name = os.path.splitext(os.path.basename(param_config))[0]
1467        best_models_dir = os.path.join(output_folder_codetr, "best")
1468        os.makedirs(best_models_dir, exist_ok=True)
1469        # Best model files follow the naming scheme "config_dataset.pth"
1470        pth_model_files = (
1471            [f for f in os.listdir(best_models_dir) if f.endswith(".pth")]
1472            if os.path.exists(best_models_dir) and os.path.isdir(best_models_dir)
1473            else []
1474        )
1475
1476        # Best model files are stored in the format "config_dataset.pth"
1477        matching_files = [
1478            f
1479            for f in pth_model_files
1480            if f.startswith(param_config_name)
1481            and self.dataset_name in f
1482            and f.endswith(".pth")
1483        ]
1484        if len(matching_files) > 0:
1485            logging.warning(
1486                f"Model {param_config_name} already trained on dataset {self.dataset_name}. Skipping training."
1487            )
1488            if len(matching_files) > 1:
1489                logging.warning(f"Multiple weights found: {matching_files}")
1490        else:
1491            logging.info(
1492                f"Launching training for Co-DETR config {param_config} and dataset {self.dataset_name}."
1493            )
1494            volume_data = os.path.join(self.export_dir_root, self.dataset_name)
1495
1496            # Train model, store checkpoints in 'output_folder_codetr'
1497            train_result = self._run_container(
1498                volume_data=volume_data,
1499                param_function=param_function,
1500                param_config=param_config,
1501                param_n_gpus=param_n_gpus,
1502                container_tool=container_tool,
1503            )
1504
1505            # Find the best_bbox checkpoint file
1506            checkpoint_files = [
1507                f
1508                for f in os.listdir(output_folder_codetr)
1509                if "best_bbox" in f and f.endswith(".pth")
1510            ]
1511            if not checkpoint_files:
1512                logging.error(
1513                    "Co-DETR was not trained, model pth file missing. No checkpoint file with 'best_bbox' found."
1514                )
1515            else:
1516                if len(checkpoint_files) > 1:
1517                    logging.warning(
1518                        f"Found {len(checkpoint_files)} checkpoint files. Selecting {checkpoint_files[0]}."
1519                    )
1520                checkpoint = checkpoint_files[0]
1521                checkpoint_path = os.path.join(output_folder_codetr, checkpoint)
1522                logging.info("Co-DETR was trained successfully.")
1523
1524                # Upload best model to Hugging Face
1525                if HF_DO_UPLOAD == True:
1526                    logging.info("Uploading Co-DETR model to Hugging Face.")
1527                    api = HfApi()
1528                    api.create_repo(
1529                        self.hf_repo_name,
1530                        private=True,
1531                        repo_type="model",
1532                        exist_ok=True,
1533                    )
1534                    api.upload_file(
1535                        path_or_fileobj=checkpoint_path,
1536                        path_in_repo="model.pth",
1537                        repo_id=self.hf_repo_name,
1538                        repo_type="model",
1539                    )
1540
1541                # Move best model file and clear output folder
1542                self._run_container(
1543                    volume_data=volume_data,
1544                    param_function="clear-output",
1545                    param_config=param_config,
1546                    param_dataset_name=self.dataset_name,
1547                    container_tool=container_tool,
1548                )
1549
1550    @staticmethod
1551    def _find_file_iteratively(start_path, filename):
1552        """Direct access or recursively search for a file in a directory structure."""
1553        # Convert start_path to a Path object
1554        start_path = Path(start_path)
1555
1556        # Check if the file exists in the start_path directly (very fast)
1557        file_path = start_path / filename
1558        if file_path.exists():
1559            return str(file_path)
1560
1561        # Start with the highest directory and go up iteratively
1562        current_dir = start_path
1563        checked_dirs = set()
1564
1565        while current_dir != current_dir.root:
1566            # Check if the file is in the current directory
1567            file_path = current_dir / filename
1568            if file_path.exists():
1569                return str(file_path)
1570
1571            # If we haven't checked the sibling directories, check them as well
1572            parent_dir = current_dir.parent
1573            if parent_dir not in checked_dirs:
1574                # Check sibling directories
1575                for sibling in parent_dir.iterdir():
1576                    if sibling != current_dir and sibling.is_dir():
1577                        sibling_file_path = sibling / filename
1578                        if sibling_file_path.exists():
1579                            return str(sibling_file_path)
1580                checked_dirs.add(parent_dir)
1581
1582            # Otherwise, go one level up
1583            current_dir = current_dir.parent
1584
1585        # If file is not found after traversing all levels, return None
1586        logging.error(f"File {filename} could not be found.")
1587        return None
1588
1589    def run_inference(
1590        self,
1591        dataset,
1592        param_config,
1593        param_n_gpus,
1594        container_tool,
1595        inference_settings,
1596        param_function="inference",
1597        inference_output_folder="custom_models/CoDETR/Co-DETR/output/inference/",
1598        gt_field="ground_truth",
1599    ):
1600        """Run inference using trained Co-DETR model and convert results to FiftyOne format"""
1601
1602        logging.info(f"Launching inference for Co-DETR config {param_config}.")
1603        volume_data = os.path.join(self.export_dir_root, self.dataset_name)
1604
1605        if inference_settings["inference_on_test"] is True:
1606            folder_inference = os.path.join("coco", "test2017")
1607        else:
1608            folder_inference = os.path.join("coco")
1609
1610        # Get model from Hugging Face
1611        dataset_name = None
1612        config_key = None
1613        try:
1614            if inference_settings["model_hf"] is None:
1615                hf_path = self.hf_repo_name
1616            else:
1617                hf_path = inference_settings["model_hf"]
1618
1619            dataset_name, config_key = get_dataset_and_model_from_hf_id(hf_path)
1620
1621            download_folder = os.path.join(
1622                self.root_codetr_models, dataset_name, config_key
1623            )
1624
1625            logging.info(
1626                f"Downloading model {hf_path} from Hugging Face into {download_folder}"
1627            )
1628            os.makedirs(download_folder, exist_ok=True)
1629
1630            file_path = hf_hub_download(
1631                repo_id=hf_path,
1632                filename="model.pth",
1633                local_dir=download_folder,
1634            )
1635        except Exception as e:
1636            logging.error(f"An error occured during model download: {e}")
1637
1638        model_path = os.path.join(dataset_name, config_key, "model.pth")
1639        logging.info(f"Starting inference for model {model_path}")
1640
1641        inference_result = self._run_container(
1642            volume_data=volume_data,
1643            param_function=param_function,
1644            param_config=param_config,
1645            param_n_gpus=param_n_gpus,
1646            container_tool=container_tool,
1647            param_inference_dataset_folder=folder_inference,
1648            param_inference_model_checkpoint=model_path,
1649        )
1650
1651        # Convert results from JSON output into V51 dataset
1652        # Files follow format inference_results_{timestamp}.json (run_inference.py)
1653        os.makedirs(inference_output_folder, exist_ok=True)
1654        output_files = [
1655            f
1656            for f in os.listdir(inference_output_folder)
1657            if f.startswith("inference_results_") and f.endswith(".json")
1658        ]
1659        logging.debug(f"Found files with inference content: {output_files}")
1660
1661        if not output_files:
1662            logging.error(
1663                f"No inference result files found in {inference_output_folder}"
1664            )
1665
1666        # Get full path for each file
1667        file_paths = [os.path.join(inference_output_folder, f) for f in output_files]
1668
1669        # Extract timestamp from the filename and sort based on the timestamp
1670        file_paths_sorted = sorted(
1671            file_paths,
1672            key=lambda f: datetime.datetime.strptime(
1673                f.split("_")[-2] + "_" + f.split("_")[-1].replace(".json", ""),
1674                "%Y%m%d_%H%M%S",
1675            ),
1676            reverse=True,
1677        )
1678
1679        # Use the most recent file based on timestamp
1680        latest_file = file_paths_sorted[0]
1681        logging.info(f"Using inference results from: {latest_file}")
1682        with open(latest_file, "r") as file:
1683            data = json.load(file)
1684
1685        # Get conversion for annotated classes
1686        annotations_path = os.path.join(
1687            volume_data, "coco", "annotations", "instances_train2017.json"
1688        )
1689
1690        with open(annotations_path, "r") as file:
1691            data_annotations = json.load(file)
1692
1693        class_ids_and_names = [
1694            (category["id"], category["name"])
1695            for category in data_annotations["categories"]
1696        ]
1697
1698        # Match sample filepaths (from exported Co-DETR COCO format) to V51 filepaths
1699        sample = dataset.first()
1700        root_dir_samples = sample.filepath
1701
1702        # Convert results into V51 file format
1703        detection_threshold = inference_settings["detection_threshold"]
1704        pred_key = f"pred_od_{config_key}-{dataset_name}"
1705        for key, value in tqdm(data.items(), desc="Processing Co-DETR detection"):
1706            try:
1707                # Get filename
1708                filepath = CustomCoDETRObjectDetection._find_file_iteratively(
1709                    root_dir_samples, os.path.basename(key)
1710                )
1711                sample = dataset[filepath]
1712
1713                img_width = sample.metadata.width
1714                img_height = sample.metadata.height
1715
1716                detections_v51 = []
1717                for class_id, class_detections in enumerate(data[key]):  # Starts with 0
1718                    if len(class_detections) > 0:
1719                        objects_class = class_ids_and_names[class_id]
1720                        for detection in class_detections:
1721                            confidence = detection[4]
1722                            detection_v51 = fo.Detection(
1723                                label=objects_class[1],
1724                                bounding_box=[
1725                                    detection[0] / img_width,
1726                                    detection[1] / img_height,
1727                                    (detection[2] - detection[0]) / img_width,
1728                                    (detection[3] - detection[1]) / img_height,
1729                                ],
1730                                confidence=confidence,
1731                            )
1732                            if confidence >= detection_threshold:
1733                                detections_v51.append(detection_v51)
1734
1735                sample[pred_key] = fo.Detections(detections=detections_v51)
1736                sample.save()
1737            except Exception as e:
1738                logging.error(
1739                    f"An error occured during the conversion of Co-DETR inference results to the V51 dataset: {e}"
1740                )
1741
1742        # Run V51 evaluation
1743        if inference_settings["do_eval"] is True:
1744            eval_key = pred_key.replace("pred_", "eval_").replace("-", "_")
1745
1746            if inference_settings["inference_on_test"] is True:
1747                dataset_view = dataset.match_tags(["test"])
1748            else:
1749                dataset_view = dataset
1750
1751            logging.info(
1752                f"Starting evaluation for {pred_key} in evaluation key {eval_key}."
1753            )
1754
1755            results = dataset_view.evaluate_detections(
1756                pred_key,
1757                gt_field=gt_field,
1758                eval_key=eval_key,
1759                compute_mAP=True,
1760            )
1761
1762            results.print_report()
1763
1764    def _run_container(
1765        self,
1766        volume_data,
1767        param_function,
1768        param_config="",
1769        param_n_gpus="1",
1770        param_dataset_name="",
1771        param_inference_dataset_folder="",
1772        param_inference_model_checkpoint="",
1773        image="dbogdollresearch/codetr",
1774        workdir="/launch",
1775        container_tool="docker",
1776    ):
1777        """Execute Co-DETR container with specified parameters using Docker or Singularity"""
1778
1779        try:
1780            # Convert relative paths to absolute paths (necessary under WSL2)
1781            root_codetr_abs = os.path.abspath(self.root_codetr)
1782            volume_data_abs = os.path.abspath(volume_data)
1783            root_codetr_models_abs = os.path.abspath(self.root_codetr_models)
1784
1785            # Check if using Docker or Singularity and define the appropriate command
1786            if container_tool == "docker":
1787                command = [
1788                    "docker",
1789                    "run",
1790                    "--gpus",
1791                    "all",
1792                    "--workdir",
1793                    workdir,
1794                    "--volume",
1795                    f"{root_codetr_abs}:{workdir}",
1796                    "--volume",
1797                    f"{volume_data_abs}:{workdir}/data:ro",
1798                    "--volume",
1799                    f"{root_codetr_models_abs}:{workdir}/hf_models:ro",
1800                    "--shm-size=8g",
1801                    image,
1802                    param_function,
1803                    param_config,
1804                    param_n_gpus,
1805                    param_dataset_name,
1806                    param_inference_dataset_folder,
1807                    param_inference_model_checkpoint,
1808                ]
1809            elif container_tool == "singularity":
1810                command = [
1811                    "singularity",
1812                    "run",
1813                    "--nv",
1814                    "--pwd",
1815                    workdir,
1816                    "--bind",
1817                    f"{self.root_codetr}:{workdir}",
1818                    "--bind",
1819                    f"{volume_data}:{workdir}/data:ro",
1820                    "--bind",
1821                    f"{self.root_codetr_models}:{workdir}/hf_models:ro",
1822                    f"docker://{image}",
1823                    param_function,
1824                    param_config,
1825                    param_n_gpus,
1826                    param_dataset_name,
1827                    param_inference_dataset_folder,
1828                    param_inference_model_checkpoint,
1829                ]
1830            else:
1831                raise ValueError(
1832                    f"Invalid container tool specified: {container_tool}. Choose 'docker' or 'singularity'."
1833                )
1834
1835            # Start the process and stream outputs to the console
1836            logging.info(f"Launching terminal command {command}")
1837            with subprocess.Popen(
1838                command, stdout=sys.stdout, stderr=sys.stderr, text=True
1839            ) as proc:
1840                proc.wait()  # Wait for the process to complete
1841            return True
1842        except Exception as e:
1843            logging.error(f"Error during Co-DETR container run: {e}")
1844            return False

Interface for running Co-DETR object detection model training and inference in containers

CustomCoDETRObjectDetection(dataset, dataset_info, run_config)
1379    def __init__(self, dataset, dataset_info, run_config):
1380        """Initialize Co-DETR interface with dataset and configuration"""
1381        self.root_codetr = "./custom_models/CoDETR/Co-DETR"
1382        self.root_codetr_models = "output/models/codetr"
1383        self.dataset = dataset
1384        self.dataset_name = dataset_info["name"]
1385        self.export_dir_root = run_config["export_dataset_root"]
1386        self.config_key = os.path.splitext(os.path.basename(run_config["config"]))[0]
1387        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_{self.config_key}"

Initialize Co-DETR interface with dataset and configuration

root_codetr
root_codetr_models
dataset
dataset_name
export_dir_root
config_key
hf_repo_name
def convert_data(self):
1389    def convert_data(self):
1390        """Convert dataset to COCO format required by Co-DETR"""
1391
1392        export_dir = os.path.join(self.export_dir_root, self.dataset_name, "coco")
1393
1394        # Check if folder already exists
1395        if not os.path.exists(export_dir):
1396            # Make directory
1397            os.makedirs(export_dir, exist_ok=True)
1398            logging.info(f"Exporting data to {export_dir}")
1399            splits = [
1400                "train",
1401                "val",
1402                "test",
1403            ]  # CoDETR expects data in 'train' and 'val' folder
1404            for split in splits:
1405                split_view = self.dataset.match_tags(split)
1406                split_view.export(
1407                    dataset_type=fo.types.COCODetectionDataset,
1408                    data_path=os.path.join(export_dir, f"{split}2017"),
1409                    labels_path=os.path.join(
1410                        export_dir, "annotations", f"instances_{split}2017.json"
1411                    ),
1412                    label_field="ground_truth",
1413                )
1414        else:
1415            logging.warning(
1416                f"Folder {export_dir} already exists, skipping data export."
1417            )

Convert dataset to COCO format required by Co-DETR

def update_config_file(self, dataset_name, config_file, max_epochs):
1419    def update_config_file(self, dataset_name, config_file, max_epochs):
1420        """Update Co-DETR config file with dataset-specific parameters"""
1421
1422        config_path = os.path.join(self.root_codetr, config_file)
1423
1424        # Get classes from exported data
1425        annotations_json = os.path.join(
1426            self.export_dir_root,
1427            dataset_name,
1428            "coco/annotations/instances_train2017.json",
1429        )
1430        # Read the JSON file
1431        with open(annotations_json, "r") as file:
1432            data = json.load(file)
1433
1434        # Extract the value associated with the key "categories"
1435        categories = data.get("categories")
1436        class_names = tuple(category["name"] for category in categories)
1437        num_classes = len(class_names)
1438
1439        # Update configuration file
1440        # This assumes that 'classes = '('a','b',...)' are already defined and will be overwritten.
1441        with open(config_path, "r") as file:
1442            content = file.read()
1443
1444        # Update the classes tuple
1445        content = re.sub(r"classes\s*=\s*\(.*?\)", f"classes = {class_names}", content)
1446
1447        # Update all instances of num_classes
1448        content = re.sub(r"num_classes=\d+", f"num_classes={num_classes}", content)
1449
1450        # Update all instances of max_epochs
1451        content = re.sub(r"max_epochs=\d+", f"max_epochs={max_epochs}", content)
1452
1453        with open(config_path, "w") as file:
1454            file.write(content)
1455
1456        logging.warning(
1457            f"Updated {config_path} with classes={class_names} and num_classes={num_classes} and max_epochs={max_epochs}"
1458        )

Update Co-DETR config file with dataset-specific parameters

def train( self, param_config, param_n_gpus, container_tool, param_function='train'):
1460    def train(self, param_config, param_n_gpus, container_tool, param_function="train"):
1461        """Train Co-DETR model using containerized environment"""
1462
1463        # Check if model already exists
1464        output_folder_codetr = os.path.join(self.root_codetr, "output")
1465        os.makedirs(output_folder_codetr, exist_ok=True)
1466        param_config_name = os.path.splitext(os.path.basename(param_config))[0]
1467        best_models_dir = os.path.join(output_folder_codetr, "best")
1468        os.makedirs(best_models_dir, exist_ok=True)
1469        # Best model files follow the naming scheme "config_dataset.pth"
1470        pth_model_files = (
1471            [f for f in os.listdir(best_models_dir) if f.endswith(".pth")]
1472            if os.path.exists(best_models_dir) and os.path.isdir(best_models_dir)
1473            else []
1474        )
1475
1476        # Best model files are stored in the format "config_dataset.pth"
1477        matching_files = [
1478            f
1479            for f in pth_model_files
1480            if f.startswith(param_config_name)
1481            and self.dataset_name in f
1482            and f.endswith(".pth")
1483        ]
1484        if len(matching_files) > 0:
1485            logging.warning(
1486                f"Model {param_config_name} already trained on dataset {self.dataset_name}. Skipping training."
1487            )
1488            if len(matching_files) > 1:
1489                logging.warning(f"Multiple weights found: {matching_files}")
1490        else:
1491            logging.info(
1492                f"Launching training for Co-DETR config {param_config} and dataset {self.dataset_name}."
1493            )
1494            volume_data = os.path.join(self.export_dir_root, self.dataset_name)
1495
1496            # Train model, store checkpoints in 'output_folder_codetr'
1497            train_result = self._run_container(
1498                volume_data=volume_data,
1499                param_function=param_function,
1500                param_config=param_config,
1501                param_n_gpus=param_n_gpus,
1502                container_tool=container_tool,
1503            )
1504
1505            # Find the best_bbox checkpoint file
1506            checkpoint_files = [
1507                f
1508                for f in os.listdir(output_folder_codetr)
1509                if "best_bbox" in f and f.endswith(".pth")
1510            ]
1511            if not checkpoint_files:
1512                logging.error(
1513                    "Co-DETR was not trained, model pth file missing. No checkpoint file with 'best_bbox' found."
1514                )
1515            else:
1516                if len(checkpoint_files) > 1:
1517                    logging.warning(
1518                        f"Found {len(checkpoint_files)} checkpoint files. Selecting {checkpoint_files[0]}."
1519                    )
1520                checkpoint = checkpoint_files[0]
1521                checkpoint_path = os.path.join(output_folder_codetr, checkpoint)
1522                logging.info("Co-DETR was trained successfully.")
1523
1524                # Upload best model to Hugging Face
1525                if HF_DO_UPLOAD == True:
1526                    logging.info("Uploading Co-DETR model to Hugging Face.")
1527                    api = HfApi()
1528                    api.create_repo(
1529                        self.hf_repo_name,
1530                        private=True,
1531                        repo_type="model",
1532                        exist_ok=True,
1533                    )
1534                    api.upload_file(
1535                        path_or_fileobj=checkpoint_path,
1536                        path_in_repo="model.pth",
1537                        repo_id=self.hf_repo_name,
1538                        repo_type="model",
1539                    )
1540
1541                # Move best model file and clear output folder
1542                self._run_container(
1543                    volume_data=volume_data,
1544                    param_function="clear-output",
1545                    param_config=param_config,
1546                    param_dataset_name=self.dataset_name,
1547                    container_tool=container_tool,
1548                )

Train Co-DETR model using containerized environment

def run_inference( self, dataset, param_config, param_n_gpus, container_tool, inference_settings, param_function='inference', inference_output_folder='custom_models/CoDETR/Co-DETR/output/inference/', gt_field='ground_truth'):
1589    def run_inference(
1590        self,
1591        dataset,
1592        param_config,
1593        param_n_gpus,
1594        container_tool,
1595        inference_settings,
1596        param_function="inference",
1597        inference_output_folder="custom_models/CoDETR/Co-DETR/output/inference/",
1598        gt_field="ground_truth",
1599    ):
1600        """Run inference using trained Co-DETR model and convert results to FiftyOne format"""
1601
1602        logging.info(f"Launching inference for Co-DETR config {param_config}.")
1603        volume_data = os.path.join(self.export_dir_root, self.dataset_name)
1604
1605        if inference_settings["inference_on_test"] is True:
1606            folder_inference = os.path.join("coco", "test2017")
1607        else:
1608            folder_inference = os.path.join("coco")
1609
1610        # Get model from Hugging Face
1611        dataset_name = None
1612        config_key = None
1613        try:
1614            if inference_settings["model_hf"] is None:
1615                hf_path = self.hf_repo_name
1616            else:
1617                hf_path = inference_settings["model_hf"]
1618
1619            dataset_name, config_key = get_dataset_and_model_from_hf_id(hf_path)
1620
1621            download_folder = os.path.join(
1622                self.root_codetr_models, dataset_name, config_key
1623            )
1624
1625            logging.info(
1626                f"Downloading model {hf_path} from Hugging Face into {download_folder}"
1627            )
1628            os.makedirs(download_folder, exist_ok=True)
1629
1630            file_path = hf_hub_download(
1631                repo_id=hf_path,
1632                filename="model.pth",
1633                local_dir=download_folder,
1634            )
1635        except Exception as e:
1636            logging.error(f"An error occured during model download: {e}")
1637
1638        model_path = os.path.join(dataset_name, config_key, "model.pth")
1639        logging.info(f"Starting inference for model {model_path}")
1640
1641        inference_result = self._run_container(
1642            volume_data=volume_data,
1643            param_function=param_function,
1644            param_config=param_config,
1645            param_n_gpus=param_n_gpus,
1646            container_tool=container_tool,
1647            param_inference_dataset_folder=folder_inference,
1648            param_inference_model_checkpoint=model_path,
1649        )
1650
1651        # Convert results from JSON output into V51 dataset
1652        # Files follow format inference_results_{timestamp}.json (run_inference.py)
1653        os.makedirs(inference_output_folder, exist_ok=True)
1654        output_files = [
1655            f
1656            for f in os.listdir(inference_output_folder)
1657            if f.startswith("inference_results_") and f.endswith(".json")
1658        ]
1659        logging.debug(f"Found files with inference content: {output_files}")
1660
1661        if not output_files:
1662            logging.error(
1663                f"No inference result files found in {inference_output_folder}"
1664            )
1665
1666        # Get full path for each file
1667        file_paths = [os.path.join(inference_output_folder, f) for f in output_files]
1668
1669        # Extract timestamp from the filename and sort based on the timestamp
1670        file_paths_sorted = sorted(
1671            file_paths,
1672            key=lambda f: datetime.datetime.strptime(
1673                f.split("_")[-2] + "_" + f.split("_")[-1].replace(".json", ""),
1674                "%Y%m%d_%H%M%S",
1675            ),
1676            reverse=True,
1677        )
1678
1679        # Use the most recent file based on timestamp
1680        latest_file = file_paths_sorted[0]
1681        logging.info(f"Using inference results from: {latest_file}")
1682        with open(latest_file, "r") as file:
1683            data = json.load(file)
1684
1685        # Get conversion for annotated classes
1686        annotations_path = os.path.join(
1687            volume_data, "coco", "annotations", "instances_train2017.json"
1688        )
1689
1690        with open(annotations_path, "r") as file:
1691            data_annotations = json.load(file)
1692
1693        class_ids_and_names = [
1694            (category["id"], category["name"])
1695            for category in data_annotations["categories"]
1696        ]
1697
1698        # Match sample filepaths (from exported Co-DETR COCO format) to V51 filepaths
1699        sample = dataset.first()
1700        root_dir_samples = sample.filepath
1701
1702        # Convert results into V51 file format
1703        detection_threshold = inference_settings["detection_threshold"]
1704        pred_key = f"pred_od_{config_key}-{dataset_name}"
1705        for key, value in tqdm(data.items(), desc="Processing Co-DETR detection"):
1706            try:
1707                # Get filename
1708                filepath = CustomCoDETRObjectDetection._find_file_iteratively(
1709                    root_dir_samples, os.path.basename(key)
1710                )
1711                sample = dataset[filepath]
1712
1713                img_width = sample.metadata.width
1714                img_height = sample.metadata.height
1715
1716                detections_v51 = []
1717                for class_id, class_detections in enumerate(data[key]):  # Starts with 0
1718                    if len(class_detections) > 0:
1719                        objects_class = class_ids_and_names[class_id]
1720                        for detection in class_detections:
1721                            confidence = detection[4]
1722                            detection_v51 = fo.Detection(
1723                                label=objects_class[1],
1724                                bounding_box=[
1725                                    detection[0] / img_width,
1726                                    detection[1] / img_height,
1727                                    (detection[2] - detection[0]) / img_width,
1728                                    (detection[3] - detection[1]) / img_height,
1729                                ],
1730                                confidence=confidence,
1731                            )
1732                            if confidence >= detection_threshold:
1733                                detections_v51.append(detection_v51)
1734
1735                sample[pred_key] = fo.Detections(detections=detections_v51)
1736                sample.save()
1737            except Exception as e:
1738                logging.error(
1739                    f"An error occured during the conversion of Co-DETR inference results to the V51 dataset: {e}"
1740                )
1741
1742        # Run V51 evaluation
1743        if inference_settings["do_eval"] is True:
1744            eval_key = pred_key.replace("pred_", "eval_").replace("-", "_")
1745
1746            if inference_settings["inference_on_test"] is True:
1747                dataset_view = dataset.match_tags(["test"])
1748            else:
1749                dataset_view = dataset
1750
1751            logging.info(
1752                f"Starting evaluation for {pred_key} in evaluation key {eval_key}."
1753            )
1754
1755            results = dataset_view.evaluate_detections(
1756                pred_key,
1757                gt_field=gt_field,
1758                eval_key=eval_key,
1759                compute_mAP=True,
1760            )
1761
1762            results.print_report()

Run inference using trained Co-DETR model and convert results to FiftyOne format

class CustomRFDETRObjectDetection:
1847class CustomRFDETRObjectDetection:
1848    """Interface for running RF-DETR object detection model training and inference"""
1849
1850    def __init__(self, dataset, dataset_info, run_config):
1851        """Initialize RF-DETR interface with dataset and configuration"""
1852        self.dataset = dataset
1853        self.dataset_name = dataset_info["name"]
1854        self.export_dir_root = run_config["export_dataset_root"]
1855        self.config_key = os.path.splitext(os.path.basename(run_config.get("config", "rfdetr")))[0]
1856        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_{self.config_key}"
1857
1858    def convert_data(self):
1859        """
1860        Convert dataset to RF-DETR COCO format with zero-indexed categories.
1861        Automatically creates missing splits by splitting val or test 50/50.
1862
1863        Expected output structure:
1864        dsname/rfdetr/
1865            test/
1866                _annotations.coco.json
1867                images...
1868            train/
1869                _annotations.coco.json
1870                images...
1871            valid/
1872                _annotations.coco.json
1873                images...
1874        """
1875        export_dir = os.path.join(self.export_dir_root, self.dataset_name, "rfdetr")
1876
1877        # Check if folder already exists
1878        if os.path.exists(export_dir):
1879            logging.warning(
1880                f"Folder {export_dir} already exists, skipping data export."
1881            )
1882            return
1883
1884        # Make directory
1885        os.makedirs(export_dir, exist_ok=True)
1886        logging.info(f"Exporting data to {export_dir}")
1887
1888        # Check what splits exist
1889        available_tags = self.dataset.distinct("tags")
1890        has_train = "train" in available_tags
1891        has_val = "val" in available_tags
1892        has_test = "test" in available_tags
1893
1894        logging.info(f"Available splits in dataset: {available_tags}")
1895
1896        # Handle missing splits - split val or test 50/50
1897        if has_train and has_val and not has_test:
1898            logging.info("No test split found. Splitting val 50/50 into valid and test...")
1899            self._split_50_50("val", "test")
1900        elif has_train and has_test and not has_val:
1901            logging.info("No val split found. Splitting test 50/50 into valid and test...")
1902            self._split_50_50("test", "val")
1903        elif not has_train or (not has_val and not has_test):
1904            logging.error(
1905                f"Dataset must have 'train' and at least one of 'val' or 'test'. "
1906                f"Found: {available_tags}"
1907            )
1908            raise ValueError("Insufficient splits in dataset")
1909
1910        # RF-DETR expects 'train', 'valid', 'test' splits
1911        split_mapping = {
1912            "train": "train",
1913            "val": "valid",  # Map 'val' to 'valid' for RF-DETR
1914            "test": "test"
1915        }
1916
1917        for v51_split, rfdetr_split in split_mapping.items():
1918            split_view = self.dataset.match_tags(v51_split)
1919
1920            if len(split_view) == 0:
1921                logging.warning(f"No samples found for split '{v51_split}', skipping.")
1922                continue
1923
1924            split_export_dir = os.path.join(export_dir, rfdetr_split)
1925            os.makedirs(split_export_dir, exist_ok=True)
1926
1927            # Export to COCO format
1928            annotation_path = os.path.join(split_export_dir, "_annotations.coco.json")
1929
1930            logging.info(f"Exporting {len(split_view)} samples to {rfdetr_split}/")
1931
1932            split_view.export(
1933                dataset_type=fo.types.COCODetectionDataset,
1934                data_path=split_export_dir,
1935                labels_path=annotation_path,
1936                label_field="ground_truth",
1937            )
1938
1939            # Fix category IDs: Convert from 1-indexed to 0-indexed
1940            self._fix_annotation_indices(annotation_path)
1941
1942        logging.info(f"Successfully exported dataset to RF-DETR format at {export_dir}")
1943
1944    def _split_50_50(self, source_split, target_split):
1945        """
1946        Split a dataset split 50/50 into two splits.
1947
1948        Args:
1949            source_split: The split to divide (e.g., "val" or "test")
1950            target_split: The new split to create (e.g., "test" or "val")
1951
1952        Example:
1953            - 1000 val samples → 500 val + 500 test
1954            - 1000 test samples → 500 val + 500 test
1955        """
1956        source_samples = self.dataset.match_tags(source_split)
1957        source_ids = source_samples.values("id")
1958
1959        if len(source_ids) < 2:
1960            logging.error(
1961                f"Not enough samples in '{source_split}' to split. "
1962                f"Need at least 2, found {len(source_ids)}"
1963            )
1964            raise ValueError(f"Insufficient samples in {source_split} split")
1965
1966        # Shuffle for random split
1967        random.seed(GLOBAL_SEED)  # Use GLOBAL_SEED instead of 42
1968        random.shuffle(source_ids)
1969
1970        # Split 50/50
1971        split_point = len(source_ids) // 2
1972        keep_in_source = source_ids[:split_point]
1973        move_to_target = source_ids[split_point:]
1974
1975        logging.info(
1976            f"Splitting {len(source_ids)} '{source_split}' samples: "
1977            f"{len(keep_in_source)} remain in '{source_split}', "
1978            f"{len(move_to_target)} moved to '{target_split}'"
1979        )
1980
1981        # Move samples to target split
1982        for sample_id in move_to_target:
1983            sample = self.dataset[sample_id]
1984            sample.tags.remove(source_split)
1985            sample.tags.append(target_split)
1986            sample.save()
1987
1988        self.dataset.save()
1989        logging.info(f"Successfully created '{target_split}' split from '{source_split}'")
1990
1991    def _fix_annotation_indices(self, annotation_path):
1992        """
1993        Fix COCO annotation file to use zero-indexed category IDs.
1994
1995        Args:
1996            annotation_path: Path to the _annotations.coco.json file
1997        """
1998        if not os.path.exists(annotation_path):
1999            logging.error(f"Annotation file not found: {annotation_path}")
2000            return
2001
2002        try:
2003            # Create backup
2004            backup_path = f"{annotation_path}.backup"
2005            if not os.path.exists(backup_path):
2006                shutil.copy2(annotation_path, backup_path)
2007                logging.debug(f"Created backup: {backup_path}")
2008
2009            # Read annotation file
2010            with open(annotation_path, 'r') as f:
2011                data = json.load(f)
2012
2013            # Fix categories: 1-indexed → 0-indexed
2014            if 'categories' in data:
2015                for cat in data['categories']:
2016                    if cat['id'] > 0:
2017                        cat['id'] -= 1
2018                logging.debug(f"Fixed {len(data['categories'])} category IDs")
2019
2020            # Fix annotations: 1-indexed → 0-indexed
2021            if 'annotations' in data:
2022                for ann in data['annotations']:
2023                    if ann['category_id'] > 0:
2024                        ann['category_id'] -= 1
2025                logging.debug(f"Fixed {len(data['annotations'])} annotation category IDs")
2026
2027            # Save fixed file
2028            with open(annotation_path, 'w') as f:
2029                json.dump(data, f, indent=2)
2030
2031            logging.info(f"Successfully fixed indices in: {annotation_path}")
2032
2033        except Exception as e:
2034            logging.error(f"Error fixing annotation indices in {annotation_path}: {e}")
2035            # Restore from backup if something went wrong
2036            backup_path = f"{annotation_path}.backup"
2037            if os.path.exists(backup_path):
2038                shutil.copy2(backup_path, annotation_path)
2039                logging.info(f"Restored from backup due to error")
2040
2041    def train(self, run_config, shared_config):
2042        """Train RF-DETR model using shared and model-specific configuration"""
2043
2044
2045
2046
2047        # Model selection mapping
2048        MODEL_REGISTRY = {
2049            "rfdetr_nano": RFDETRNano,
2050            "rfdetr_small": RFDETRSmall,
2051            "rfdetr_medium": RFDETRMedium,
2052            "rfdetr_large": RFDETRLarge,
2053        }
2054
2055        model_name = self.config_key.lower()
2056
2057        if model_name not in MODEL_REGISTRY:
2058            logging.error(
2059                f"Model '{model_name}' not supported. "
2060                f"Available models: {list(MODEL_REGISTRY.keys())}"
2061            )
2062            raise ValueError(f"Unsupported RF-DETR model: {model_name}")
2063
2064        # Initialize model
2065        logging.info(f"Initializing {model_name}...")
2066        ModelClass = MODEL_REGISTRY[model_name]
2067        model = ModelClass()
2068
2069        # Prepare dataset directory
2070        dataset_dir = os.path.join(self.export_dir_root, self.dataset_name, "rfdetr")
2071
2072        if not os.path.exists(dataset_dir):
2073            logging.error(f"Dataset directory not found: {dataset_dir}")
2074            logging.info("Please run convert_data() first to prepare the dataset.")
2075            raise FileNotFoundError(f"Dataset not found at {dataset_dir}")
2076
2077        # Output directory
2078        output_dir = os.path.join("output/models/rfdetr", self.dataset_name, model_name)
2079        os.makedirs(output_dir, exist_ok=True)
2080
2081        # Build training arguments
2082        train_kwargs = {
2083            "dataset_dir": dataset_dir,
2084            "output_dir": output_dir,
2085
2086            # === SHARED parameters from top-level config ===
2087            "epochs": shared_config.get("epochs", 50),
2088            "lr": shared_config.get("learning_rate", 1e-4),
2089            "weight_decay": shared_config.get("weight_decay", 0.0001),
2090
2091            # === RF-DETR specific parameters ===
2092            "batch_size": run_config.get("batch_size", 16),
2093            "grad_accum_steps": run_config.get("grad_accum_steps", 1),
2094            "lr_encoder": run_config.get("lr_encoder", None),
2095            "resolution": run_config.get("resolution", None),
2096            "use_ema": run_config.get("use_ema", True),
2097            "gradient_checkpointing": run_config.get("gradient_checkpointing", False),
2098
2099            # === Logging (use global settings) ===
2100            "tensorboard": True,
2101            "wandb": WANDB_ACTIVE,
2102            "project": f"MCityDataEngine-RFDETR",
2103            "run": f"{self.dataset_name}_{model_name}",
2104
2105            # === Early stopping ===
2106            "early_stopping": True,
2107            "early_stopping_patience": shared_config.get("early_stop_patience", 10),
2108            "early_stopping_min_delta": run_config.get(
2109                "early_stopping_min_delta",
2110                shared_config.get("early_stop_threshold", 0.001)
2111            ),
2112            "early_stopping_use_ema": run_config.get("early_stopping_use_ema", True),
2113        }
2114
2115        # Set device - use all available GPUs
2116        if torch.cuda.is_available():
2117            if torch.cuda.device_count() > 1:
2118                device = "cuda"
2119                logging.info(f"Using {torch.cuda.device_count()} GPUs for training")
2120            else:
2121                device = "cuda:0"
2122                logging.info("Using single GPU for training")
2123        else:
2124            device = "cpu"
2125            logging.warning("No GPU available, training on CPU")
2126
2127        train_kwargs["device"] = device
2128
2129        # Remove None values
2130        train_kwargs = {k: v for k, v in train_kwargs.items() if v is not None}
2131
2132        # Log configuration
2133        logging.info("="*70)
2134        logging.info("RF-DETR TRAINING CONFIGURATION")
2135        logging.info("="*70)
2136        logging.info(f"Model: {model_name}")
2137        logging.info(f"Dataset: {dataset_dir}")
2138        logging.info(f"Output: {output_dir}")
2139        logging.info(f"WandB Active: {WANDB_ACTIVE}")
2140        for key, value in train_kwargs.items():
2141            if key not in ["dataset_dir", "output_dir"]:
2142                logging.info(f"  {key}: {value}")
2143        logging.info("="*70)
2144
2145        # Train
2146        try:
2147            logging.info("Starting RF-DETR training...")
2148            model.train(**train_kwargs)
2149            logging.info("RF-DETR training completed successfully!")
2150
2151            # Model paths to check (RF-DETR can save in different locations)
2152            possible_model_paths = [
2153                os.path.join(output_dir, "checkpoints", "best.pt"),
2154                os.path.join(output_dir, "checkpoint_best_total.pth"),
2155                os.path.join(output_dir, "best.pt"),
2156            ]
2157
2158            self.model_path = None
2159            for path in possible_model_paths:
2160                if os.path.exists(path):
2161                    self.model_path = path
2162                    logging.info(f"Found trained model at: {path}")
2163                    break
2164
2165            if self.model_path is None:
2166                logging.warning("Could not find trained model file in expected locations")
2167                self.model_path = possible_model_paths[0]  # Default to first path
2168
2169            # Upload to Hugging Face if configured
2170            if HF_DO_UPLOAD:
2171                self._upload_to_hf()
2172
2173            return True
2174
2175        except Exception as e:
2176            logging.error(f"❌ Training failed: {e}")
2177            import traceback
2178            traceback.print_exc()
2179            return False
2180
2181
2182    def _upload_to_hf(self):
2183        """Upload trained RF-DETR model to Hugging Face"""
2184
2185        if not os.path.exists(self.model_path):
2186            logging.warning(f"Model file not found at {self.model_path}, skipping upload.")
2187            return
2188
2189        try:
2190            logging.info(f"Uploading RF-DETR model to Hugging Face: {self.hf_repo_name}")
2191            api = HfApi()
2192
2193            # Create repository
2194            api.create_repo(
2195                self.hf_repo_name,
2196                private=True,
2197                repo_type="model",
2198                exist_ok=True
2199            )
2200
2201            # Upload model file
2202            api.upload_file(
2203                path_or_fileobj=self.model_path,
2204                path_in_repo="best.pt",
2205                repo_id=self.hf_repo_name,
2206                repo_type="model",
2207            )
2208
2209            logging.info(f"Model uploaded successfully to {self.hf_repo_name}")
2210
2211        except Exception as e:
2212            logging.error(f"Failed to upload model to Hugging Face: {e}")
2213            import traceback
2214            traceback.print_exc()
2215
2216
2217    def inference(self, inference_settings, gt_field="ground_truth"):
2218        """Performs inference using RF-DETR model on a dataset with optional evaluation"""
2219
2220
2221
2222        logging.info(f"Running RF-DETR inference on dataset {self.dataset_name}")
2223
2224        # Model selection mapping
2225        MODEL_REGISTRY = {
2226            "rfdetr_nano": RFDETRNano,
2227            "rfdetr_small": RFDETRSmall,
2228            "rfdetr_medium": RFDETRMedium,
2229            "rfdetr_large": RFDETRLarge,
2230        }
2231
2232        # Determine model and dataset names
2233        dataset_name = None
2234        model_name = self.config_key.lower()
2235
2236        model_hf = inference_settings.get("model_hf", None)
2237
2238        # Determine model path
2239        if model_hf is not None:
2240            # Use model from Hugging Face
2241            logging.info(f"Using model from Hugging Face: {model_hf}")
2242            dataset_name, model_name = get_dataset_and_model_from_hf_id(model_hf)
2243
2244            # Set up directories
2245            download_dir = os.path.join(
2246                "output/models/rfdetr", dataset_name, model_name
2247            )
2248            os.makedirs(download_dir, exist_ok=True)
2249
2250            # Download model from Hugging Face
2251            try:
2252                logging.info(f"Downloading model from Hugging Face: {model_hf}")
2253                model_path = hf_hub_download(
2254                    repo_id=model_hf,
2255                    filename="best.pt",
2256                    local_dir=download_dir,
2257                )
2258            except Exception as e:
2259                logging.error(f"Failed to download model from Hugging Face: {e}")
2260                return False
2261        else:
2262            # Use locally trained model
2263            dataset_name = self.dataset_name
2264
2265            # Check multiple possible locations
2266            possible_paths = [
2267                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "checkpoints", "best.pt"),
2268                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "checkpoint_best_total.pth"),
2269                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "best.pt"),
2270            ]
2271
2272            model_path = None
2273            for path in possible_paths:
2274                if os.path.exists(path):
2275                    model_path = path
2276                    logging.info(f"Found model at: {path}")
2277                    break
2278
2279            if model_path is None:
2280                # Try downloading from auto-generated HF repo
2281                logging.info(f"Local model not found. Attempting to download from {self.hf_repo_name}")
2282                download_dir = os.path.join(
2283                    "output/models/rfdetr", self.dataset_name, model_name
2284                )
2285                os.makedirs(download_dir, exist_ok=True)
2286
2287                try:
2288                    model_path = hf_hub_download(
2289                        repo_id=self.hf_repo_name,
2290                        filename="best.pt",
2291                        local_dir=download_dir,
2292                    )
2293                except Exception as e:
2294                    logging.error(f"Failed to load or download model: {e}")
2295                    return False
2296
2297        # Check if model exists
2298        if not os.path.exists(model_path):
2299            logging.error(f"Model file not found: {model_path}")
2300            return False
2301
2302        logging.info(f"Using model: {model_path}")
2303
2304        # Initialize model
2305        if model_name not in MODEL_REGISTRY:
2306            logging.error(f"Model '{model_name}' not supported.")
2307            return False
2308
2309        ModelClass = MODEL_REGISTRY[model_name]
2310
2311        # Get class names from dataset
2312        try:
2313            class_names = self.dataset.distinct(f"{gt_field}.detections.label")
2314            class_names = sorted(class_names)
2315            num_classes = len(class_names)
2316            logging.info(f"Found {num_classes} classes: {class_names}")
2317        except Exception as e:
2318            logging.warning(f"Could not extract class names from dataset: {e}")
2319            num_classes = 8  # Default fallback
2320            class_names = None
2321
2322        # Load model with trained weights
2323        try:
2324            logging.info("Loading RF-DETR model...")
2325            model = ModelClass(
2326                pretrain_weights=model_path,
2327                num_classes=num_classes
2328            )
2329
2330            logging.info("RF-DETR model loaded successfully")
2331        except Exception as e:
2332            logging.error(f"Failed to load model: {e}")
2333            return False
2334
2335        # Prepare dataset view
2336        detection_threshold = inference_settings.get("detection_threshold", 0.2)
2337
2338        if inference_settings.get("inference_on_test", True):
2339            INFERENCE_SPLITS = ["test"]
2340            dataset_view = self.dataset.match_tags(INFERENCE_SPLITS)
2341            if len(dataset_view) == 0:
2342                logging.error(f"Dataset has no splits: {INFERENCE_SPLITS}")
2343                return False
2344        else:
2345            dataset_view = self.dataset
2346
2347        # Prediction key
2348        pred_key = f"pred_od_{model_name}-{dataset_name}"
2349
2350        logging.info(f"Running inference on {len(dataset_view)} samples...")
2351        logging.info(f"Detection threshold: {detection_threshold}")
2352
2353        # Run inference on each sample
2354        try:
2355            processed_count = 0
2356
2357            for sample in tqdm(dataset_view.iter_samples(progress=True, autosave=True),
2358                            total=len(dataset_view),
2359                            desc="RF-DETR Inference"):
2360
2361                try:
2362                    # Load image
2363                    image = Image.open(sample.filepath)
2364                    img_width, img_height = image.size
2365
2366                    # Run inference using RF-DETR's predict method
2367                    detections = model.predict(
2368                        image,
2369                        threshold=detection_threshold
2370                    )
2371
2372                    # Convert supervision detections to FiftyOne format
2373                    fo_detections = []
2374
2375                    if len(detections) > 0:
2376                        for i in range(len(detections)):
2377                            # Get detection data (RF-DETR returns supervision format)
2378                            bbox = detections.xyxy[i]  # [x1, y1, x2, y2] in pixel coordinates
2379                            confidence = detections.confidence[i] if detections.confidence is not None else 1.0
2380                            class_id = detections.class_id[i] if detections.class_id is not None else 0
2381
2382                            # Convert to relative coordinates [x, y, width, height]
2383                            x1, y1, x2, y2 = bbox
2384                            rel_x = x1 / img_width
2385                            rel_y = y1 / img_height
2386                            rel_w = (x2 - x1) / img_width
2387                            rel_h = (y2 - y1) / img_height
2388
2389                            # Get class name
2390                            if class_names and class_id < len(class_names):
2391                                class_name = class_names[class_id]
2392                            else:
2393                                class_name = f"class_{class_id}"
2394
2395                            # Create FiftyOne detection
2396                            fo_detection = fo.Detection(
2397                                label=class_name,
2398                                bounding_box=[rel_x, rel_y, rel_w, rel_h],
2399                                confidence=float(confidence)
2400                            )
2401                            fo_detections.append(fo_detection)
2402
2403                    # Save detections to sample
2404                    sample[pred_key] = fo.Detections(detections=fo_detections)
2405                    processed_count += 1
2406
2407                except Exception as e:
2408                    logging.error(f"Error processing sample {sample.id}: {e}")
2409                    continue
2410
2411            logging.info(f"Inference completed on {processed_count}/{len(dataset_view)} samples")
2412            logging.info(f"Predictions saved to field '{pred_key}'")
2413
2414        except Exception as e:
2415            logging.error(f"Error during inference: {e}")
2416            import traceback
2417            traceback.print_exc()
2418            return False
2419
2420        # Evaluate if requested
2421        if inference_settings.get("do_eval", True):
2422            eval_key = f"eval_{model_name}_{dataset_name}".replace("-", "_")
2423
2424            if inference_settings.get("inference_on_test", True):
2425                dataset_view = self.dataset.match_tags(["test"])
2426            else:
2427                dataset_view = self.dataset
2428
2429            # Filter samples that have both predictions and ground truth
2430            dataset_view = dataset_view.exists(pred_key).exists(gt_field)
2431
2432            if len(dataset_view) == 0:
2433                logging.warning("No samples found with both predictions and ground truth for evaluation")
2434            else:
2435                try:
2436                    logging.info(f"Evaluating predictions on {len(dataset_view)} samples...")
2437
2438                    results = dataset_view.evaluate_detections(
2439                        pred_key,
2440                        gt_field=gt_field,
2441                        eval_key=eval_key,
2442                        compute_mAP=True,
2443                        iou=0.5  # IoU threshold for matching
2444                    )
2445
2446                    # Print evaluation report
2447                    logging.info("="*70)
2448                    logging.info("EVALUATION RESULTS")
2449                    logging.info("="*70)
2450                    results.print_report()
2451                    logging.info("="*70)
2452
2453                    logging.info("Evaluation completed")
2454                except Exception as e:
2455                    logging.error(f"Evaluation failed: {e}")
2456                    import traceback
2457                    traceback.print_exc()
2458
2459        return True

Interface for running RF-DETR object detection model training and inference

CustomRFDETRObjectDetection(dataset, dataset_info, run_config)
1850    def __init__(self, dataset, dataset_info, run_config):
1851        """Initialize RF-DETR interface with dataset and configuration"""
1852        self.dataset = dataset
1853        self.dataset_name = dataset_info["name"]
1854        self.export_dir_root = run_config["export_dataset_root"]
1855        self.config_key = os.path.splitext(os.path.basename(run_config.get("config", "rfdetr")))[0]
1856        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_{self.config_key}"

Initialize RF-DETR interface with dataset and configuration

dataset
dataset_name
export_dir_root
config_key
hf_repo_name
def convert_data(self):
1858    def convert_data(self):
1859        """
1860        Convert dataset to RF-DETR COCO format with zero-indexed categories.
1861        Automatically creates missing splits by splitting val or test 50/50.
1862
1863        Expected output structure:
1864        dsname/rfdetr/
1865            test/
1866                _annotations.coco.json
1867                images...
1868            train/
1869                _annotations.coco.json
1870                images...
1871            valid/
1872                _annotations.coco.json
1873                images...
1874        """
1875        export_dir = os.path.join(self.export_dir_root, self.dataset_name, "rfdetr")
1876
1877        # Check if folder already exists
1878        if os.path.exists(export_dir):
1879            logging.warning(
1880                f"Folder {export_dir} already exists, skipping data export."
1881            )
1882            return
1883
1884        # Make directory
1885        os.makedirs(export_dir, exist_ok=True)
1886        logging.info(f"Exporting data to {export_dir}")
1887
1888        # Check what splits exist
1889        available_tags = self.dataset.distinct("tags")
1890        has_train = "train" in available_tags
1891        has_val = "val" in available_tags
1892        has_test = "test" in available_tags
1893
1894        logging.info(f"Available splits in dataset: {available_tags}")
1895
1896        # Handle missing splits - split val or test 50/50
1897        if has_train and has_val and not has_test:
1898            logging.info("No test split found. Splitting val 50/50 into valid and test...")
1899            self._split_50_50("val", "test")
1900        elif has_train and has_test and not has_val:
1901            logging.info("No val split found. Splitting test 50/50 into valid and test...")
1902            self._split_50_50("test", "val")
1903        elif not has_train or (not has_val and not has_test):
1904            logging.error(
1905                f"Dataset must have 'train' and at least one of 'val' or 'test'. "
1906                f"Found: {available_tags}"
1907            )
1908            raise ValueError("Insufficient splits in dataset")
1909
1910        # RF-DETR expects 'train', 'valid', 'test' splits
1911        split_mapping = {
1912            "train": "train",
1913            "val": "valid",  # Map 'val' to 'valid' for RF-DETR
1914            "test": "test"
1915        }
1916
1917        for v51_split, rfdetr_split in split_mapping.items():
1918            split_view = self.dataset.match_tags(v51_split)
1919
1920            if len(split_view) == 0:
1921                logging.warning(f"No samples found for split '{v51_split}', skipping.")
1922                continue
1923
1924            split_export_dir = os.path.join(export_dir, rfdetr_split)
1925            os.makedirs(split_export_dir, exist_ok=True)
1926
1927            # Export to COCO format
1928            annotation_path = os.path.join(split_export_dir, "_annotations.coco.json")
1929
1930            logging.info(f"Exporting {len(split_view)} samples to {rfdetr_split}/")
1931
1932            split_view.export(
1933                dataset_type=fo.types.COCODetectionDataset,
1934                data_path=split_export_dir,
1935                labels_path=annotation_path,
1936                label_field="ground_truth",
1937            )
1938
1939            # Fix category IDs: Convert from 1-indexed to 0-indexed
1940            self._fix_annotation_indices(annotation_path)
1941
1942        logging.info(f"Successfully exported dataset to RF-DETR format at {export_dir}")

Convert dataset to RF-DETR COCO format with zero-indexed categories. Automatically creates missing splits by splitting val or test 50/50.

Expected output structure: dsname/rfdetr/ test/ _annotations.coco.json images... train/ _annotations.coco.json images... valid/ _annotations.coco.json images...

def train(self, run_config, shared_config):
2041    def train(self, run_config, shared_config):
2042        """Train RF-DETR model using shared and model-specific configuration"""
2043
2044
2045
2046
2047        # Model selection mapping
2048        MODEL_REGISTRY = {
2049            "rfdetr_nano": RFDETRNano,
2050            "rfdetr_small": RFDETRSmall,
2051            "rfdetr_medium": RFDETRMedium,
2052            "rfdetr_large": RFDETRLarge,
2053        }
2054
2055        model_name = self.config_key.lower()
2056
2057        if model_name not in MODEL_REGISTRY:
2058            logging.error(
2059                f"Model '{model_name}' not supported. "
2060                f"Available models: {list(MODEL_REGISTRY.keys())}"
2061            )
2062            raise ValueError(f"Unsupported RF-DETR model: {model_name}")
2063
2064        # Initialize model
2065        logging.info(f"Initializing {model_name}...")
2066        ModelClass = MODEL_REGISTRY[model_name]
2067        model = ModelClass()
2068
2069        # Prepare dataset directory
2070        dataset_dir = os.path.join(self.export_dir_root, self.dataset_name, "rfdetr")
2071
2072        if not os.path.exists(dataset_dir):
2073            logging.error(f"Dataset directory not found: {dataset_dir}")
2074            logging.info("Please run convert_data() first to prepare the dataset.")
2075            raise FileNotFoundError(f"Dataset not found at {dataset_dir}")
2076
2077        # Output directory
2078        output_dir = os.path.join("output/models/rfdetr", self.dataset_name, model_name)
2079        os.makedirs(output_dir, exist_ok=True)
2080
2081        # Build training arguments
2082        train_kwargs = {
2083            "dataset_dir": dataset_dir,
2084            "output_dir": output_dir,
2085
2086            # === SHARED parameters from top-level config ===
2087            "epochs": shared_config.get("epochs", 50),
2088            "lr": shared_config.get("learning_rate", 1e-4),
2089            "weight_decay": shared_config.get("weight_decay", 0.0001),
2090
2091            # === RF-DETR specific parameters ===
2092            "batch_size": run_config.get("batch_size", 16),
2093            "grad_accum_steps": run_config.get("grad_accum_steps", 1),
2094            "lr_encoder": run_config.get("lr_encoder", None),
2095            "resolution": run_config.get("resolution", None),
2096            "use_ema": run_config.get("use_ema", True),
2097            "gradient_checkpointing": run_config.get("gradient_checkpointing", False),
2098
2099            # === Logging (use global settings) ===
2100            "tensorboard": True,
2101            "wandb": WANDB_ACTIVE,
2102            "project": f"MCityDataEngine-RFDETR",
2103            "run": f"{self.dataset_name}_{model_name}",
2104
2105            # === Early stopping ===
2106            "early_stopping": True,
2107            "early_stopping_patience": shared_config.get("early_stop_patience", 10),
2108            "early_stopping_min_delta": run_config.get(
2109                "early_stopping_min_delta",
2110                shared_config.get("early_stop_threshold", 0.001)
2111            ),
2112            "early_stopping_use_ema": run_config.get("early_stopping_use_ema", True),
2113        }
2114
2115        # Set device - use all available GPUs
2116        if torch.cuda.is_available():
2117            if torch.cuda.device_count() > 1:
2118                device = "cuda"
2119                logging.info(f"Using {torch.cuda.device_count()} GPUs for training")
2120            else:
2121                device = "cuda:0"
2122                logging.info("Using single GPU for training")
2123        else:
2124            device = "cpu"
2125            logging.warning("No GPU available, training on CPU")
2126
2127        train_kwargs["device"] = device
2128
2129        # Remove None values
2130        train_kwargs = {k: v for k, v in train_kwargs.items() if v is not None}
2131
2132        # Log configuration
2133        logging.info("="*70)
2134        logging.info("RF-DETR TRAINING CONFIGURATION")
2135        logging.info("="*70)
2136        logging.info(f"Model: {model_name}")
2137        logging.info(f"Dataset: {dataset_dir}")
2138        logging.info(f"Output: {output_dir}")
2139        logging.info(f"WandB Active: {WANDB_ACTIVE}")
2140        for key, value in train_kwargs.items():
2141            if key not in ["dataset_dir", "output_dir"]:
2142                logging.info(f"  {key}: {value}")
2143        logging.info("="*70)
2144
2145        # Train
2146        try:
2147            logging.info("Starting RF-DETR training...")
2148            model.train(**train_kwargs)
2149            logging.info("RF-DETR training completed successfully!")
2150
2151            # Model paths to check (RF-DETR can save in different locations)
2152            possible_model_paths = [
2153                os.path.join(output_dir, "checkpoints", "best.pt"),
2154                os.path.join(output_dir, "checkpoint_best_total.pth"),
2155                os.path.join(output_dir, "best.pt"),
2156            ]
2157
2158            self.model_path = None
2159            for path in possible_model_paths:
2160                if os.path.exists(path):
2161                    self.model_path = path
2162                    logging.info(f"Found trained model at: {path}")
2163                    break
2164
2165            if self.model_path is None:
2166                logging.warning("Could not find trained model file in expected locations")
2167                self.model_path = possible_model_paths[0]  # Default to first path
2168
2169            # Upload to Hugging Face if configured
2170            if HF_DO_UPLOAD:
2171                self._upload_to_hf()
2172
2173            return True
2174
2175        except Exception as e:
2176            logging.error(f"❌ Training failed: {e}")
2177            import traceback
2178            traceback.print_exc()
2179            return False

Train RF-DETR model using shared and model-specific configuration

def inference(self, inference_settings, gt_field='ground_truth'):
2217    def inference(self, inference_settings, gt_field="ground_truth"):
2218        """Performs inference using RF-DETR model on a dataset with optional evaluation"""
2219
2220
2221
2222        logging.info(f"Running RF-DETR inference on dataset {self.dataset_name}")
2223
2224        # Model selection mapping
2225        MODEL_REGISTRY = {
2226            "rfdetr_nano": RFDETRNano,
2227            "rfdetr_small": RFDETRSmall,
2228            "rfdetr_medium": RFDETRMedium,
2229            "rfdetr_large": RFDETRLarge,
2230        }
2231
2232        # Determine model and dataset names
2233        dataset_name = None
2234        model_name = self.config_key.lower()
2235
2236        model_hf = inference_settings.get("model_hf", None)
2237
2238        # Determine model path
2239        if model_hf is not None:
2240            # Use model from Hugging Face
2241            logging.info(f"Using model from Hugging Face: {model_hf}")
2242            dataset_name, model_name = get_dataset_and_model_from_hf_id(model_hf)
2243
2244            # Set up directories
2245            download_dir = os.path.join(
2246                "output/models/rfdetr", dataset_name, model_name
2247            )
2248            os.makedirs(download_dir, exist_ok=True)
2249
2250            # Download model from Hugging Face
2251            try:
2252                logging.info(f"Downloading model from Hugging Face: {model_hf}")
2253                model_path = hf_hub_download(
2254                    repo_id=model_hf,
2255                    filename="best.pt",
2256                    local_dir=download_dir,
2257                )
2258            except Exception as e:
2259                logging.error(f"Failed to download model from Hugging Face: {e}")
2260                return False
2261        else:
2262            # Use locally trained model
2263            dataset_name = self.dataset_name
2264
2265            # Check multiple possible locations
2266            possible_paths = [
2267                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "checkpoints", "best.pt"),
2268                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "checkpoint_best_total.pth"),
2269                os.path.join("output/models/rfdetr", self.dataset_name, model_name, "best.pt"),
2270            ]
2271
2272            model_path = None
2273            for path in possible_paths:
2274                if os.path.exists(path):
2275                    model_path = path
2276                    logging.info(f"Found model at: {path}")
2277                    break
2278
2279            if model_path is None:
2280                # Try downloading from auto-generated HF repo
2281                logging.info(f"Local model not found. Attempting to download from {self.hf_repo_name}")
2282                download_dir = os.path.join(
2283                    "output/models/rfdetr", self.dataset_name, model_name
2284                )
2285                os.makedirs(download_dir, exist_ok=True)
2286
2287                try:
2288                    model_path = hf_hub_download(
2289                        repo_id=self.hf_repo_name,
2290                        filename="best.pt",
2291                        local_dir=download_dir,
2292                    )
2293                except Exception as e:
2294                    logging.error(f"Failed to load or download model: {e}")
2295                    return False
2296
2297        # Check if model exists
2298        if not os.path.exists(model_path):
2299            logging.error(f"Model file not found: {model_path}")
2300            return False
2301
2302        logging.info(f"Using model: {model_path}")
2303
2304        # Initialize model
2305        if model_name not in MODEL_REGISTRY:
2306            logging.error(f"Model '{model_name}' not supported.")
2307            return False
2308
2309        ModelClass = MODEL_REGISTRY[model_name]
2310
2311        # Get class names from dataset
2312        try:
2313            class_names = self.dataset.distinct(f"{gt_field}.detections.label")
2314            class_names = sorted(class_names)
2315            num_classes = len(class_names)
2316            logging.info(f"Found {num_classes} classes: {class_names}")
2317        except Exception as e:
2318            logging.warning(f"Could not extract class names from dataset: {e}")
2319            num_classes = 8  # Default fallback
2320            class_names = None
2321
2322        # Load model with trained weights
2323        try:
2324            logging.info("Loading RF-DETR model...")
2325            model = ModelClass(
2326                pretrain_weights=model_path,
2327                num_classes=num_classes
2328            )
2329
2330            logging.info("RF-DETR model loaded successfully")
2331        except Exception as e:
2332            logging.error(f"Failed to load model: {e}")
2333            return False
2334
2335        # Prepare dataset view
2336        detection_threshold = inference_settings.get("detection_threshold", 0.2)
2337
2338        if inference_settings.get("inference_on_test", True):
2339            INFERENCE_SPLITS = ["test"]
2340            dataset_view = self.dataset.match_tags(INFERENCE_SPLITS)
2341            if len(dataset_view) == 0:
2342                logging.error(f"Dataset has no splits: {INFERENCE_SPLITS}")
2343                return False
2344        else:
2345            dataset_view = self.dataset
2346
2347        # Prediction key
2348        pred_key = f"pred_od_{model_name}-{dataset_name}"
2349
2350        logging.info(f"Running inference on {len(dataset_view)} samples...")
2351        logging.info(f"Detection threshold: {detection_threshold}")
2352
2353        # Run inference on each sample
2354        try:
2355            processed_count = 0
2356
2357            for sample in tqdm(dataset_view.iter_samples(progress=True, autosave=True),
2358                            total=len(dataset_view),
2359                            desc="RF-DETR Inference"):
2360
2361                try:
2362                    # Load image
2363                    image = Image.open(sample.filepath)
2364                    img_width, img_height = image.size
2365
2366                    # Run inference using RF-DETR's predict method
2367                    detections = model.predict(
2368                        image,
2369                        threshold=detection_threshold
2370                    )
2371
2372                    # Convert supervision detections to FiftyOne format
2373                    fo_detections = []
2374
2375                    if len(detections) > 0:
2376                        for i in range(len(detections)):
2377                            # Get detection data (RF-DETR returns supervision format)
2378                            bbox = detections.xyxy[i]  # [x1, y1, x2, y2] in pixel coordinates
2379                            confidence = detections.confidence[i] if detections.confidence is not None else 1.0
2380                            class_id = detections.class_id[i] if detections.class_id is not None else 0
2381
2382                            # Convert to relative coordinates [x, y, width, height]
2383                            x1, y1, x2, y2 = bbox
2384                            rel_x = x1 / img_width
2385                            rel_y = y1 / img_height
2386                            rel_w = (x2 - x1) / img_width
2387                            rel_h = (y2 - y1) / img_height
2388
2389                            # Get class name
2390                            if class_names and class_id < len(class_names):
2391                                class_name = class_names[class_id]
2392                            else:
2393                                class_name = f"class_{class_id}"
2394
2395                            # Create FiftyOne detection
2396                            fo_detection = fo.Detection(
2397                                label=class_name,
2398                                bounding_box=[rel_x, rel_y, rel_w, rel_h],
2399                                confidence=float(confidence)
2400                            )
2401                            fo_detections.append(fo_detection)
2402
2403                    # Save detections to sample
2404                    sample[pred_key] = fo.Detections(detections=fo_detections)
2405                    processed_count += 1
2406
2407                except Exception as e:
2408                    logging.error(f"Error processing sample {sample.id}: {e}")
2409                    continue
2410
2411            logging.info(f"Inference completed on {processed_count}/{len(dataset_view)} samples")
2412            logging.info(f"Predictions saved to field '{pred_key}'")
2413
2414        except Exception as e:
2415            logging.error(f"Error during inference: {e}")
2416            import traceback
2417            traceback.print_exc()
2418            return False
2419
2420        # Evaluate if requested
2421        if inference_settings.get("do_eval", True):
2422            eval_key = f"eval_{model_name}_{dataset_name}".replace("-", "_")
2423
2424            if inference_settings.get("inference_on_test", True):
2425                dataset_view = self.dataset.match_tags(["test"])
2426            else:
2427                dataset_view = self.dataset
2428
2429            # Filter samples that have both predictions and ground truth
2430            dataset_view = dataset_view.exists(pred_key).exists(gt_field)
2431
2432            if len(dataset_view) == 0:
2433                logging.warning("No samples found with both predictions and ground truth for evaluation")
2434            else:
2435                try:
2436                    logging.info(f"Evaluating predictions on {len(dataset_view)} samples...")
2437
2438                    results = dataset_view.evaluate_detections(
2439                        pred_key,
2440                        gt_field=gt_field,
2441                        eval_key=eval_key,
2442                        compute_mAP=True,
2443                        iou=0.5  # IoU threshold for matching
2444                    )
2445
2446                    # Print evaluation report
2447                    logging.info("="*70)
2448                    logging.info("EVALUATION RESULTS")
2449                    logging.info("="*70)
2450                    results.print_report()
2451                    logging.info("="*70)
2452
2453                    logging.info("Evaluation completed")
2454                except Exception as e:
2455                    logging.error(f"Evaluation failed: {e}")
2456                    import traceback
2457                    traceback.print_exc()
2458
2459        return True

Performs inference using RF-DETR model on a dataset with optional evaluation