workflows.class_mapping

  1import logging
  2import torch
  3from transformers import AutoConfig, AutoProcessor, AutoModelForZeroShotImageClassification, AutoModel, AlignProcessor, AlignModel, AltCLIPModel, AltCLIPProcessor, CLIPSegProcessor, CLIPSegForImageSegmentation
  4import os
  5import datetime
  6from PIL import Image
  7from tqdm import tqdm
  8from config.config import WORKFLOWS
  9import fiftyone as fo
 10import wandb
 11from utils.dataset_loader import load_dataset
 12import torch.nn.functional as F
 13from torch.utils.tensorboard import SummaryWriter
 14
 15
 16class ClassMapper:
 17    """Class mapper that uses various HuggingFace models to align class labels between the source and target datasets."""
 18    def __init__(self, dataset, model_name, config=None):
 19        """Initialize the ClassMapper with dataset and model configuration."""
 20        self.dataset = dataset
 21        self.model_name = model_name
 22        # Get default config from WORKFLOWS and update with any provided config.
 23        self.config = WORKFLOWS["class_mapping"].copy()
 24        self.change_labels = config["change_labels"]
 25        if config:
 26            self.config.update(config)
 27
 28        self.device = "cuda" if torch.cuda.is_available() else "cpu"
 29        self.processor = None
 30        self.model = None
 31        self.stats = {
 32            "total_processed": 0,
 33            "changes_made": 0,
 34            "source_class_counts": {},  # Count of processed detections per source class (e.g., Car, Truck)
 35            "tags_added_per_category": {}  # Detailed counts for tags added per target class(e.g., Van, Pickup)
 36        }
 37
 38    def load_model(self):
 39        """Load the model and processor from HuggingFace."""
 40        try:
 41            self.hf_model_config = AutoConfig.from_pretrained(self.model_name)
 42            self.hf_model_config_name = type(self.hf_model_config).__name__
 43
 44            if self.hf_model_config_name == "SiglipConfig":
 45                self.model = AutoModel.from_pretrained(self.model_name)
 46                self.processor = AutoProcessor.from_pretrained(self.model_name)
 47
 48            elif self.hf_model_config_name == "AlignConfig":
 49                self.processor = AlignProcessor.from_pretrained(self.model_name)
 50                self.model = AlignModel.from_pretrained(self.model_name)
 51
 52            elif self.hf_model_config_name == "AltCLIPConfig":
 53                self.processor = AltCLIPProcessor.from_pretrained(self.model_name)
 54                self.model = AltCLIPModel.from_pretrained(self.model_name)
 55
 56            elif self.hf_model_config_name == "CLIPSegConfig":
 57                self.processor = CLIPSegProcessor.from_pretrained(self.model_name)
 58                self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_name)
 59
 60            elif self.hf_model_config_name in ["Blip2Config", "CLIPConfig"]:
 61                self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_name)
 62                self.processor = AutoProcessor.from_pretrained(self.model_name)
 63
 64            else:
 65                logging.error(f"Invalid Model Name : {self.model_name}")
 66
 67            self.model.to(self.device)
 68            logging.info(f"Successfully loaded model {self.model_name}")
 69
 70        except Exception as e:
 71            logging.error(f"Failed to load model: {str(e)}")
 72            raise
 73
 74    def process_detection(self, image, detection, candidate_labels):
 75        """Process a single detection with the model."""
 76        # Convert bounding box to pixel coordinates.
 77        img_width, img_height = image.size
 78        bbox = detection.bounding_box
 79        min_x, min_y, width, height = bbox
 80        x1, y1 = int(min_x * img_width), int(min_y * img_height)
 81        x2, y2 = int((min_x + width) * img_width), int((min_y + height) * img_height)
 82
 83        # Crop image to detection region.
 84        image_patch = image.crop((x1, y1, x2, y2))
 85
 86        # Prepare inputs for the model.
 87        if self.hf_model_config_name == "SiglipConfig":
 88            target_size = (384, 384)  # Adjust per model
 89            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
 90            inputs = self.processor(text=candidate_labels, images=image_patch, padding="max_length", return_tensors="pt")
 91
 92        elif self.hf_model_config_name == "CLIPSegConfig":
 93            target_size = (224, 224)  # Adjust per model
 94            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
 95            inputs = self.processor(text=candidate_labels, images=[image_patch]*len(candidate_labels), padding="max_length", return_tensors="pt")
 96
 97        else:
 98            target_size = (224, 224)  # Adjust per model
 99            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
