workflows.auto_label_mask

  1import logging
  2import fiftyone as fo
  3import fiftyone.zoo as foz
  4import numpy as np
  5import torch
  6from PIL import Image
  7from transformers import AutoImageProcessor, AutoModelForDepthEstimation
  8
  9class AutoLabelMask:
 10    def __init__(self, dataset, dataset_info, model_name, task_type, model_config):
 11        self.dataset = dataset
 12        self.dataset_info = dataset_info
 13        self.model_name = model_name
 14        self.task_type = task_type
 15        self.model_config = model_config
 16
 17    def _sanitize_model_name(self, model_name: str) -> str:
 18        """Replacing special characters with underscores"""
 19
 20        return (
 21            model_name
 22            .replace("/", "_")
 23            .replace("-", "_")
 24            .replace(".", "_")
 25        )
 26
 27    def run_inference(self):
 28        if self.task_type == "semantic_segmentation":
 29            self._run_semantic_segmentation()
 30        elif self.task_type == "depth_estimation":
 31            self._run_depth_estimation()
 32        else:
 33            logging.error(f"Task type '{self.task_type}' is not supported")
 34
 35
 36    def _run_semantic_segmentation(self):
 37        """Handles semantic segmentation inference using SAM2 models"""
 38
 39        if self.model_name == "sam2":
 40            self._handle_sam2_segmentation()
 41        else:
 42            logging.error(
 43                f"Semantic segmentation with model '{self.model_name}' is not supported"
 44            )
 45
 46    def _handle_sam2_segmentation(self):
 47        """Runs SAM2 segmentation inference for all specified models"""
 48
 49        sam_config = self.model_config
 50        prompt_field = sam_config.get("prompt_field", None)
 51        sam_models = sam_config["models"]
 52
 53        for sam_model in sam_models:
 54            model_name_clear = self._sanitize_model_name(sam_model)
 55            logging.info(f"Running SAM2 segmentation for model: {sam_model}")
 56
 57            # `pred_ss_` for semantic segmentation
 58            if prompt_field:
 59                label_field = f"pred_ss_{model_name_clear}_prompt_{prompt_field}"
 60            else:
 61                label_field = f"pred_ss_{model_name_clear}_noprompt"
 62
 63            self._inference_sam2(self.dataset, sam_model, label_field, prompt_field)
 64
 65    def _inference_sam2(self, dataset, sam_model, label_field, prompt_field=None):
 66        """Applies SAM2 model to the dataset, optionally using prompt field"""
 67
 68        logging.info(f"Starting SAM2 inference for model '{sam_model}'")
 69        model = foz.load_zoo_model(sam_model)
 70
 71        if prompt_field:
 72            logging.info(f"Running SAM2 with model '{sam_model}' and prompt_field '{prompt_field}'")
 73            dataset.apply_model(model, label_field=label_field, prompt_field=prompt_field, progress=True)
 74        else:
 75            logging.info(f"Running SAM2 with model '{sam_model}' without prompt_field")
 76            dataset.apply_model(model, label_field=label_field, progress=True)
 77
 78    def _run_depth_estimation(self):
 79        """Handles depth estimation inference for supported models"""
 80
 81        if self.model_name not in ["dpt", "depth_anything", "depth_pro", "glpn", "zoe_depth"]:
 82            logging.error(f"Depth estimation model '{self.model_name}' not supported")
 83            return
 84
 85        depth_models = self.model_config["models"]
 86
 87        for depth_model in depth_models:
 88            depth_model_clear = self._sanitize_model_name(depth_model)
 89            logging.info(f"Running depth estimation for model: {depth_model}")
 90
 91            # `pred_de_` for depth estimation
 92            label_field = f"pred_de_{depth_model_clear}"
 93
 94            self._inference_depth_estimation(self.dataset, depth_model, label_field)
 95
 96        logging.info(f"Depth estimation completed for all '{self.model_name}' models.")
 97
 98    def _inference_depth_estimation(self, dataset, model_name, label_field):
 99        """Applies depth estimation model to each sample in dataset"""
