workflows.auto_label_mask
1import logging 2import fiftyone as fo 3import fiftyone.zoo as foz 4import numpy as np 5import torch 6from PIL import Image 7from transformers import AutoImageProcessor, AutoModelForDepthEstimation 8 9class AutoLabelMask: 10 def __init__(self, dataset, dataset_info, model_name, task_type, model_config): 11 self.dataset = dataset 12 self.dataset_info = dataset_info 13 self.model_name = model_name 14 self.task_type = task_type 15 self.model_config = model_config 16 17 def _sanitize_model_name(self, model_name: str) -> str: 18 """Replacing special characters with underscores""" 19 20 return ( 21 model_name 22 .replace("/", "_") 23 .replace("-", "_") 24 .replace(".", "_") 25 ) 26 27 def run_inference(self): 28 if self.task_type == "semantic_segmentation": 29 self._run_semantic_segmentation() 30 elif self.task_type == "depth_estimation": 31 self._run_depth_estimation() 32 else: 33 logging.error(f"Task type '{self.task_type}' is not supported") 34 35 36 def _run_semantic_segmentation(self): 37 """Handles semantic segmentation inference using SAM2 models""" 38 39 if self.model_name == "sam2": 40 self._handle_sam2_segmentation() 41 else: 42 logging.error( 43 f"Semantic segmentation with model '{self.model_name}' is not supported" 44 ) 45 46 def _handle_sam2_segmentation(self): 47 """Runs SAM2 segmentation inference for all specified models""" 48 49 sam_config = self.model_config 50 prompt_field = sam_config.get("prompt_field", None) 51 sam_models = sam_config["models"] 52 53 for sam_model in sam_models: 54 model_name_clear = self._sanitize_model_name(sam_model) 55 logging.info(f"Running SAM2 segmentation for model: {sam_model}") 56 57 # `pred_ss_` for semantic segmentation 58 if prompt_field: 59 label_field = f"pred_ss_{model_name_clear}_prompt_{prompt_field}" 60 else: 61 label_field = f"pred_ss_{model_name_clear}_noprompt" 62 63 self._inference_sam2(self.dataset, sam_model, label_field, prompt_field) 64 65 def _inference_sam2(self, dataset, sam_model, label_field, prompt_field=None): 66 """Applies SAM2 model to the dataset, optionally using prompt field""" 67 68 logging.info(f"Starting SAM2 inference for model '{sam_model}'") 69 model = foz.load_zoo_model(sam_model) 70 71 if prompt_field: 72 logging.info(f"Running SAM2 with model '{sam_model}' and prompt_field '{prompt_field}'") 73 dataset.apply_model(model, label_field=label_field, prompt_field=prompt_field, progress=True) 74 else: 75 logging.info(f"Running SAM2 with model '{sam_model}' without prompt_field") 76 dataset.apply_model(model, label_field=label_field, progress=True) 77 78 def _run_depth_estimation(self): 79 """Handles depth estimation inference for supported models""" 80 81 if self.model_name not in ["dpt", "depth_anything", "depth_pro", "glpn", "zoe_depth"]: 82 logging.error(f"Depth estimation model '{self.model_name}' not supported") 83 return 84 85 depth_models = self.model_config["models"] 86 87 for depth_model in depth_models: 88 depth_model_clear = self._sanitize_model_name(depth_model) 89 logging.info(f"Running depth estimation for model: {depth_model}") 90 91 # `pred_de_` for depth estimation 92 label_field = f"pred_de_{depth_model_clear}" 93 94 self._inference_depth_estimation(self.dataset, depth_model, label_field) 95 96 logging.info(f"Depth estimation completed for all '{self.model_name}' models.") 97 98 def _inference_depth_estimation(self, dataset, model_name, label_field): 99 """Applies depth estimation model to each sample in dataset""" 100 101 logging.info(f"Starting depth estimation for HF model '{model_name}'") 102 image_processor = AutoImageProcessor.from_pretrained(model_name) 103 model = AutoModelForDepthEstimation.from_pretrained( 104 model_name, 105 ignore_mismatched_sizes=True 106 ) 107 108 def apply_depth_model(sample, depth_model, processor, out_field): 109 110 pil_image = Image.open(sample.filepath).convert("RGB") 111 depth_inputs = processor(images=pil_image, return_tensors="pt") 112 113 with torch.no_grad(): 114 depth_outputs = depth_model(**depth_inputs) 115 predicted_depth = depth_outputs.predicted_depth 116 117 resized_depth = torch.nn.functional.interpolate( 118 predicted_depth.unsqueeze(1), 119 size=pil_image.size[::-1], 120 mode="bicubic", 121 align_corners=False, 122 ) 123 124 depth_map = resized_depth.squeeze().cpu().numpy() 125 126 # different depth estimation models may output depth maps with different scaling, thus specific post-processing to normalize them properly 127 if self.model_name in ["dpt", "depth_anything", "depth_pro"]: 128 if np.max(depth_map) > 0: # avoid division by zero 129 depth_map = (255 - depth_map * 255 / np.max(depth_map)).astype("uint8") 130 131 elif self.model_name in ["zoe", "glpn"]: 132 depth_map = (depth_map * 255).astype("uint8") 133 134 else: 135 logging.error(f"Unsupported model: {self.model_name}") 136 raise ValueError(f"Unsupported model: {self.model_name}") 137 logging.info(f"Saving depth estimation result to field '{out_field}' for sample ID {sample.id}") 138 139 sample[out_field] = fo.Heatmap(map=depth_map) 140 sample.save() 141 142 for sample in dataset.iter_samples(autosave=True, progress=True): 143 apply_depth_model(sample, model, image_processor, label_field) 144 145 logging.info(f"Depth estimation inference finished for '{model_name}'")
class
AutoLabelMask:
10class AutoLabelMask: 11 def __init__(self, dataset, dataset_info, model_name, task_type, model_config): 12 self.dataset = dataset 13 self.dataset_info = dataset_info 14 self.model_name = model_name 15 self.task_type = task_type 16 self.model_config = model_config 17 18 def _sanitize_model_name(self, model_name: str) -> str: 19 """Replacing special characters with underscores""" 20 21 return ( 22 model_name 23 .replace("/", "_") 24 .replace("-", "_") 25 .replace(".", "_") 26 ) 27 28 def run_inference(self): 29 if self.task_type == "semantic_segmentation": 30 self._run_semantic_segmentation() 31 elif self.task_type == "depth_estimation": 32 self._run_depth_estimation() 33 else: 34 logging.error(f"Task type '{self.task_type}' is not supported") 35 36 37 def _run_semantic_segmentation(self): 38 """Handles semantic segmentation inference using SAM2 models""" 39 40 if self.model_name == "sam2": 41 self._handle_sam2_segmentation() 42 else: 43 logging.error( 44 f"Semantic segmentation with model '{self.model_name}' is not supported" 45 ) 46 47 def _handle_sam2_segmentation(self): 48 """Runs SAM2 segmentation inference for all specified models""" 49 50 sam_config = self.model_config 51 prompt_field = sam_config.get("prompt_field", None) 52 sam_models = sam_config["models"] 53 54 for sam_model in sam_models: 55 model_name_clear = self._sanitize_model_name(sam_model) 56 logging.info(f"Running SAM2 segmentation for model: {sam_model}") 57 58 # `pred_ss_` for semantic segmentation 59 if prompt_field: 60 label_field = f"pred_ss_{model_name_clear}_prompt_{prompt_field}" 61 else: 62 label_field = f"pred_ss_{model_name_clear}_noprompt" 63 64 self._inference_sam2(self.dataset, sam_model, label_field, prompt_field) 65 66 def _inference_sam2(self, dataset, sam_model, label_field, prompt_field=None): 67 """Applies SAM2 model to the dataset, optionally using prompt field""" 68 69 logging.info(f"Starting SAM2 inference for model '{sam_model}'") 70 model = foz.load_zoo_model(sam_model) 71 72 if prompt_field: 73 logging.info(f"Running SAM2 with model '{sam_model}' and prompt_field '{prompt_field}'") 74 dataset.apply_model(model, label_field=label_field, prompt_field=prompt_field, progress=True) 75 else: 76 logging.info(f"Running SAM2 with model '{sam_model}' without prompt_field") 77 dataset.apply_model(model, label_field=label_field, progress=True) 78 79 def _run_depth_estimation(self): 80 """Handles depth estimation inference for supported models""" 81 82 if self.model_name not in ["dpt", "depth_anything", "depth_pro", "glpn", "zoe_depth"]: 83 logging.error(f"Depth estimation model '{self.model_name}' not supported") 84 return 85 86 depth_models = self.model_config["models"] 87 88 for depth_model in depth_models: 89 depth_model_clear = self._sanitize_model_name(depth_model) 90 logging.info(f"Running depth estimation for model: {depth_model}") 91 92 # `pred_de_` for depth estimation 93 label_field = f"pred_de_{depth_model_clear}" 94 95 self._inference_depth_estimation(self.dataset, depth_model, label_field) 96 97 logging.info(f"Depth estimation completed for all '{self.model_name}' models.") 98 99 def _inference_depth_estimation(self, dataset, model_name, label_field): 100 """Applies depth estimation model to each sample in dataset""" 101 102 logging.info(f"Starting depth estimation for HF model '{model_name}'") 103 image_processor = AutoImageProcessor.from_pretrained(model_name) 104 model = AutoModelForDepthEstimation.from_pretrained( 105 model_name, 106 ignore_mismatched_sizes=True 107 ) 108 109 def apply_depth_model(sample, depth_model, processor, out_field): 110 111 pil_image = Image.open(sample.filepath).convert("RGB") 112 depth_inputs = processor(images=pil_image, return_tensors="pt") 113 114 with torch.no_grad(): 115 depth_outputs = depth_model(**depth_inputs) 116 predicted_depth = depth_outputs.predicted_depth 117 118 resized_depth = torch.nn.functional.interpolate( 119 predicted_depth.unsqueeze(1), 120 size=pil_image.size[::-1], 121 mode="bicubic", 122 align_corners=False, 123 ) 124 125 depth_map = resized_depth.squeeze().cpu().numpy() 126 127 # different depth estimation models may output depth maps with different scaling, thus specific post-processing to normalize them properly 128 if self.model_name in ["dpt", "depth_anything", "depth_pro"]: 129 if np.max(depth_map) > 0: # avoid division by zero 130 depth_map = (255 - depth_map * 255 / np.max(depth_map)).astype("uint8") 131 132 elif self.model_name in ["zoe", "glpn"]: 133 depth_map = (depth_map * 255).astype("uint8") 134 135 else: 136 logging.error(f"Unsupported model: {self.model_name}") 137 raise ValueError(f"Unsupported model: {self.model_name}") 138 logging.info(f"Saving depth estimation result to field '{out_field}' for sample ID {sample.id}") 139 140 sample[out_field] = fo.Heatmap(map=depth_map) 141 sample.save() 142 143 for sample in dataset.iter_samples(autosave=True, progress=True): 144 apply_depth_model(sample, model, image_processor, label_field) 145 146 logging.info(f"Depth estimation inference finished for '{model_name}'")