100            inputs = self.processor(images=image_patch, text=candidate_labels, return_tensors="pt", padding=True)
101
102        # Ensure all tensors in the processed inputs are moved to the designated device.
103        inputs = {k: v.to(self.device) for k, v in inputs.items()}
104
105        # Generate classification output.
106        with torch.no_grad():
107            outputs = self.model(**inputs)
108
109        predicted_label = None
110        confidence_score = None
111
112        if self.hf_model_config_name == "SiglipConfig":
113            # Apply sigmoid for probabilities
114            logits = outputs.logits_per_image
115            probs = torch.sigmoid(logits)
116            probs = torch.softmax(probs, dim=1)
117            max_prob, predicted_idx = probs[0].max(dim=-1)
118            predicted_label = candidate_labels[predicted_idx.item()]
119            confidence_score = max_prob.item()
120
121        elif self.hf_model_config_name in ["AlignConfig", "AltCLIPConfig"]:
122            logits = outputs.logits_per_image
123            probs = torch.softmax(logits, dim=1)
124            max_prob, predicted_idx = probs[0].max(dim=-1)
125            predicted_label = candidate_labels[predicted_idx.item()]
126            confidence_score = max_prob.item()
127
128        elif self.hf_model_config_name == "CLIPSegConfig":
129            # Get masks and ensure batch dimension exists
130            masks = torch.sigmoid(outputs.logits)
131            if masks.dim() == 2:  # Handle single-example edge case
132                masks = masks.unsqueeze(0)  # Add batch dimension
133
134            # Verify mask dimensions
135            batch_size, mask_height, mask_width = masks.shape
136
137            # Convert detection box to mask coordinates
138            box_x1 = int(x1 * mask_width / img_width)
139            box_y1 = int(y1 * mask_height / img_height)
140            box_x2 = int(x2 * mask_width / img_width)
141            box_y2 = int(y2 * mask_height / img_height)
142
143            # Calculate scores for each candidate label
144            scores = []
145            for i in range(batch_size):
146                # Extract relevant mask region
147                label_mask = masks[i, box_y1:box_y2, box_x1:box_x2]
148
149                # Handle empty regions gracefully
150                if label_mask.numel() == 0:
151                    scores.append(0.0)
152                    continue
153                scores.append(label_mask.mean().item())
154
155            # Find best match
156            confidence_score = max(scores)
157            predicted_idx = scores.index(confidence_score)
158            predicted_label = candidate_labels[predicted_idx]
159
160        else:
161            logits = outputs.logits_per_image
162            max_logit, predicted_idx = logits.max(dim=-1)
163            predicted_label = candidate_labels[predicted_idx.item()]
164            confidence_score = max_logit.item()
165
166        return predicted_label, confidence_score
167
168    def run_mapping(self, test_dataset_source, test_dataset_target, label_field = "ground_truth",):
169        """Run the class mapping process between the source dataset and the target dataset."""
170        if not self.model:
171            self.load_model()
172
173        dataset_source_name = self.config.get("dataset_source")
174        dataset_target_name = self.config.get("dataset_target")
175
176        if not dataset_source_name or not dataset_target_name:
177            logging.error("Both 'dataset_source' and 'dataset_target' must be specified in the config.")
178            raise ValueError("Both 'dataset_source' and 'dataset_target' must be specified in the config.")
179
180        # Load the datasets from FiftyOne.
181        try:
182            if test_dataset_source is None:
183                SELECTED_DATASET = {
184                    "name": dataset_source_name,
185                    "n_samples": None,  # 'None' (full dataset) or 'int' (subset of the dataset)
186                }
187                source_dataset, source_dataset_info = load_dataset(SELECTED_DATASET)
188            else:
189                source_dataset = test_dataset_source
190
191        except Exception as e:
192            logging.error(f"Failed to load dataset_source '{dataset_source_name}': {e}")
193            raise ValueError(f"Failed to load dataset_source '{dataset_source_name}': {e}")
194
195        try:
196            if test_dataset_target is None:
197                SELECTED_DATASET = {
198                    "name": dataset_target_name,
199                    "n_samples": None,  # 'None' (full dataset) or 'int' (subset of the dataset)
200                }
201                target_dataset, target_dataset_info = load_dataset(SELECTED_DATASET)
202            else:
203                target_dataset = test_dataset_target
204
205        except Exception as e:
206            logging.error(f"Failed to load dataset_target '{dataset_target_name}': {e}")
207            raise ValueError(f"Failed to load dataset_target '{dataset_target_name}': {e}")
208
209        # Get the distinct labels present in each dataset.
210        source_labels = source_dataset.distinct(f"{label_field}.detections.label")
211        target_labels = target_dataset.distinct(f"{label_field}.detections.label")
212
213        # Access candidate labels from config
214        candidate_labels = self.config["candidate_labels"]
215
216        # Check that all labels from the candidate_labels exist in dataset_source.
217        input_source_labels = list(candidate_labels.keys())
218        missing_source_labels = [p for p in input_source_labels if p not in source_labels]
219        if missing_source_labels:
220            error_msg = (f"Missing labels in dataset_source '{dataset_source_name}': {missing_source_labels}\n"
221                        f"Expected source labels: {input_source_labels}\n"
222                        f"Found in dataset_source: {source_labels}")
223            logging.error(error_msg)
224            raise ValueError(error_msg)
225
226        # Check that all labels from candidate_labels exist in dataset_target.
227        all_target_labels = []
228        for input_target_labels in candidate_labels.values():
229            all_target_labels.extend(input_target_labels)
230        missing_target = [target_temp for target_temp in all_target_labels if target_temp not in target_labels]
231        if missing_target:
232            error_msg = (f"Missing labels in dataset_target '{dataset_target_name}': {missing_target}\n"
233                        f"Expected target classes: {all_target_labels}\n"
234                        f"Found in dataset_target: {target_labels}")
235            logging.error(error_msg)
236            raise ValueError(error_msg)
237
238        one_to_one_mapping = {
239            source: target[0]  # Only pick mappings with exactly 1 target class
240            for source, target in candidate_labels.items()
241            if len(target) == 1
242        }
243
244        log_root="./logs/"
245        experiment_name = f"{self.model_name}_class_mapping_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
246        tensorboard_root = os.path.join(
247            log_root, "tensorboard/class_mapping"
248        )
249        dataset_name = getattr(self.dataset, "name", "default_dataset")
250        log_directory = os.path.join(
251            tensorboard_root, dataset_name, experiment_name
252        )
253
254        tb_writer = SummaryWriter(log_dir=log_directory)
255
256        threshold = self.config["thresholds"]["confidence"]
257        sample_count = 0  # For logging steps
258
259        for sample in self.dataset.iter_samples(progress=True, autosave=True):
260            sample_count += 1
261            try:
262                image = Image.open(sample.filepath)
263            except Exception as e:
264                logging.error(f"Error opening image {sample.filepath}: {str(e)}")
265                continue
266
267            detections = sample[label_field]
268            for det in detections.detections:
269                current_label = det.label
270
271                # Check if it is a one-to-one mapping
272                if current_label in one_to_one_mapping:
273                    new_label = one_to_one_mapping[current_label]
274                    tag = f"new_class_{new_label}_changed_from_{current_label}"
275
276                    # Apply new label and add tag
277                    if tag not in det.tags:
278                        det.tags.append(tag)
279                        det.label = new_label  # Replace label directly
280
281                        self.stats["tags_added_per_category"][new_label] = (
282                            self.stats["tags_added_per_category"].get(new_label, 0) + 1
283                        )
284                        # Log tag changes
285                        tb_writer.add_scalar(f"Tags_Added/{new_label}",
286                                            self.stats["tags_added_per_category"].get(new_label, 0),
287                                            sample_count)
288                    continue
289
290                if current_label not in candidate_labels:
291                    continue
292
293                current_candidate_labels = candidate_labels[current_label]
294                predicted_label, confidence = self.process_detection(image, det, current_candidate_labels)
295
296                self.stats["source_class_counts"][current_label] = self.stats["source_class_counts"].get(current_label, 0) + 1
297                self.stats["total_processed"] += 1
298
299                # If prediction meets threshold and differs from the original label, add a tag.
300                if confidence > threshold and current_label.lower() != predicted_label.lower():
301                    tag = f"new_class_{self.model_name}_{predicted_label}"
302                    if tag not in det.tags:
303                        det.tags.append(tag)
304
305                        #Change Lables if the flag is set to True
306                        if self.change_labels:
307                            det.label = predicted_label
308                        self.stats["changes_made"] += 1
309
310                        self.stats["tags_added_per_category"][predicted_label] = (
311                            self.stats["tags_added_per_category"].get(predicted_label, 0) + 1
312                        )
313
314                        tb_writer.add_scalar(f"Tags_Added/{predicted_label}",
315                                            self.stats["tags_added_per_category"].get(predicted_label, 0),
316                                            sample_count)
317                        tb_writer.add_scalar(f"Total_Tags_Added",
318                                            self.stats["changes_made"],
319                                            sample_count)
320
321            tb_writer.add_scalar(f"Total_Processed_Samples",
322                                self.stats["total_processed"],
323                                sample_count)
324
325            # Log class counts from source and target tag counts based on candidate_labels.
326            for input_source_label, target_temp_labels in self.config["candidate_labels"].items():
327                tb_writer.add_scalar(f"Source_Class_Count/{input_source_label}",
328                                    self.stats["source_class_counts"].get(input_source_label, 0),
329                                    sample_count)
330                for target_temp_label in target_temp_labels:
331                    tb_writer.add_scalar(f"Tags_Added/{target_temp_label}",
332                                        self.stats["tags_added_per_category"].get(target_temp_label, 0),
333                                        sample_count)
334
335        tb_writer.close()
336
337        return self.stats
class ClassMapper:
 17class ClassMapper:
 18    """Class mapper that uses various HuggingFace models to align class labels between the source and target datasets."""
 19    def __init__(self, dataset, model_name, config=None):
 20        """Initialize the ClassMapper with dataset and model configuration."""
 21        self.dataset = dataset
 22        self.model_name = model_name
 23        # Get default config from WORKFLOWS and update with any provided config.
 24        self.config = WORKFLOWS["class_mapping"].copy()
 25        self.change_labels = config["change_labels"]
 26        if config:
 27            self.config.update(config)
 28
 29        self.device = "cuda" if torch.cuda.is_available() else "cpu"
 30        self.processor = None
 31        self.model = None
 32        self.stats = {
 33            "total_processed": 0,
 34            "changes_made": 0,
 35            "source_class_counts": {},  # Count of processed detections per source class (e.g., Car, Truck)
 36            "tags_added_per_category": {}  # Detailed counts for tags added per target class(e.g., Van, Pickup)
 37        }
 38
 39    def load_model(self):
 40        """Load the model and processor from HuggingFace."""
 41        try:
 42            self.hf_model_config = AutoConfig.from_pretrained(self.model_name)
 43            self.hf_model_config_name = type(self.hf_model_config).__name__
 44
 45            if self.hf_model_config_name == "SiglipConfig":
 46                self.model = AutoModel.from_pretrained(self.model_name)
 47                self.processor = AutoProcessor.from_pretrained(self.model_name)
 48
 49            elif self.hf_model_config_name == "AlignConfig":
 50                self.processor = AlignProcessor.from_pretrained(self.model_name)
 51                self.model = AlignModel.from_pretrained(self.model_name)
 52
 53            elif self.hf_model_config_name == "AltCLIPConfig":
 54                self.processor = AltCLIPProcessor.from_pretrained(self.model_name)
 55                self.model = AltCLIPModel.from_pretrained(self.model_name)
 56
 57            elif self.hf_model_config_name == "CLIPSegConfig":
 58                self.processor = CLIPSegProcessor.from_pretrained(self.model_name)
 59                self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_name)
 60
 61            elif self.hf_model_config_name in ["Blip2Config", "CLIPConfig"]:
 62                self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_name)
 63                self.processor = AutoProcessor.from_pretrained(self.model_name)
 64
 65            else:
 66                logging.error(f"Invalid Model Name : {self.model_name}")
 67
 68            self.model.to(self.device)
 69            logging.info(f"Successfully loaded model {self.model_name}")
 70
 71        except Exception as e:
 72            logging.error(f"Failed to load model: {str(e)}")
 73            raise
 74
 75    def process_detection(self, image, detection, candidate_labels):
 76        """Process a single detection with the model."""
 77        # Convert bounding box to pixel coordinates.
 78        img_width, img_height = image.size
 79        bbox = detection.bounding_box
 80        min_x, min_y, width, height = bbox
 81        x1, y1 = int(min_x * img_width), int(min_y * img_height)
 82        x2, y2 = int((min_x + width) * img_width), int((min_y + height) * img_height)
 83
 84        # Crop image to detection region.
 85        image_patch = image.crop((x1, y1, x2, y2))
 86
 87        # Prepare inputs for the model.
 88        if self.hf_model_config_name == "SiglipConfig":
 89            target_size = (384, 384)  # Adjust per model
 90            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
 91            inputs = self.processor(text=candidate_labels, images=image_patch, padding="max_length", return_tensors="pt")
 92
 93        elif self.hf_model_config_name == "CLIPSegConfig":
 94            target_size = (224, 224)  # Adjust per model
 95            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
 96            inputs = self.processor(text=candidate_labels, images=[image_patch]*len(candidate_labels), padding="max_length", return_tensors="pt")
 97
 98        else:
 99            target_size = (224, 224)  # Adjust per model