100
101        logging.info(f"Starting depth estimation for HF model '{model_name}'")
102        image_processor = AutoImageProcessor.from_pretrained(model_name)
103        model = AutoModelForDepthEstimation.from_pretrained(
104            model_name,
105            ignore_mismatched_sizes=True
106        )
107
108        def apply_depth_model(sample, depth_model, processor, out_field):
109
110            pil_image = Image.open(sample.filepath).convert("RGB")
111            depth_inputs = processor(images=pil_image, return_tensors="pt")
112
113            with torch.no_grad():
114                depth_outputs = depth_model(**depth_inputs)
115                predicted_depth = depth_outputs.predicted_depth
116
117            resized_depth = torch.nn.functional.interpolate(
118                predicted_depth.unsqueeze(1),
119                size=pil_image.size[::-1],
120                mode="bicubic",
121                align_corners=False,
122            )
123
124            depth_map = resized_depth.squeeze().cpu().numpy()
125
126            # different depth estimation models may output depth maps with different scaling, thus specific post-processing to normalize them properly
127            if self.model_name in ["dpt", "depth_anything", "depth_pro"]:
128                if np.max(depth_map) > 0: # avoid division by zero
129                    depth_map = (255 - depth_map * 255 / np.max(depth_map)).astype("uint8")
130
131                elif self.model_name in ["zoe", "glpn"]:
132                    depth_map = (depth_map * 255).astype("uint8")
133
134                else:
135                    logging.error(f"Unsupported model: {self.model_name}")
136                    raise ValueError(f"Unsupported model: {self.model_name}")
137            logging.info(f"Saving depth estimation result to field '{out_field}' for sample ID {sample.id}")
138
139            sample[out_field] = fo.Heatmap(map=depth_map)
140            sample.save()
141
142        for sample in dataset.iter_samples(autosave=True, progress=True):
143            apply_depth_model(sample, model, image_processor, label_field)
144
145        logging.info(f"Depth estimation inference finished for '{model_name}'")
class AutoLabelMask:
 10class AutoLabelMask:
 11    def __init__(self, dataset, dataset_info, model_name, task_type, model_config):
 12        self.dataset = dataset
 13        self.dataset_info = dataset_info
 14        self.model_name = model_name
 15        self.task_type = task_type
 16        self.model_config = model_config
 17
 18    def _sanitize_model_name(self, model_name: str) -> str:
 19        """Replacing special characters with underscores"""
 20
 21        return (
 22            model_name
 23            .replace("/", "_")
 24            .replace("-", "_")
 25            .replace(".", "_")
 26        )
 27
 28    def run_inference(self):
 29        if self.task_type == "semantic_segmentation":
 30            self._run_semantic_segmentation()
 31        elif self.task_type == "depth_estimation":
 32            self._run_depth_estimation()
 33        else:
 34            logging.error(f"Task type '{self.task_type}' is not supported")
 35
 36
 37    def _run_semantic_segmentation(self):
 38        """Handles semantic segmentation inference using SAM2 models"""
 39
 40        if self.model_name == "sam2":
 41            self._handle_sam2_segmentation()
 42        else:
 43            logging.error(
 44                f"Semantic segmentation with model '{self.model_name}' is not supported"
 45            )
 46
 47    def _handle_sam2_segmentation(self):
 48        """Runs SAM2 segmentation inference for all specified models"""
 49
 50        sam_config = self.model_config
 51        prompt_field = sam_config.get("prompt_field", None)
 52        sam_models = sam_config["models"]
 53
 54        for sam_model in sam_models:
 55            model_name_clear = self._sanitize_model_name(sam_model)
 56            logging.info(f"Running SAM2 segmentation for model: {sam_model}")
 57
 58            # `pred_ss_` for semantic segmentation
 59            if prompt_field:
 60                label_field = f"pred_ss_{model_name_clear}_prompt_{prompt_field}"
 61            else:
 62                label_field = f"pred_ss_{model_name_clear}_noprompt"
 63
 64            self._inference_sam2(self.dataset, sam_model, label_field, prompt_field)
 65
 66    def _inference_sam2(self, dataset, sam_model, label_field, prompt_field=None):
 67        """Applies SAM2 model to the dataset, optionally using prompt field"""
 68
 69        logging.info(f"Starting SAM2 inference for model '{sam_model}'")
 70        model = foz.load_zoo_model(sam_model)
 71
 72        if prompt_field:
 73            logging.info(f"Running SAM2 with model '{sam_model}' and prompt_field '{prompt_field}'")
 74            dataset.apply_model(model, label_field=label_field, prompt_field=prompt_field, progress=True)
 75        else:
 76            logging.info(f"Running SAM2 with model '{sam_model}' without prompt_field")
 77            dataset.apply_model(model, label_field=label_field, progress=True)
 78
 79    def _run_depth_estimation(self):
 80        """Handles depth estimation inference for supported models"""
 81
 82        if self.model_name not in ["dpt", "depth_anything", "depth_pro", "glpn", "zoe_depth"]:
 83            logging.error(f"Depth estimation model '{self.model_name}' not supported")
 84            return
 85
 86        depth_models = self.model_config["models"]
 87
 88        for depth_model in depth_models:
 89            depth_model_clear = self._sanitize_model_name(depth_model)
 90            logging.info(f"Running depth estimation for model: {depth_model}")
 91
 92            # `pred_de_` for depth estimation
 93            label_field = f"pred_de_{depth_model_clear}"
 94
 95            self._inference_depth_estimation(self.dataset, depth_model, label_field)
 96
 97        logging.info(f"Depth estimation completed for all '{self.model_name}' models.")
 98
 99    def _inference_depth_estimation(self, dataset, model_name, label_field):
