workflows.class_mapping
1import logging 2import torch 3from transformers import AutoConfig, AutoProcessor, AutoModelForZeroShotImageClassification, AutoModel, AlignProcessor, AlignModel, AltCLIPModel, AltCLIPProcessor, CLIPSegProcessor, CLIPSegForImageSegmentation 4import os 5import datetime 6from PIL import Image 7from tqdm import tqdm 8from config.config import WORKFLOWS 9import fiftyone as fo 10import wandb 11from utils.dataset_loader import load_dataset 12import torch.nn.functional as F 13from torch.utils.tensorboard import SummaryWriter 14 15 16class ClassMapper: 17 """Class mapper that uses various HuggingFace models to align class labels between the source and target datasets.""" 18 def __init__(self, dataset, model_name, config=None): 19 """Initialize the ClassMapper with dataset and model configuration.""" 20 self.dataset = dataset 21 self.model_name = model_name 22 # Get default config from WORKFLOWS and update with any provided config. 23 self.config = WORKFLOWS["class_mapping"].copy() 24 self.change_labels = config["change_labels"] 25 if config: 26 self.config.update(config) 27 28 self.device = "cuda" if torch.cuda.is_available() else "cpu" 29 self.processor = None 30 self.model = None 31 self.stats = { 32 "total_processed": 0, 33 "changes_made": 0, 34 "source_class_counts": {}, # Count of processed detections per source class (e.g., Car, Truck) 35 "tags_added_per_category": {} # Detailed counts for tags added per target class(e.g., Van, Pickup) 36 } 37 38 def load_model(self): 39 """Load the model and processor from HuggingFace.""" 40 try: 41 self.hf_model_config = AutoConfig.from_pretrained(self.model_name) 42 self.hf_model_config_name = type(self.hf_model_config).__name__ 43 44 if self.hf_model_config_name == "SiglipConfig": 45 self.model = AutoModel.from_pretrained(self.model_name) 46 self.processor = AutoProcessor.from_pretrained(self.model_name) 47 48 elif self.hf_model_config_name == "AlignConfig": 49 self.processor = AlignProcessor.from_pretrained(self.model_name) 50 self.model = AlignModel.from_pretrained(self.model_name) 51 52 elif self.hf_model_config_name == "AltCLIPConfig": 53 self.processor = AltCLIPProcessor.from_pretrained(self.model_name) 54 self.model = AltCLIPModel.from_pretrained(self.model_name) 55 56 elif self.hf_model_config_name == "CLIPSegConfig": 57 self.processor = CLIPSegProcessor.from_pretrained(self.model_name) 58 self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_name) 59 60 elif self.hf_model_config_name in ["Blip2Config", "CLIPConfig"]: 61 self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_name) 62 self.processor = AutoProcessor.from_pretrained(self.model_name) 63 64 else: 65 logging.error(f"Invalid Model Name : {self.model_name}") 66 67 self.model.to(self.device) 68 logging.info(f"Successfully loaded model {self.model_name}") 69 70 except Exception as e: 71 logging.error(f"Failed to load model: {str(e)}") 72 raise 73 74 def process_detection(self, image, detection, candidate_labels): 75 """Process a single detection with the model.""" 76 # Convert bounding box to pixel coordinates. 77 img_width, img_height = image.size 78 bbox = detection.bounding_box 79 min_x, min_y, width, height = bbox 80 x1, y1 = int(min_x * img_width), int(min_y * img_height) 81 x2, y2 = int((min_x + width) * img_width), int((min_y + height) * img_height) 82 83 # Crop image to detection region. 84 image_patch = image.crop((x1, y1, x2, y2)) 85 86 # Prepare inputs for the model. 87 if self.hf_model_config_name == "SiglipConfig": 88 target_size = (384, 384) # Adjust per model 89 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 90 inputs = self.processor(text=candidate_labels, images=image_patch, padding="max_length", return_tensors="pt") 91 92 elif self.hf_model_config_name == "CLIPSegConfig": 93 target_size = (224, 224) # Adjust per model 94 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 95 inputs = self.processor(text=candidate_labels, images=[image_patch]*len(candidate_labels), padding="max_length", return_tensors="pt") 96 97 else: 98 target_size = (224, 224) # Adjust per model 99 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 100 inputs = self.processor(images=image_patch, text=candidate_labels, return_tensors="pt", padding=True) 101 102 # Ensure all tensors in the processed inputs are moved to the designated device. 103 inputs = {k: v.to(self.device) for k, v in inputs.items()} 104 105 # Generate classification output. 106 with torch.no_grad(): 107 outputs = self.model(**inputs) 108 109 predicted_label = None 110 confidence_score = None 111 112 if self.hf_model_config_name == "SiglipConfig": 113 # Apply sigmoid for probabilities 114 logits = outputs.logits_per_image 115 probs = torch.sigmoid(logits) 116 probs = torch.softmax(probs, dim=1) 117 max_prob, predicted_idx = probs[0].max(dim=-1) 118 predicted_label = candidate_labels[predicted_idx.item()] 119 confidence_score = max_prob.item() 120 121 elif self.hf_model_config_name in ["AlignConfig", "AltCLIPConfig"]: 122 logits = outputs.logits_per_image 123 probs = torch.softmax(logits, dim=1) 124 max_prob, predicted_idx = probs[0].max(dim=-1) 125 predicted_label = candidate_labels[predicted_idx.item()] 126 confidence_score = max_prob.item() 127 128 elif self.hf_model_config_name == "CLIPSegConfig": 129 # Get masks and ensure batch dimension exists 130 masks = torch.sigmoid(outputs.logits) 131 if masks.dim() == 2: # Handle single-example edge case 132 masks = masks.unsqueeze(0) # Add batch dimension 133 134 # Verify mask dimensions 135 batch_size, mask_height, mask_width = masks.shape 136 137 # Convert detection box to mask coordinates 138 box_x1 = int(x1 * mask_width / img_width) 139 box_y1 = int(y1 * mask_height / img_height) 140 box_x2 = int(x2 * mask_width / img_width) 141 box_y2 = int(y2 * mask_height / img_height) 142 143 # Calculate scores for each candidate label 144 scores = [] 145 for i in range(batch_size): 146 # Extract relevant mask region 147 label_mask = masks[i, box_y1:box_y2, box_x1:box_x2] 148 149 # Handle empty regions gracefully 150 if label_mask.numel() == 0: 151 scores.append(0.0) 152 continue 153 scores.append(label_mask.mean().item()) 154 155 # Find best match 156 confidence_score = max(scores) 157 predicted_idx = scores.index(confidence_score) 158 predicted_label = candidate_labels[predicted_idx] 159 160 else: 161 logits = outputs.logits_per_image 162 max_logit, predicted_idx = logits.max(dim=-1) 163 predicted_label = candidate_labels[predicted_idx.item()] 164 confidence_score = max_logit.item() 165 166 return predicted_label, confidence_score 167 168 def run_mapping(self, test_dataset_source, test_dataset_target, label_field = "ground_truth",): 169 """Run the class mapping process between the source dataset and the target dataset.""" 170 if not self.model: 171 self.load_model() 172 173 dataset_source_name = self.config.get("dataset_source") 174 dataset_target_name = self.config.get("dataset_target") 175 176 if not dataset_source_name or not dataset_target_name: 177 logging.error("Both 'dataset_source' and 'dataset_target' must be specified in the config.") 178 raise ValueError("Both 'dataset_source' and 'dataset_target' must be specified in the config.") 179 180 # Load the datasets from FiftyOne. 181 try: 182 if test_dataset_source is None: 183 SELECTED_DATASET = { 184 "name": dataset_source_name, 185 "n_samples": None, # 'None' (full dataset) or 'int' (subset of the dataset) 186 } 187 source_dataset, source_dataset_info = load_dataset(SELECTED_DATASET) 188 else: 189 source_dataset = test_dataset_source 190 191 except Exception as e: 192 logging.error(f"Failed to load dataset_source '{dataset_source_name}': {e}") 193 raise ValueError(f"Failed to load dataset_source '{dataset_source_name}': {e}") 194 195 try: 196 if test_dataset_target is None: 197 SELECTED_DATASET = { 198 "name": dataset_target_name, 199 "n_samples": None, # 'None' (full dataset) or 'int' (subset of the dataset) 200 } 201 target_dataset, target_dataset_info = load_dataset(SELECTED_DATASET) 202 else: 203 target_dataset = test_dataset_target 204 205 except Exception as e: 206 logging.error(f"Failed to load dataset_target '{dataset_target_name}': {e}") 207 raise ValueError(f"Failed to load dataset_target '{dataset_target_name}': {e}") 208 209 # Get the distinct labels present in each dataset. 210 source_labels = source_dataset.distinct(f"{label_field}.detections.label") 211 target_labels = target_dataset.distinct(f"{label_field}.detections.label") 212 213 # Access candidate labels from config 214 candidate_labels = self.config["candidate_labels"] 215 216 # Check that all labels from the candidate_labels exist in dataset_source. 217 input_source_labels = list(candidate_labels.keys()) 218 missing_source_labels = [p for p in input_source_labels if p not in source_labels] 219 if missing_source_labels: 220 error_msg = (f"Missing labels in dataset_source '{dataset_source_name}': {missing_source_labels}\n" 221 f"Expected source labels: {input_source_labels}\n" 222 f"Found in dataset_source: {source_labels}") 223 logging.error(error_msg) 224 raise ValueError(error_msg) 225 226 # Check that all labels from candidate_labels exist in dataset_target. 227 all_target_labels = [] 228 for input_target_labels in candidate_labels.values(): 229 all_target_labels.extend(input_target_labels) 230 missing_target = [target_temp for target_temp in all_target_labels if target_temp not in target_labels] 231 if missing_target: 232 error_msg = (f"Missing labels in dataset_target '{dataset_target_name}': {missing_target}\n" 233 f"Expected target classes: {all_target_labels}\n" 234 f"Found in dataset_target: {target_labels}") 235 logging.error(error_msg) 236 raise ValueError(error_msg) 237 238 one_to_one_mapping = { 239 source: target[0] # Only pick mappings with exactly 1 target class 240 for source, target in candidate_labels.items() 241 if len(target) == 1 242 } 243 244 log_root="./logs/" 245 experiment_name = f"{self.model_name}_class_mapping_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" 246 tensorboard_root = os.path.join( 247 log_root, "tensorboard/class_mapping" 248 ) 249 dataset_name = getattr(self.dataset, "name", "default_dataset") 250 log_directory = os.path.join( 251 tensorboard_root, dataset_name, experiment_name 252 ) 253 254 tb_writer = SummaryWriter(log_dir=log_directory) 255 256 threshold = self.config["thresholds"]["confidence"] 257 sample_count = 0 # For logging steps 258 259 for sample in self.dataset.iter_samples(progress=True, autosave=True): 260 sample_count += 1 261 try: 262 image = Image.open(sample.filepath) 263 except Exception as e: 264 logging.error(f"Error opening image {sample.filepath}: {str(e)}") 265 continue 266 267 detections = sample[label_field] 268 for det in detections.detections: 269 current_label = det.label 270 271 # Check if it is a one-to-one mapping 272 if current_label in one_to_one_mapping: 273 new_label = one_to_one_mapping[current_label] 274 tag = f"new_class_{new_label}_changed_from_{current_label}" 275 276 # Apply new label and add tag 277 if tag not in det.tags: 278 det.tags.append(tag) 279 det.label = new_label # Replace label directly 280 281 self.stats["tags_added_per_category"][new_label] = ( 282 self.stats["tags_added_per_category"].get(new_label, 0) + 1 283 ) 284 # Log tag changes 285 tb_writer.add_scalar(f"Tags_Added/{new_label}", 286 self.stats["tags_added_per_category"].get(new_label, 0), 287 sample_count) 288 continue 289 290 if current_label not in candidate_labels: 291 continue 292 293 current_candidate_labels = candidate_labels[current_label] 294 predicted_label, confidence = self.process_detection(image, det, current_candidate_labels) 295 296 self.stats["source_class_counts"][current_label] = self.stats["source_class_counts"].get(current_label, 0) + 1 297 self.stats["total_processed"] += 1 298 299 # If prediction meets threshold and differs from the original label, add a tag. 300 if confidence > threshold and current_label.lower() != predicted_label.lower(): 301 tag = f"new_class_{self.model_name}_{predicted_label}" 302 if tag not in det.tags: 303 det.tags.append(tag) 304 305 #Change Lables if the flag is set to True 306 if self.change_labels: 307 det.label = predicted_label 308 self.stats["changes_made"] += 1 309 310 self.stats["tags_added_per_category"][predicted_label] = ( 311 self.stats["tags_added_per_category"].get(predicted_label, 0) + 1 312 ) 313 314 tb_writer.add_scalar(f"Tags_Added/{predicted_label}", 315 self.stats["tags_added_per_category"].get(predicted_label, 0), 316 sample_count) 317 tb_writer.add_scalar(f"Total_Tags_Added", 318 self.stats["changes_made"], 319 sample_count) 320 321 tb_writer.add_scalar(f"Total_Processed_Samples", 322 self.stats["total_processed"], 323 sample_count) 324 325 # Log class counts from source and target tag counts based on candidate_labels. 326 for input_source_label, target_temp_labels in self.config["candidate_labels"].items(): 327 tb_writer.add_scalar(f"Source_Class_Count/{input_source_label}", 328 self.stats["source_class_counts"].get(input_source_label, 0), 329 sample_count) 330 for target_temp_label in target_temp_labels: 331 tb_writer.add_scalar(f"Tags_Added/{target_temp_label}", 332 self.stats["tags_added_per_category"].get(target_temp_label, 0), 333 sample_count) 334 335 tb_writer.close() 336 337 return self.stats
class
ClassMapper:
17class ClassMapper: 18 """Class mapper that uses various HuggingFace models to align class labels between the source and target datasets.""" 19 def __init__(self, dataset, model_name, config=None): 20 """Initialize the ClassMapper with dataset and model configuration.""" 21 self.dataset = dataset 22 self.model_name = model_name 23 # Get default config from WORKFLOWS and update with any provided config. 24 self.config = WORKFLOWS["class_mapping"].copy() 25 self.change_labels = config["change_labels"] 26 if config: 27 self.config.update(config) 28 29 self.device = "cuda" if torch.cuda.is_available() else "cpu" 30 self.processor = None 31 self.model = None 32 self.stats = { 33 "total_processed": 0, 34 "changes_made": 0, 35 "source_class_counts": {}, # Count of processed detections per source class (e.g., Car, Truck) 36 "tags_added_per_category": {} # Detailed counts for tags added per target class(e.g., Van, Pickup) 37 } 38 39 def load_model(self): 40 """Load the model and processor from HuggingFace.""" 41 try: 42 self.hf_model_config = AutoConfig.from_pretrained(self.model_name) 43 self.hf_model_config_name = type(self.hf_model_config).__name__ 44 45 if self.hf_model_config_name == "SiglipConfig": 46 self.model = AutoModel.from_pretrained(self.model_name) 47 self.processor = AutoProcessor.from_pretrained(self.model_name) 48 49 elif self.hf_model_config_name == "AlignConfig": 50 self.processor = AlignProcessor.from_pretrained(self.model_name) 51 self.model = AlignModel.from_pretrained(self.model_name) 52 53 elif self.hf_model_config_name == "AltCLIPConfig": 54 self.processor = AltCLIPProcessor.from_pretrained(self.model_name) 55 self.model = AltCLIPModel.from_pretrained(self.model_name) 56 57 elif self.hf_model_config_name == "CLIPSegConfig": 58 self.processor = CLIPSegProcessor.from_pretrained(self.model_name) 59 self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_name) 60 61 elif self.hf_model_config_name in ["Blip2Config", "CLIPConfig"]: 62 self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_name) 63 self.processor = AutoProcessor.from_pretrained(self.model_name) 64 65 else: 66 logging.error(f"Invalid Model Name : {self.model_name}") 67 68 self.model.to(self.device) 69 logging.info(f"Successfully loaded model {self.model_name}") 70 71 except Exception as e: 72 logging.error(f"Failed to load model: {str(e)}") 73 raise 74 75 def process_detection(self, image, detection, candidate_labels): 76 """Process a single detection with the model.""" 77 # Convert bounding box to pixel coordinates. 78 img_width, img_height = image.size 79 bbox = detection.bounding_box 80 min_x, min_y, width, height = bbox 81 x1, y1 = int(min_x * img_width), int(min_y * img_height) 82 x2, y2 = int((min_x + width) * img_width), int((min_y + height) * img_height) 83 84 # Crop image to detection region. 85 image_patch = image.crop((x1, y1, x2, y2)) 86 87 # Prepare inputs for the model. 88 if self.hf_model_config_name == "SiglipConfig": 89 target_size = (384, 384) # Adjust per model 90 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 91 inputs = self.processor(text=candidate_labels, images=image_patch, padding="max_length", return_tensors="pt") 92 93 elif self.hf_model_config_name == "CLIPSegConfig": 94 target_size = (224, 224) # Adjust per model 95 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 96 inputs = self.processor(text=candidate_labels, images=[image_patch]*len(candidate_labels), padding="max_length", return_tensors="pt") 97 98 else: 99 target_size = (224, 224) # Adjust per model 100 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 101 inputs = self.processor(images=image_patch, text=candidate_labels, return_tensors="pt", padding=True) 102 103 # Ensure all tensors in the processed inputs are moved to the designated device. 104 inputs = {k: v.to(self.device) for k, v in inputs.items()} 105 106 # Generate classification output. 107 with torch.no_grad(): 108 outputs = self.model(**inputs) 109 110 predicted_label = None 111 confidence_score = None 112 113 if self.hf_model_config_name == "SiglipConfig": 114 # Apply sigmoid for probabilities 115 logits = outputs.logits_per_image 116 probs = torch.sigmoid(logits) 117 probs = torch.softmax(probs, dim=1) 118 max_prob, predicted_idx = probs[0].max(dim=-1) 119 predicted_label = candidate_labels[predicted_idx.item()] 120 confidence_score = max_prob.item() 121 122 elif self.hf_model_config_name in ["AlignConfig", "AltCLIPConfig"]: 123 logits = outputs.logits_per_image 124 probs = torch.softmax(logits, dim=1) 125 max_prob, predicted_idx = probs[0].max(dim=-1) 126 predicted_label = candidate_labels[predicted_idx.item()] 127 confidence_score = max_prob.item() 128 129 elif self.hf_model_config_name == "CLIPSegConfig": 130 # Get masks and ensure batch dimension exists 131 masks = torch.sigmoid(outputs.logits) 132 if masks.dim() == 2: # Handle single-example edge case 133 masks = masks.unsqueeze(0) # Add batch dimension 134 135 # Verify mask dimensions 136 batch_size, mask_height, mask_width = masks.shape 137 138 # Convert detection box to mask coordinates 139 box_x1 = int(x1 * mask_width / img_width) 140 box_y1 = int(y1 * mask_height / img_height) 141 box_x2 = int(x2 * mask_width / img_width) 142 box_y2 = int(y2 * mask_height / img_height) 143 144 # Calculate scores for each candidate label 145 scores = [] 146 for i in range(batch_size): 147 # Extract relevant mask region 148 label_mask = masks[i, box_y1:box_y2, box_x1:box_x2] 149 150 # Handle empty regions gracefully 151 if label_mask.numel() == 0: 152 scores.append(0.0) 153 continue 154 scores.append(label_mask.mean().item()) 155 156 # Find best match 157 confidence_score = max(scores) 158 predicted_idx = scores.index(confidence_score) 159 predicted_label = candidate_labels[predicted_idx] 160 161 else: 162 logits = outputs.logits_per_image 163 max_logit, predicted_idx = logits.max(dim=-1) 164 predicted_label = candidate_labels[predicted_idx.item()] 165 confidence_score = max_logit.item() 166 167 return predicted_label, confidence_score 168 169 def run_mapping(self, test_dataset_source, test_dataset_target, label_field = "ground_truth",): 170 """Run the class mapping process between the source dataset and the target dataset.""" 171 if not self.model: 172 self.load_model() 173 174 dataset_source_name = self.config.get("dataset_source") 175 dataset_target_name = self.config.get("dataset_target") 176 177 if not dataset_source_name or not dataset_target_name: 178 logging.error("Both 'dataset_source' and 'dataset_target' must be specified in the config.") 179 raise ValueError("Both 'dataset_source' and 'dataset_target' must be specified in the config.") 180 181 # Load the datasets from FiftyOne. 182 try: 183 if test_dataset_source is None: 184 SELECTED_DATASET = { 185 "name": dataset_source_name, 186 "n_samples": None, # 'None' (full dataset) or 'int' (subset of the dataset) 187 } 188 source_dataset, source_dataset_info = load_dataset(SELECTED_DATASET) 189 else: 190 source_dataset = test_dataset_source 191 192 except Exception as e: 193 logging.error(f"Failed to load dataset_source '{dataset_source_name}': {e}") 194 raise ValueError(f"Failed to load dataset_source '{dataset_source_name}': {e}") 195 196 try: 197 if test_dataset_target is None: 198 SELECTED_DATASET = { 199 "name": dataset_target_name, 200 "n_samples": None, # 'None' (full dataset) or 'int' (subset of the dataset) 201 } 202 target_dataset, target_dataset_info = load_dataset(SELECTED_DATASET) 203 else: 204 target_dataset = test_dataset_target 205 206 except Exception as e: 207 logging.error(f"Failed to load dataset_target '{dataset_target_name}': {e}") 208 raise ValueError(f"Failed to load dataset_target '{dataset_target_name}': {e}") 209 210 # Get the distinct labels present in each dataset. 211 source_labels = source_dataset.distinct(f"{label_field}.detections.label") 212 target_labels = target_dataset.distinct(f"{label_field}.detections.label") 213 214 # Access candidate labels from config 215 candidate_labels = self.config["candidate_labels"] 216 217 # Check that all labels from the candidate_labels exist in dataset_source. 218 input_source_labels = list(candidate_labels.keys()) 219 missing_source_labels = [p for p in input_source_labels if p not in source_labels] 220 if missing_source_labels: 221 error_msg = (f"Missing labels in dataset_source '{dataset_source_name}': {missing_source_labels}\n" 222 f"Expected source labels: {input_source_labels}\n" 223 f"Found in dataset_source: {source_labels}") 224 logging.error(error_msg) 225 raise ValueError(error_msg) 226 227 # Check that all labels from candidate_labels exist in dataset_target. 228 all_target_labels = [] 229 for input_target_labels in candidate_labels.values(): 230 all_target_labels.extend(input_target_labels) 231 missing_target = [target_temp for target_temp in all_target_labels if target_temp not in target_labels] 232 if missing_target: 233 error_msg = (f"Missing labels in dataset_target '{dataset_target_name}': {missing_target}\n" 234 f"Expected target classes: {all_target_labels}\n" 235 f"Found in dataset_target: {target_labels}") 236 logging.error(error_msg) 237 raise ValueError(error_msg) 238 239 one_to_one_mapping = { 240 source: target[0] # Only pick mappings with exactly 1 target class 241 for source, target in candidate_labels.items() 242 if len(target) == 1 243 } 244 245 log_root="./logs/" 246 experiment_name = f"{self.model_name}_class_mapping_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" 247 tensorboard_root = os.path.join( 248 log_root, "tensorboard/class_mapping" 249 ) 250 dataset_name = getattr(self.dataset, "name", "default_dataset") 251 log_directory = os.path.join( 252 tensorboard_root, dataset_name, experiment_name 253 ) 254 255 tb_writer = SummaryWriter(log_dir=log_directory) 256 257 threshold = self.config["thresholds"]["confidence"] 258 sample_count = 0 # For logging steps 259 260 for sample in self.dataset.iter_samples(progress=True, autosave=True): 261 sample_count += 1 262 try: 263 image = Image.open(sample.filepath) 264 except Exception as e: 265 logging.error(f"Error opening image {sample.filepath}: {str(e)}") 266 continue 267 268 detections = sample[label_field] 269 for det in detections.detections: 270 current_label = det.label 271 272 # Check if it is a one-to-one mapping 273 if current_label in one_to_one_mapping: 274 new_label = one_to_one_mapping[current_label] 275 tag = f"new_class_{new_label}_changed_from_{current_label}" 276 277 # Apply new label and add tag 278 if tag not in det.tags: 279 det.tags.append(tag) 280 det.label = new_label # Replace label directly 281 282 self.stats["tags_added_per_category"][new_label] = ( 283 self.stats["tags_added_per_category"].get(new_label, 0) + 1 284 ) 285 # Log tag changes 286 tb_writer.add_scalar(f"Tags_Added/{new_label}", 287 self.stats["tags_added_per_category"].get(new_label, 0), 288 sample_count) 289 continue 290 291 if current_label not in candidate_labels: 292 continue 293 294 current_candidate_labels = candidate_labels[current_label] 295 predicted_label, confidence = self.process_detection(image, det, current_candidate_labels) 296 297 self.stats["source_class_counts"][current_label] = self.stats["source_class_counts"].get(current_label, 0) + 1 298 self.stats["total_processed"] += 1 299 300 # If prediction meets threshold and differs from the original label, add a tag. 301 if confidence > threshold and current_label.lower() != predicted_label.lower(): 302 tag = f"new_class_{self.model_name}_{predicted_label}" 303 if tag not in det.tags: 304 det.tags.append(tag) 305 306 #Change Lables if the flag is set to True 307 if self.change_labels: 308 det.label = predicted_label 309 self.stats["changes_made"] += 1 310 311 self.stats["tags_added_per_category"][predicted_label] = ( 312 self.stats["tags_added_per_category"].get(predicted_label, 0) + 1 313 ) 314 315 tb_writer.add_scalar(f"Tags_Added/{predicted_label}", 316 self.stats["tags_added_per_category"].get(predicted_label, 0), 317 sample_count) 318 tb_writer.add_scalar(f"Total_Tags_Added", 319 self.stats["changes_made"], 320 sample_count) 321 322 tb_writer.add_scalar(f"Total_Processed_Samples", 323 self.stats["total_processed"], 324 sample_count) 325 326 # Log class counts from source and target tag counts based on candidate_labels. 327 for input_source_label, target_temp_labels in self.config["candidate_labels"].items(): 328 tb_writer.add_scalar(f"Source_Class_Count/{input_source_label}", 329 self.stats["source_class_counts"].get(input_source_label, 0), 330 sample_count) 331 for target_temp_label in target_temp_labels: 332 tb_writer.add_scalar(f"Tags_Added/{target_temp_label}", 333 self.stats["tags_added_per_category"].get(target_temp_label, 0), 334 sample_count) 335 336 tb_writer.close() 337 338 return self.stats
Class mapper that uses various HuggingFace models to align class labels between the source and target datasets.
ClassMapper(dataset, model_name, config=None)
19 def __init__(self, dataset, model_name, config=None): 20 """Initialize the ClassMapper with dataset and model configuration.""" 21 self.dataset = dataset 22 self.model_name = model_name 23 # Get default config from WORKFLOWS and update with any provided config. 24 self.config = WORKFLOWS["class_mapping"].copy() 25 self.change_labels = config["change_labels"] 26 if config: 27 self.config.update(config) 28 29 self.device = "cuda" if torch.cuda.is_available() else "cpu" 30 self.processor = None 31 self.model = None 32 self.stats = { 33 "total_processed": 0, 34 "changes_made": 0, 35 "source_class_counts": {}, # Count of processed detections per source class (e.g., Car, Truck) 36 "tags_added_per_category": {} # Detailed counts for tags added per target class(e.g., Van, Pickup) 37 }
Initialize the ClassMapper with dataset and model configuration.
def
load_model(self):
39 def load_model(self): 40 """Load the model and processor from HuggingFace.""" 41 try: 42 self.hf_model_config = AutoConfig.from_pretrained(self.model_name) 43 self.hf_model_config_name = type(self.hf_model_config).__name__ 44 45 if self.hf_model_config_name == "SiglipConfig": 46 self.model = AutoModel.from_pretrained(self.model_name) 47 self.processor = AutoProcessor.from_pretrained(self.model_name) 48 49 elif self.hf_model_config_name == "AlignConfig": 50 self.processor = AlignProcessor.from_pretrained(self.model_name) 51 self.model = AlignModel.from_pretrained(self.model_name) 52 53 elif self.hf_model_config_name == "AltCLIPConfig": 54 self.processor = AltCLIPProcessor.from_pretrained(self.model_name) 55 self.model = AltCLIPModel.from_pretrained(self.model_name) 56 57 elif self.hf_model_config_name == "CLIPSegConfig": 58 self.processor = CLIPSegProcessor.from_pretrained(self.model_name) 59 self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_name) 60 61 elif self.hf_model_config_name in ["Blip2Config", "CLIPConfig"]: 62 self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_name) 63 self.processor = AutoProcessor.from_pretrained(self.model_name) 64 65 else: 66 logging.error(f"Invalid Model Name : {self.model_name}") 67 68 self.model.to(self.device) 69 logging.info(f"Successfully loaded model {self.model_name}") 70 71 except Exception as e: 72 logging.error(f"Failed to load model: {str(e)}") 73 raise
Load the model and processor from HuggingFace.
def
process_detection(self, image, detection, candidate_labels):
75 def process_detection(self, image, detection, candidate_labels): 76 """Process a single detection with the model.""" 77 # Convert bounding box to pixel coordinates. 78 img_width, img_height = image.size 79 bbox = detection.bounding_box 80 min_x, min_y, width, height = bbox 81 x1, y1 = int(min_x * img_width), int(min_y * img_height) 82 x2, y2 = int((min_x + width) * img_width), int((min_y + height) * img_height) 83 84 # Crop image to detection region. 85 image_patch = image.crop((x1, y1, x2, y2)) 86 87 # Prepare inputs for the model. 88 if self.hf_model_config_name == "SiglipConfig": 89 target_size = (384, 384) # Adjust per model 90 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 91 inputs = self.processor(text=candidate_labels, images=image_patch, padding="max_length", return_tensors="pt") 92 93 elif self.hf_model_config_name == "CLIPSegConfig": 94 target_size = (224, 224) # Adjust per model 95 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 96 inputs = self.processor(text=candidate_labels, images=[image_patch]*len(candidate_labels), padding="max_length", return_tensors="pt") 97 98 else: 99 target_size = (224, 224) # Adjust per model 100 image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS) 101 inputs = self.processor(images=image_patch, text=candidate_labels, return_tensors="pt", padding=True) 102 103 # Ensure all tensors in the processed inputs are moved to the designated device. 104 inputs = {k: v.to(self.device) for k, v in inputs.items()} 105 106 # Generate classification output. 107 with torch.no_grad(): 108 outputs = self.model(**inputs) 109 110 predicted_label = None 111 confidence_score = None 112 113 if self.hf_model_config_name == "SiglipConfig": 114 # Apply sigmoid for probabilities 115 logits = outputs.logits_per_image 116 probs = torch.sigmoid(logits) 117 probs = torch.softmax(probs, dim=1) 118 max_prob, predicted_idx = probs[0].max(dim=-1) 119 predicted_label = candidate_labels[predicted_idx.item()] 120 confidence_score = max_prob.item() 121 122 elif self.hf_model_config_name in ["AlignConfig", "AltCLIPConfig"]: 123 logits = outputs.logits_per_image 124 probs = torch.softmax(logits, dim=1) 125 max_prob, predicted_idx = probs[0].max(dim=-1) 126 predicted_label = candidate_labels[predicted_idx.item()] 127 confidence_score = max_prob.item() 128 129 elif self.hf_model_config_name == "CLIPSegConfig": 130 # Get masks and ensure batch dimension exists 131 masks = torch.sigmoid(outputs.logits) 132 if masks.dim() == 2: # Handle single-example edge case 133 masks = masks.unsqueeze(0) # Add batch dimension 134 135 # Verify mask dimensions 136 batch_size, mask_height, mask_width = masks.shape 137 138 # Convert detection box to mask coordinates 139 box_x1 = int(x1 * mask_width / img_width) 140 box_y1 = int(y1 * mask_height / img_height) 141 box_x2 = int(x2 * mask_width / img_width) 142 box_y2 = int(y2 * mask_height / img_height) 143 144 # Calculate scores for each candidate label 145 scores = [] 146 for i in range(batch_size): 147 # Extract relevant mask region 148 label_mask = masks[i, box_y1:box_y2, box_x1:box_x2] 149 150 # Handle empty regions gracefully 151 if label_mask.numel() == 0: 152 scores.append(0.0) 153 continue 154 scores.append(label_mask.mean().item()) 155 156 # Find best match 157 confidence_score = max(scores) 158 predicted_idx = scores.index(confidence_score) 159 predicted_label = candidate_labels[predicted_idx] 160 161 else: 162 logits = outputs.logits_per_image 163 max_logit, predicted_idx = logits.max(dim=-1) 164 predicted_label = candidate_labels[predicted_idx.item()] 165 confidence_score = max_logit.item() 166 167 return predicted_label, confidence_score
Process a single detection with the model.
def
run_mapping( self, test_dataset_source, test_dataset_target, label_field='ground_truth'):
169 def run_mapping(self, test_dataset_source, test_dataset_target, label_field = "ground_truth",): 170 """Run the class mapping process between the source dataset and the target dataset.""" 171 if not self.model: 172 self.load_model() 173 174 dataset_source_name = self.config.get("dataset_source") 175 dataset_target_name = self.config.get("dataset_target") 176 177 if not dataset_source_name or not dataset_target_name: 178 logging.error("Both 'dataset_source' and 'dataset_target' must be specified in the config.") 179 raise ValueError("Both 'dataset_source' and 'dataset_target' must be specified in the config.") 180 181 # Load the datasets from FiftyOne. 182 try: 183 if test_dataset_source is None: 184 SELECTED_DATASET = { 185 "name": dataset_source_name, 186 "n_samples": None, # 'None' (full dataset) or 'int' (subset of the dataset) 187 } 188 source_dataset, source_dataset_info = load_dataset(SELECTED_DATASET) 189 else: 190 source_dataset = test_dataset_source 191 192 except Exception as e: 193 logging.error(f"Failed to load dataset_source '{dataset_source_name}': {e}") 194 raise ValueError(f"Failed to load dataset_source '{dataset_source_name}': {e}") 195 196 try: 197 if test_dataset_target is None: 198 SELECTED_DATASET = { 199 "name": dataset_target_name, 200 "n_samples": None, # 'None' (full dataset) or 'int' (subset of the dataset) 201 } 202 target_dataset, target_dataset_info = load_dataset(SELECTED_DATASET) 203 else: 204 target_dataset = test_dataset_target 205 206 except Exception as e: 207 logging.error(f"Failed to load dataset_target '{dataset_target_name}': {e}") 208 raise ValueError(f"Failed to load dataset_target '{dataset_target_name}': {e}") 209 210 # Get the distinct labels present in each dataset. 211 source_labels = source_dataset.distinct(f"{label_field}.detections.label") 212 target_labels = target_dataset.distinct(f"{label_field}.detections.label") 213 214 # Access candidate labels from config 215 candidate_labels = self.config["candidate_labels"] 216 217 # Check that all labels from the candidate_labels exist in dataset_source. 218 input_source_labels = list(candidate_labels.keys()) 219 missing_source_labels = [p for p in input_source_labels if p not in source_labels] 220 if missing_source_labels: 221 error_msg = (f"Missing labels in dataset_source '{dataset_source_name}': {missing_source_labels}\n" 222 f"Expected source labels: {input_source_labels}\n" 223 f"Found in dataset_source: {source_labels}") 224 logging.error(error_msg) 225 raise ValueError(error_msg) 226 227 # Check that all labels from candidate_labels exist in dataset_target. 228 all_target_labels = [] 229 for input_target_labels in candidate_labels.values(): 230 all_target_labels.extend(input_target_labels) 231 missing_target = [target_temp for target_temp in all_target_labels if target_temp not in target_labels] 232 if missing_target: 233 error_msg = (f"Missing labels in dataset_target '{dataset_target_name}': {missing_target}\n" 234 f"Expected target classes: {all_target_labels}\n" 235 f"Found in dataset_target: {target_labels}") 236 logging.error(error_msg) 237 raise ValueError(error_msg) 238 239 one_to_one_mapping = { 240 source: target[0] # Only pick mappings with exactly 1 target class 241 for source, target in candidate_labels.items() 242 if len(target) == 1 243 } 244 245 log_root="./logs/" 246 experiment_name = f"{self.model_name}_class_mapping_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" 247 tensorboard_root = os.path.join( 248 log_root, "tensorboard/class_mapping" 249 ) 250 dataset_name = getattr(self.dataset, "name", "default_dataset") 251 log_directory = os.path.join( 252 tensorboard_root, dataset_name, experiment_name 253 ) 254 255 tb_writer = SummaryWriter(log_dir=log_directory) 256 257 threshold = self.config["thresholds"]["confidence"] 258 sample_count = 0 # For logging steps 259 260 for sample in self.dataset.iter_samples(progress=True, autosave=True): 261 sample_count += 1 262 try: 263 image = Image.open(sample.filepath) 264 except Exception as e: 265 logging.error(f"Error opening image {sample.filepath}: {str(e)}") 266 continue 267 268 detections = sample[label_field] 269 for det in detections.detections: 270 current_label = det.label 271 272 # Check if it is a one-to-one mapping 273 if current_label in one_to_one_mapping: 274 new_label = one_to_one_mapping[current_label] 275 tag = f"new_class_{new_label}_changed_from_{current_label}" 276 277 # Apply new label and add tag 278 if tag not in det.tags: 279 det.tags.append(tag) 280 det.label = new_label # Replace label directly 281 282 self.stats["tags_added_per_category"][new_label] = ( 283 self.stats["tags_added_per_category"].get(new_label, 0) + 1 284 ) 285 # Log tag changes 286 tb_writer.add_scalar(f"Tags_Added/{new_label}", 287 self.stats["tags_added_per_category"].get(new_label, 0), 288 sample_count) 289 continue 290 291 if current_label not in candidate_labels: 292 continue 293 294 current_candidate_labels = candidate_labels[current_label] 295 predicted_label, confidence = self.process_detection(image, det, current_candidate_labels) 296 297 self.stats["source_class_counts"][current_label] = self.stats["source_class_counts"].get(current_label, 0) + 1 298 self.stats["total_processed"] += 1 299 300 # If prediction meets threshold and differs from the original label, add a tag. 301 if confidence > threshold and current_label.lower() != predicted_label.lower(): 302 tag = f"new_class_{self.model_name}_{predicted_label}" 303 if tag not in det.tags: 304 det.tags.append(tag) 305 306 #Change Lables if the flag is set to True 307 if self.change_labels: 308 det.label = predicted_label 309 self.stats["changes_made"] += 1 310 311 self.stats["tags_added_per_category"][predicted_label] = ( 312 self.stats["tags_added_per_category"].get(predicted_label, 0) + 1 313 ) 314 315 tb_writer.add_scalar(f"Tags_Added/{predicted_label}", 316 self.stats["tags_added_per_category"].get(predicted_label, 0), 317 sample_count) 318 tb_writer.add_scalar(f"Total_Tags_Added", 319 self.stats["changes_made"], 320 sample_count) 321 322 tb_writer.add_scalar(f"Total_Processed_Samples", 323 self.stats["total_processed"], 324 sample_count) 325 326 # Log class counts from source and target tag counts based on candidate_labels. 327 for input_source_label, target_temp_labels in self.config["candidate_labels"].items(): 328 tb_writer.add_scalar(f"Source_Class_Count/{input_source_label}", 329 self.stats["source_class_counts"].get(input_source_label, 0), 330 sample_count) 331 for target_temp_label in target_temp_labels: 332 tb_writer.add_scalar(f"Tags_Added/{target_temp_label}", 333 self.stats["tags_added_per_category"].get(target_temp_label, 0), 334 sample_count) 335 336 tb_writer.close() 337 338 return self.stats
Run the class mapping process between the source dataset and the target dataset.