100            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
101            inputs = self.processor(images=image_patch, text=candidate_labels, return_tensors="pt", padding=True)
102
103        # Ensure all tensors in the processed inputs are moved to the designated device.
104        inputs = {k: v.to(self.device) for k, v in inputs.items()}
105
106        # Generate classification output.
107        with torch.no_grad():
108            outputs = self.model(**inputs)
109
110        predicted_label = None
111        confidence_score = None
112
113        if self.hf_model_config_name == "SiglipConfig":
114            # Apply sigmoid for probabilities
115            logits = outputs.logits_per_image
116            probs = torch.sigmoid(logits)
117            probs = torch.softmax(probs, dim=1)
118            max_prob, predicted_idx = probs[0].max(dim=-1)
119            predicted_label = candidate_labels[predicted_idx.item()]
120            confidence_score = max_prob.item()
121
122        elif self.hf_model_config_name in ["AlignConfig", "AltCLIPConfig"]:
123            logits = outputs.logits_per_image
124            probs = torch.softmax(logits, dim=1)
125            max_prob, predicted_idx = probs[0].max(dim=-1)
126            predicted_label = candidate_labels[predicted_idx.item()]
127            confidence_score = max_prob.item()
128
129        elif self.hf_model_config_name == "CLIPSegConfig":
130            # Get masks and ensure batch dimension exists
131            masks = torch.sigmoid(outputs.logits)
132            if masks.dim() == 2:  # Handle single-example edge case
133                masks = masks.unsqueeze(0)  # Add batch dimension
134
135            # Verify mask dimensions
136            batch_size, mask_height, mask_width = masks.shape
137
138            # Convert detection box to mask coordinates
139            box_x1 = int(x1 * mask_width / img_width)
140            box_y1 = int(y1 * mask_height / img_height)
141            box_x2 = int(x2 * mask_width / img_width)
142            box_y2 = int(y2 * mask_height / img_height)
143
144            # Calculate scores for each candidate label
145            scores = []
146            for i in range(batch_size):
147                # Extract relevant mask region
148                label_mask = masks[i, box_y1:box_y2, box_x1:box_x2]
149
150                # Handle empty regions gracefully
151                if label_mask.numel() == 0:
152                    scores.append(0.0)
153                    continue
154                scores.append(label_mask.mean().item())
155
156            # Find best match
157            confidence_score = max(scores)
158            predicted_idx = scores.index(confidence_score)
159            predicted_label = candidate_labels[predicted_idx]
160
161        else:
162            logits = outputs.logits_per_image
163            max_logit, predicted_idx = logits.max(dim=-1)
164            predicted_label = candidate_labels[predicted_idx.item()]
165            confidence_score = max_logit.item()
166
167        return predicted_label, confidence_score
168
169    def run_mapping(self, test_dataset_source, test_dataset_target, label_field = "ground_truth",):
170        """Run the class mapping process between the source dataset and the target dataset."""
171        if not self.model:
172            self.load_model()
173
174        dataset_source_name = self.config.get("dataset_source")
175        dataset_target_name = self.config.get("dataset_target")
176
177        if not dataset_source_name or not dataset_target_name:
178            logging.error("Both 'dataset_source' and 'dataset_target' must be specified in the config.")
179            raise ValueError("Both 'dataset_source' and 'dataset_target' must be specified in the config.")
180
181        # Load the datasets from FiftyOne.
182        try:
183            if test_dataset_source is None:
184                SELECTED_DATASET = {
185                    "name": dataset_source_name,
186                    "n_samples": None,  # 'None' (full dataset) or 'int' (subset of the dataset)
187                }
188                source_dataset, source_dataset_info = load_dataset(SELECTED_DATASET)
189            else:
190                source_dataset = test_dataset_source
191
192        except Exception as e:
193            logging.error(f"Failed to load dataset_source '{dataset_source_name}': {e}")
194            raise ValueError(f"Failed to load dataset_source '{dataset_source_name}': {e}")
195
196        try:
197            if test_dataset_target is None:
198                SELECTED_DATASET = {
199                    "name": dataset_target_name,
200                    "n_samples": None,  # 'None' (full dataset) or 'int' (subset of the dataset)
201                }
202                target_dataset, target_dataset_info = load_dataset(SELECTED_DATASET)
203            else:
204                target_dataset = test_dataset_target
205
206        except Exception as e:
207            logging.error(f"Failed to load dataset_target '{dataset_target_name}': {e}")
208            raise ValueError(f"Failed to load dataset_target '{dataset_target_name}': {e}")
209
210        # Get the distinct labels present in each dataset.
211        source_labels = source_dataset.distinct(f"{label_field}.detections.label")
212        target_labels = target_dataset.distinct(f"{label_field}.detections.label")
213
214        # Access candidate labels from config
215        candidate_labels = self.config["candidate_labels"]
216
217        # Check that all labels from the candidate_labels exist in dataset_source.
218        input_source_labels = list(candidate_labels.keys())
219        missing_source_labels = [p for p in input_source_labels if p not in source_labels]
220        if missing_source_labels:
221            error_msg = (f"Missing labels in dataset_source '{dataset_source_name}': {missing_source_labels}\n"
222                        f"Expected source labels: {input_source_labels}\n"
223                        f"Found in dataset_source: {source_labels}")
224            logging.error(error_msg)
225            raise ValueError(error_msg)
226
227        # Check that all labels from candidate_labels exist in dataset_target.
228        all_target_labels = []
229        for input_target_labels in candidate_labels.values():
230            all_target_labels.extend(input_target_labels)
231        missing_target = [target_temp for target_temp in all_target_labels if target_temp not in target_labels]
232        if missing_target:
233            error_msg = (f"Missing labels in dataset_target '{dataset_target_name}': {missing_target}\n"
234                        f"Expected target classes: {all_target_labels}\n"
235                        f"Found in dataset_target: {target_labels}")
236            logging.error(error_msg)
237            raise ValueError(error_msg)
238
239        one_to_one_mapping = {
240            source: target[0]  # Only pick mappings with exactly 1 target class
241            for source, target in candidate_labels.items()
242            if len(target) == 1
243        }
244
245        log_root="./logs/"
246        experiment_name = f"{self.model_name}_class_mapping_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
247        tensorboard_root = os.path.join(
248            log_root, "tensorboard/class_mapping"
249        )
250        dataset_name = getattr(self.dataset, "name", "default_dataset")
251        log_directory = os.path.join(
252            tensorboard_root, dataset_name, experiment_name
253        )
254
255        tb_writer = SummaryWriter(log_dir=log_directory)
256
257        threshold = self.config["thresholds"]["confidence"]
258        sample_count = 0  # For logging steps
259
260        for sample in self.dataset.iter_samples(progress=True, autosave=True):
261            sample_count += 1
262            try:
263                image = Image.open(sample.filepath)
264            except Exception as e:
265                logging.error(f"Error opening image {sample.filepath}: {str(e)}")
266                continue
267
268            detections = sample[label_field]
269            for det in detections.detections:
270                current_label = det.label
271
272                # Check if it is a one-to-one mapping
273                if current_label in one_to_one_mapping:
274                    new_label = one_to_one_mapping[current_label]
275                    tag = f"new_class_{new_label}_changed_from_{current_label}"
276
277                    # Apply new label and add tag
278                    if tag not in det.tags:
279                        det.tags.append(tag)
280                        det.label = new_label  # Replace label directly
281
282                        self.stats["tags_added_per_category"][new_label] = (
283                            self.stats["tags_added_per_category"].get(new_label, 0) + 1
284                        )
285                        # Log tag changes
286                        tb_writer.add_scalar(f"Tags_Added/{new_label}",
287                                            self.stats["tags_added_per_category"].get(new_label, 0),
288                                            sample_count)
289                    continue
290
291                if current_label not in candidate_labels:
292                    continue
293
294                current_candidate_labels = candidate_labels[current_label]
295                predicted_label, confidence = self.process_detection(image, det, current_candidate_labels)
296
297                self.stats["source_class_counts"][current_label] = self.stats["source_class_counts"].get(current_label, 0) + 1
298                self.stats["total_processed"] += 1
299
300                # If prediction meets threshold and differs from the original label, add a tag.
301                if confidence > threshold and current_label.lower() != predicted_label.lower():
302                    tag = f"new_class_{self.model_name}_{predicted_label}"
303                    if tag not in det.tags:
304                        det.tags.append(tag)
305
306                        #Change Lables if the flag is set to True
307                        if self.change_labels:
308                            det.label = predicted_label
309                        self.stats["changes_made"] += 1
310
311                        self.stats["tags_added_per_category"][predicted_label] = (
312                            self.stats["tags_added_per_category"].get(predicted_label, 0) + 1
313                        )
314
315                        tb_writer.add_scalar(f"Tags_Added/{predicted_label}",
316                                            self.stats["tags_added_per_category"].get(predicted_label, 0),
317                                            sample_count)
318                        tb_writer.add_scalar(f"Total_Tags_Added",
319                                            self.stats["changes_made"],
320                                            sample_count)
321
322            tb_writer.add_scalar(f"Total_Processed_Samples",
323                                self.stats["total_processed"],
324                                sample_count)
325
326            # Log class counts from source and target tag counts based on candidate_labels.
327            for input_source_label, target_temp_labels in self.config["candidate_labels"].items():
328                tb_writer.add_scalar(f"Source_Class_Count/{input_source_label}",
329                                    self.stats["source_class_counts"].get(input_source_label, 0),
330                                    sample_count)
331                for target_temp_label in target_temp_labels:
332                    tb_writer.add_scalar(f"Tags_Added/{target_temp_label}",
333                                        self.stats["tags_added_per_category"].get(target_temp_label, 0),
334                                        sample_count)
335
336        tb_writer.close()
337
338        return self.stats