100        """Applies depth estimation model to each sample in dataset"""
101
102        logging.info(f"Starting depth estimation for HF model '{model_name}'")
103        image_processor = AutoImageProcessor.from_pretrained(model_name)
104        model = AutoModelForDepthEstimation.from_pretrained(
105            model_name,
106            ignore_mismatched_sizes=True
107        )
108
109        def apply_depth_model(sample, depth_model, processor, out_field):
110
111            pil_image = Image.open(sample.filepath).convert("RGB")
112            depth_inputs = processor(images=pil_image, return_tensors="pt")
113
114            with torch.no_grad():
115                depth_outputs = depth_model(**depth_inputs)
116                predicted_depth = depth_outputs.predicted_depth
117
118            resized_depth = torch.nn.functional.interpolate(
119                predicted_depth.unsqueeze(1),
120                size=pil_image.size[::-1],
121                mode="bicubic",
122                align_corners=False,
123            )
124
125            depth_map = resized_depth.squeeze().cpu().numpy()
126
127            # different depth estimation models may output depth maps with different scaling, thus specific post-processing to normalize them properly
128            if self.model_name in ["dpt", "depth_anything", "depth_pro"]:
129                if np.max(depth_map) > 0: # avoid division by zero
130                    depth_map = (255 - depth_map * 255 / np.max(depth_map)).astype("uint8")
131
132                elif self.model_name in ["zoe", "glpn"]:
133                    depth_map = (depth_map * 255).astype("uint8")
134
135                else:
136                    logging.error(f"Unsupported model: {self.model_name}")
137                    raise ValueError(f"Unsupported model: {self.model_name}")
138            logging.info(f"Saving depth estimation result to field '{out_field}' for sample ID {sample.id}")
139
140            sample[out_field] = fo.Heatmap(map=depth_map)
141            sample.save()
142
143        for sample in dataset.iter_samples(autosave=True, progress=True):
144            apply_depth_model(sample, model, image_processor, label_field)
145
146        logging.info(f"Depth estimation inference finished for '{model_name}'")
AutoLabelMask(dataset, dataset_info, model_name, task_type, model_config)
11    def __init__(self, dataset, dataset_info, model_name, task_type, model_config):
12        self.dataset = dataset
13        self.dataset_info = dataset_info
14        self.model_name = model_name
15        self.task_type = task_type
16        self.model_config = model_config
dataset
dataset_info
model_name
task_type
model_config
def run_inference(self):
28    def run_inference(self):
29        if self.task_type == "semantic_segmentation":
30            self._run_semantic_segmentation()
31        elif self.task_type == "depth_estimation":
32            self._run_depth_estimation()
33        else:
34            logging.error(f"Task type '{self.task_type}' is not supported")