Class mapper that uses various HuggingFace models to align class labels between the source and target datasets.

ClassMapper(dataset, model_name, config=None)
19    def __init__(self, dataset, model_name, config=None):
20        """Initialize the ClassMapper with dataset and model configuration."""
21        self.dataset = dataset
22        self.model_name = model_name
23        # Get default config from WORKFLOWS and update with any provided config.
24        self.config = WORKFLOWS["class_mapping"].copy()
25        self.change_labels = config["change_labels"]
26        if config:
27            self.config.update(config)
28
29        self.device = "cuda" if torch.cuda.is_available() else "cpu"
30        self.processor = None
31        self.model = None
32        self.stats = {
33            "total_processed": 0,
34            "changes_made": 0,
35            "source_class_counts": {},  # Count of processed detections per source class (e.g., Car, Truck)
36            "tags_added_per_category": {}  # Detailed counts for tags added per target class(e.g., Van, Pickup)
37        }

Initialize the ClassMapper with dataset and model configuration.

dataset
model_name
config
change_labels
device
processor
model
stats
def load_model(self):
39    def load_model(self):
40        """Load the model and processor from HuggingFace."""
41        try:
42            self.hf_model_config = AutoConfig.from_pretrained(self.model_name)
43            self.hf_model_config_name = type(self.hf_model_config).__name__
44
45            if self.hf_model_config_name == "SiglipConfig":
46                self.model = AutoModel.from_pretrained(self.model_name)
47                self.processor = AutoProcessor.from_pretrained(self.model_name)
48
49            elif self.hf_model_config_name == "AlignConfig":
50                self.processor = AlignProcessor.from_pretrained(self.model_name)
51                self.model = AlignModel.from_pretrained(self.model_name)
52
53            elif self.hf_model_config_name == "AltCLIPConfig":
54                self.processor = AltCLIPProcessor.from_pretrained(self.model_name)
55                self.model = AltCLIPModel.from_pretrained(self.model_name)
56
57            elif self.hf_model_config_name == "CLIPSegConfig":
58                self.processor = CLIPSegProcessor.from_pretrained(self.model_name)
59                self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_name)
60
61            elif self.hf_model_config_name in ["Blip2Config", "CLIPConfig"]:
62                self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_name)
63                self.processor = AutoProcessor.from_pretrained(self.model_name)
64
65            else:
66                logging.error(f"Invalid Model Name : {self.model_name}")
67
68            self.model.to(self.device)
69            logging.info(f"Successfully loaded model {self.model_name}")
70
71        except Exception as e:
72            logging.error(f"Failed to load model: {str(e)}")
73            raise

Load the model and processor from HuggingFace.

def process_detection(self, image, detection, candidate_labels):
 75    def process_detection(self, image, detection, candidate_labels):
 76        """Process a single detection with the model."""
 77        # Convert bounding box to pixel coordinates.
 78        img_width, img_height = image.size
 79        bbox = detection.bounding_box
 80        min_x, min_y, width, height = bbox
 81        x1, y1 = int(min_x * img_width), int(min_y * img_height)
 82        x2, y2 = int((min_x + width) * img_width), int((min_y + height) * img_height)
 83
 84        # Crop image to detection region.
 85        image_patch = image.crop((x1, y1, x2, y2))
 86
 87        # Prepare inputs for the model.
 88        if self.hf_model_config_name == "SiglipConfig":
 89            target_size = (384, 384)  # Adjust per model
 90            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
 91            inputs = self.processor(text=candidate_labels, images=image_patch, padding="max_length", return_tensors="pt")
 92
 93        elif self.hf_model_config_name == "CLIPSegConfig":
 94            target_size = (224, 224)  # Adjust per model
 95            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
 96            inputs = self.processor(text=candidate_labels, images=[image_patch]*len(candidate_labels), padding="max_length", return_tensors="pt")
 97
 98        else:
 99            target_size = (224, 224)  # Adjust per model
100            image_patch = image_patch.resize(target_size, Image.Resampling.LANCZOS)
101            inputs = self.processor(images=image_patch, text=candidate_labels, return_tensors="pt", padding=True)
102
103        # Ensure all tensors in the processed inputs are moved to the designated device.
104        inputs = {k: v.to(self.device) for k, v in inputs.items()}
105
106        # Generate classification output.
107        with torch.no_grad():
108            outputs = self.model(**inputs)
109
110        predicted_label = None
111        confidence_score = None
112
113        if self.hf_model_config_name == "SiglipConfig":
114            # Apply sigmoid for probabilities
115            logits = outputs.logits_per_image
116            probs = torch.sigmoid(logits)
117            probs = torch.softmax(probs, dim=1)
118            max_prob, predicted_idx = probs[0].max(dim=-1)
119            predicted_label = candidate_labels[predicted_idx.item()]
120            confidence_score = max_prob.item()
121
122        elif self.hf_model_config_name in ["AlignConfig", "AltCLIPConfig"]:
123            logits = outputs.logits_per_image
124            probs = torch.softmax(logits, dim=1)
125            max_prob, predicted_idx = probs[0].max(dim=-1)
126            predicted_label = candidate_labels[predicted_idx.item()]
127            confidence_score = max_prob.item()
128
129        elif self.hf_model_config_name == "CLIPSegConfig":
130            # Get masks and ensure batch dimension exists
131            masks = torch.sigmoid(outputs.logits)
132            if masks.dim() == 2:  # Handle single-example edge case
133                masks = masks.unsqueeze(0)  # Add batch dimension
134
135            # Verify mask dimensions
136            batch_size, mask_height, mask_width = masks.shape
137
138            # Convert detection box to mask coordinates
139            box_x1 = int(x1 * mask_width / img_width)
140            box_y1 = int(y1 * mask_height / img_height)
141            box_x2 = int(x2 * mask_width / img_width)
142            box_y2 = int(y2 * mask_height / img_height)
143
144            # Calculate scores for each candidate label
145            scores = []
146            for i in range(batch_size):
147                # Extract relevant mask region
148                label_mask = masks[i, box_y1:box_y2, box_x1:box_x2]
149
150                # Handle empty regions gracefully
151                if label_mask.numel() == 0:
152                    scores.append(0.0)
153                    continue
154                scores.append(label_mask.mean().item())
155
156            # Find best match
157            confidence_score = max(scores)
158            predicted_idx = scores.index(confidence_score)
159            predicted_label = candidate_labels[predicted_idx]
160
161        else:
162            logits = outputs.logits_per_image
163            max_logit, predicted_idx = logits.max(dim=-1)
164            predicted_label = candidate_labels[predicted_idx.item()]
165            confidence_score = max_logit.item()
166
167        return predicted_label, confidence_score

Process a single detection with the model.

def run_mapping( self, test_dataset_source, test_dataset_target, label_field='ground_truth'):
169    def run_mapping(self, test_dataset_source, test_dataset_target, label_field = "ground_truth",):
170        """Run the class mapping process between the source dataset and the target dataset."""
171        if not self.model:
172            self.load_model()
173
174        dataset_source_name = self.config.get("dataset_source")
175        dataset_target_name = self.config.get("dataset_target")
176
177        if not dataset_source_name or not dataset_target_name:
178            logging.error("Both 'dataset_source' and 'dataset_target' must be specified in the config.")
179            raise ValueError("Both 'dataset_source' and 'dataset_target' must be specified in the config.")
180
181        # Load the datasets from FiftyOne.
182        try:
183            if test_dataset_source is None:
184                SELECTED_DATASET = {
185                    "name": dataset_source_name,
186                    "n_samples": None,  # 'None' (full dataset) or 'int' (subset of the dataset)
187                }
188                source_dataset, source_dataset_info = load_dataset(SELECTED_DATASET)
189            else:
190                source_dataset = test_dataset_source
191
192        except Exception as e:
193            logging.error(f"Failed to load dataset_source '{dataset_source_name}': {e}")
194            raise ValueError(f"Failed to load dataset_source '{dataset_source_name}': {e}")
195
196        try:
197            if test_dataset_target is None:
198                SELECTED_DATASET = {
199                    "name": dataset_target_name,
200                    "n_samples": None,  # 'None' (full dataset) or 'int' (subset of the dataset)
201                }
202                target_dataset, target_dataset_info = load_dataset(SELECTED_DATASET)
203            else:
204                target_dataset = test_dataset_target
205
206        except Exception as e:
207            logging.error(f"Failed to load dataset_target '{dataset_target_name}': {e}")
208            raise ValueError(f"Failed to load dataset_target '{dataset_target_name}': {e}")
209
210        # Get the distinct labels present in each dataset.
211        source_labels = source_dataset.distinct(f"{label_field}.detections.label")
212        target_labels = target_dataset.distinct(f"{label_field}.detections.label")
213
214        # Access candidate labels from config
215        candidate_labels = self.config["candidate_labels"]
216
217        # Check that all labels from the candidate_labels exist in dataset_source.
218        input_source_labels = list(candidate_labels.keys())
219        missing_source_labels = [p for p in input_source_labels if p not in source_labels]
220        if missing_source_labels:
221            error_msg = (f"Missing labels in dataset_source '{dataset_source_name}': {missing_source_labels}\n"
222                        f"Expected source labels: {input_source_labels}\n"
223                        f"Found in dataset_source: {source_labels}")
224            logging.error(error_msg)
225            raise ValueError(error_msg)
226
227        # Check that all labels from candidate_labels exist in dataset_target.
228        all_target_labels = []
229        for input_target_labels in candidate_labels.values():
230            all_target_labels.extend(input_target_labels)
231        missing_target = [target_temp for target_temp in all_target_labels if target_temp not in target_labels]
232        if missing_target:
233            error_msg = (f"Missing labels in dataset_target '{dataset_target_name}': {missing_target}\n"
234                        f"Expected target classes: {all_target_labels}\n"
235                        f"Found in dataset_target: {target_labels}")
236            logging.error(error_msg)
237            raise ValueError(error_msg)
238
239        one_to_one_mapping = {
240            source: target[0]  # Only pick mappings with exactly 1 target class
241            for source, target in candidate_labels.items()
242            if len(target) == 1
243        }
244
245        log_root="./logs/"
246        experiment_name = f"{self.model_name}_class_mapping_{os.getpid()}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
247        tensorboard_root = os.path.join(
248            log_root, "tensorboard/class_mapping"
249        )
250        dataset_name = getattr(self.dataset, "name", "default_dataset")
251        log_directory = os.path.join(
252            tensorboard_root, dataset_name, experiment_name
253        )
254
255        tb_writer = SummaryWriter(log_dir=log_directory)
256
257        threshold = self.config["thresholds"]["confidence"]
258        sample_count = 0  # For logging steps
259
260        for sample in self.dataset.iter_samples(progress=True, autosave=True):
261            sample_count += 1
262            try:
263                image = Image.open(sample.filepath)
264            except Exception as e:
265                logging.error(f"Error opening image {sample.filepath}: {str(e)}")
266                continue
267
268            detections = sample[label_field]
269            for det in detections.detections:
270                current_label = det.label
271
272                # Check if it is a one-to-one mapping
273                if current_label in one_to_one_mapping:
274                    new_label = one_to_one_mapping[current_label]
275                    tag = f"new_class_{new_label}_changed_from_{current_label}"
276
277                    # Apply new label and add tag
278                    if tag not in det.tags:
279                        det.tags.append(tag)
280                        det.label = new_label  # Replace label directly
281
282                        self.stats["tags_added_per_category"][new_label] = (
283                            self.stats["tags_added_per_category"].get(new_label, 0) + 1
284                        )
285                        # Log tag changes
286                        tb_writer.add_scalar(f"Tags_Added/{new_label}",
287                                            self.stats["tags_added_per_category"].get(new_label, 0),
288                                            sample_count)
289                    continue
290
291                if current_label not in candidate_labels:
292                    continue
293
294                current_candidate_labels = candidate_labels[current_label]
295                predicted_label, confidence = self.process_detection(image, det, current_candidate_labels)
296
297                self.stats["source_class_counts"][current_label] = self.stats["source_class_counts"].get(current_label, 0) + 1
298                self.stats["total_processed"] += 1
299
300                # If prediction meets threshold and differs from the original label, add a tag.
301                if confidence > threshold and current_label.lower() != predicted_label.lower():
302                    tag = f"new_class_{self.model_name}_{predicted_label}"
303                    if tag not in det.tags:
304                        det.tags.append(tag)
305
306                        #Change Lables if the flag is set to True
307                        if self.change_labels:
308                            det.label = predicted_label
309                        self.stats["changes_made"] += 1
310
311                        self.stats["tags_added_per_category"][predicted_label] = (
312                            self.stats["tags_added_per_category"].get(predicted_label, 0) + 1
313                        )
314
315                        tb_writer.add_scalar(f"Tags_Added/{predicted_label}",
316                                            self.stats["tags_added_per_category"].get(predicted_label, 0),
317                                            sample_count)
318                        tb_writer.add_scalar(f"Total_Tags_Added",
319                                            self.stats["changes_made"],
320                                            sample_count)
321
322            tb_writer.add_scalar(f"Total_Processed_Samples",
323                                self.stats["total_processed"],
324                                sample_count)
325
326            # Log class counts from source and target tag counts based on candidate_labels.
327            for input_source_label, target_temp_labels in self.config["candidate_labels"].items():
328                tb_writer.add_scalar(f"Source_Class_Count/{input_source_label}",
329                                    self.stats["source_class_counts"].get(input_source_label, 0),
330                                    sample_count)
331                for target_temp_label in target_temp_labels:
332                    tb_writer.add_scalar(f"Tags_Added/{target_temp_label}",
333                                        self.stats["tags_added_per_category"].get(target_temp_label, 0),
334                                        sample_count)
335
336        tb_writer.close()
337
338        return self.stats

Run the class mapping process between the source dataset and the target dataset.