workflows.anomaly_detection

  1import logging
  2import os
  3
  4import anomalib.models
  5import fiftyone as fo
  6import torch
  7from anomalib import TaskType
  8from anomalib.data.image.folder import Folder
  9from anomalib.data.utils import read_image
 10from anomalib.deploy import ExportType, TorchInferencer
 11from anomalib.engine import Engine
 12from anomalib.loggers import AnomalibTensorBoardLogger
 13from huggingface_hub import HfApi, hf_hub_download
 14from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
 15from torchvision.transforms.v2 import Compose, Resize
 16
 17from config.config import GLOBAL_SEED, HF_DO_UPLOAD, HF_ROOT, NUM_WORKERS
 18
 19
 20class Anodec:
 21    """Anomaly detection model class for managing training, inference, and evaluation of anomaly detection models using Anomalib.
 22
 23    Relevant links:
 24    - https://docs.voxel51.com/tutorials/anomaly_detection.html
 25    - https://medium.com/@enrico.randellini/anomalib-a-library-for-image-anomaly-detection-and-localization-fb363639104f
 26    - https://github.com/openvinotoolkit/anomalib
 27    - https://anomalib.readthedocs.io/en/stable/
 28    """
 29
 30    def __init__(
 31        self,
 32        dataset,
 33        eval_metrics,
 34        dataset_info,
 35        config,
 36        tensorboard_output,
 37        anomalib_output_root="./output/models/anomalib/",
 38    ):
 39        """Initialize the anomaly detection module with dataset, evaluation metrics, config, and output paths."""
 40        torch.set_float32_matmul_precision(
 41            "medium"
 42        )  # Utilize Tensor core, came in warning
 43        self.config = config
 44        self.dataset = dataset
 45        self.eval_metrics = eval_metrics
 46        self.normal_data = dataset.match_tags("train")
 47        self.abnormal_data = dataset.match_tags(["val", "test"])
 48        self.dataset_name = dataset_info["name"]
 49        self.TASK = TaskType.SEGMENTATION
 50        self.model_name = self.config["model_name"]
 51        self.image_size = self.config["image_size"]
 52        self.batch_size = self.config["batch_size"]
 53        self.tensorboard_output = os.path.abspath(tensorboard_output)
 54        self.anomalib_output_root = os.path.abspath(anomalib_output_root)
 55        self.model_path = os.path.join(
 56            anomalib_output_root,
 57            self.model_name,
 58            self.dataset_name,
 59            "weights/torch/model.pt",
 60        )
 61        self.field_gt_anomaly_mask = "ground_truth_anomaly_mask"
 62
 63        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_anomalib_{self.model_name}"
 64
 65        # Anomalib objects
 66        self.inferencer = None
 67        self.engine = None
 68        self.datamodule = None
 69        self.anomalib_logger = None
 70
 71    def __del__(self):
 72        """Destructor method that unlinks symlinks and finalizes the anomaly detection logger."""
 73        try:
 74            self.unlink_symlinks()
 75            self.anomalib_logger.finalize("success")
 76        except:
 77            pass
 78
 79    def create_datamodule(self, transform):
 80        """Create datamodule for anomaly detection by preparing and symlink images/masks for the Anomalib datamodule."""
 81
 82        # Symlink the images and masks to the directory Anomalib expects.
 83        logging.info("Preparing images and masks for Anomalib")
 84        for sample in self.abnormal_data.iter_samples(progress=True, autosave=True):
 85            # Add mask groundtruth
 86            base_filename = sample.filename
 87            mask_filename = os.path.basename(base_filename).replace(".jpg", ".png")
 88
 89            mask_path = os.path.join(self.mask_dir, mask_filename)
 90            logging.debug(f"Assigned mask {mask_path} to sample {base_filename}")
 91
 92            if not os.path.exists(mask_path):
 93                logging.error(f"Mask file not found: {mask_path}")
 94
 95            sample[self.field_gt_anomaly_mask] = fo.Segmentation(mask_path=mask_path)
 96
 97            dir_name = os.path.dirname(sample.filepath).split("/")[-1]
 98            new_filename = f"{dir_name}_{base_filename}"
 99            if not os.path.exists(os.path.join(self.abnormal_dir, new_filename)):
100                os.symlink(
101                    sample.filepath, os.path.join(self.abnormal_dir, new_filename)
102                )
103
104            if not os.path.exists(os.path.join(self.mask_dir, new_filename)):
105                os.symlink(
106                    sample[self.field_gt_anomaly_mask].mask_path,
107                    os.path.join(self.mask_dir, new_filename),
108                )
109
110        logging.info(f"{len(self.normal_data)} normal images in train split.")
111        self.datamodule = Folder(
112            name=self.dataset_name,
113            normal_dir=self.normal_dir,
114            abnormal_dir=self.abnormal_dir,
115            mask_dir=self.mask_dir,
116            task=self.TASK,
117            transform=transform,
118            train_batch_size=self.batch_size,
119            eval_batch_size=self.batch_size,
120            num_workers=NUM_WORKERS,
121            seed=GLOBAL_SEED,
122        )
123
124        self.datamodule.setup()
125
126    def unlink_symlinks(self):
127        """Removes symbolic links for abnormal samples and masks."""
128        for sample in self.abnormal_data.iter_samples(progress=True):
129            base_filename = sample.filename
130            dir_name = os.path.dirname(sample.filepath).split("/")[-1]
131            new_filename = f"{dir_name}_{base_filename}"
132
133            try:
134                os.unlink(os.path.join(self.abnormal_dir, new_filename))
135            except Exception as e:
136                logging.debug(
137                    f"Unlinking of {os.path.join(self.abnormal_dir, new_filename)} failed: {e}"
138                )
139
140            try:
141                os.unlink(os.path.join(self.mask_dir, new_filename))
142            except Exception as e:
143                logging.debug(
144                    f"Unlinking of {os.path.join(self.mask_dir, new_filename)} failed: {e}"
145                )
146
147    def train_and_export_model(self):
148        """Train an anomaly detection model if not already trained and export it, optionally uploading to HuggingFace."""
149
150        MAX_EPOCHS = self.config["epochs"]
151        PATIENCE = self.config["early_stop_patience"]
152
153        # Set folders
154        data_root = os.path.abspath(self.config["data_root"])
155        dataset_folder_ano_dec_masks = f"{self.dataset_name}_anomaly_detection_masks/"
156        filepath_masks = os.path.join(data_root, dataset_folder_ano_dec_masks)
157
158        filepath_train = self.normal_data.take(1).first().filepath
159        filepath_val = self.abnormal_data.take(1).first().filepath
160
161        self.normal_dir = os.path.dirname(filepath_train)
162        self.abnormal_dir = os.path.dirname(filepath_val)
163        self.mask_dir = os.path.dirname(filepath_masks)
164
165        # Resize image if defined in config
166        if self.image_size is not None:
167            transform = Compose([Resize(self.image_size, antialias=True)])
168        else:
169            transform = None
170
171        self.create_datamodule(transform=transform)
172        if not os.path.exists(self.model_path):
173            self.model = getattr(anomalib.models, self.model_name)()
174
175            os.makedirs(self.anomalib_output_root, exist_ok=True)
176            os.makedirs(self.tensorboard_output, exist_ok=True)
177            self.unlink_symlinks()
178            self.anomalib_logger = AnomalibTensorBoardLogger(
179                save_dir=self.tensorboard_output,
180            )
181
182            # Callbacks
183            callbacks = [
184                ModelCheckpoint(
185                    mode="max",
186                    monitor="pixel_AUROC",
187                    save_last=True,
188                    verbose=True,
189                    auto_insert_metric_name=True,
190                    every_n_epochs=1,
191                ),
192                EarlyStopping(monitor="pixel_AUROC", mode="max", patience=PATIENCE),
193            ]
194            self.engine = Engine(
195                task=self.TASK,
196                default_root_dir=self.anomalib_output_root,
197                logger=self.anomalib_logger,
198                max_epochs=MAX_EPOCHS,
199                callbacks=callbacks,
200                # image_metrics=self.eval_metrics, #Classification for whole image
201                pixel_metrics=self.eval_metrics,
202                accelerator="auto",
203            )
204            self.engine.fit(model=self.model, datamodule=self.datamodule)
205
206            # Export and generate inferencer
207            export_root = self.model_path.replace("weights/torch/model.pt", "")
208            self.engine.export(
209                model=self.model,
210                export_root=export_root,
211                export_type=ExportType.TORCH,
212                ckpt_path=self.engine.trainer.checkpoint_callback.best_model_path,
213            )
214
215            # Upload model to Hugging Face
216            if HF_DO_UPLOAD == True:
217                logging.info(f"Uploading model to Hugging Face: {self.hf_repo_name}")
218                api = HfApi()
219                api.create_repo(
220                    self.hf_repo_name, private=True, repo_type="model", exist_ok=True
221                )
222                api.upload_file(
223                    path_or_fileobj=self.model_path,
224                    path_in_repo="model.pt",
225                    repo_id=self.hf_repo_name,
226                    repo_type="model",
227                )
228
229        else:
230            logging.warning(
231                f"Skipping model {self.model_name}, training results are already in {self.model_path}."
232            )
233
234    def validate_model(self):
235        """Test the anomaly detection model using the designated testing dataset and log the performance results."""
236        if self.engine:
237            test_results = self.engine.test(
238                model=self.model,
239                datamodule=self.datamodule,
240                ckpt_path=self.engine.trainer.checkpoint_callback.best_model_path,
241            )
242            logging.info(f"Model test results: {test_results}")
243        else:
244            logging.error(f"Engine '{self.engine}' not available.")
245
246    def run_inference(self, mode):
247        """Runs the anomaly detection inference on the dataset either for the train-val or generic data."""
248        logging.info(f"Running inference")
249        try:
250            if os.path.exists(self.model_path):
251                file_path = self.model_path
252                logging.info(f"Loading model {self.model_name} from disk: {file_path}")
253            else:
254                download_dir = self.model_path.replace("model.pt", "")
255                logging.info(
256                    f"Downloading model {self.hf_repo_name} from Hugging Face to {download_dir}"
257                )
258                file_path = hf_hub_download(
259                    repo_id=self.hf_repo_name,
260                    filename="model.pt",
261                    local_dir=download_dir,
262                )
263        except Exception as e:
264            logging.error(f"Failed to load or download model: {str(e)}.")
265            return False
266
267        device = "cuda" if torch.cuda.is_available() else "cpu"
268        inferencer = TorchInferencer(path=os.path.join(file_path), device=device)
269        self.inferencer = inferencer
270
271        if mode == "train":
272            dataset = self.abnormal_data
273            logging.info(f"{len(self.abnormal_data)} images in evaluation split.")
274        elif mode == "inference":
275            dataset = self.dataset
276        else:
277            dataset = None
278            logging.error(f"Mode {mode} is not suported during inference.")
279
280        field_pred_anomaly_score = f"pred_anomaly_score_{self.model_name}"
281        field_pred_anomaly_map = f"pred_anomaly_map_{self.model_name}"
282        field_pred_anomaly_mask = f"pred_anomaly_mask_{self.model_name}"
283
284        for sample in dataset.iter_samples(autosave=True, progress=True):
285            image = read_image(sample.filepath, as_tensor=True)
286            output = self.inferencer.predict(image)
287
288            # Storing results in Voxel51 dataset
289            sample[field_pred_anomaly_score] = output.pred_score
290            sample[field_pred_anomaly_map] = fo.Heatmap(map=output.anomaly_map)
291            sample[field_pred_anomaly_mask] = fo.Segmentation(mask=output.pred_mask)
292
293    def eval_v51(self):
294        """Evaluates segmentation performance of the anomaly detection model on the abnormal dataset."""
295
296        eval_seg = self.abnormal_data.evaluate_segmentations(
297            f"pred_anomaly_mask_{self.model_name}",
298            gt_field=self.field_gt_anomaly_mask,
299            eval_key=f"eval_seg_{self.model_name}",
300        )
301        eval_seg.print_report(classes=[0, 255])
class Anodec:
 21class Anodec:
 22    """Anomaly detection model class for managing training, inference, and evaluation of anomaly detection models using Anomalib.
 23
 24    Relevant links:
 25    - https://docs.voxel51.com/tutorials/anomaly_detection.html
 26    - https://medium.com/@enrico.randellini/anomalib-a-library-for-image-anomaly-detection-and-localization-fb363639104f
 27    - https://github.com/openvinotoolkit/anomalib
 28    - https://anomalib.readthedocs.io/en/stable/
 29    """
 30
 31    def __init__(
 32        self,
 33        dataset,
 34        eval_metrics,
 35        dataset_info,
 36        config,
 37        tensorboard_output,
 38        anomalib_output_root="./output/models/anomalib/",
 39    ):
 40        """Initialize the anomaly detection module with dataset, evaluation metrics, config, and output paths."""
 41        torch.set_float32_matmul_precision(
 42            "medium"
 43        )  # Utilize Tensor core, came in warning
 44        self.config = config
 45        self.dataset = dataset
 46        self.eval_metrics = eval_metrics
 47        self.normal_data = dataset.match_tags("train")
 48        self.abnormal_data = dataset.match_tags(["val", "test"])
 49        self.dataset_name = dataset_info["name"]
 50        self.TASK = TaskType.SEGMENTATION
 51        self.model_name = self.config["model_name"]
 52        self.image_size = self.config["image_size"]
 53        self.batch_size = self.config["batch_size"]
 54        self.tensorboard_output = os.path.abspath(tensorboard_output)
 55        self.anomalib_output_root = os.path.abspath(anomalib_output_root)
 56        self.model_path = os.path.join(
 57            anomalib_output_root,
 58            self.model_name,
 59            self.dataset_name,
 60            "weights/torch/model.pt",
 61        )
 62        self.field_gt_anomaly_mask = "ground_truth_anomaly_mask"
 63
 64        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_anomalib_{self.model_name}"
 65
 66        # Anomalib objects
 67        self.inferencer = None
 68        self.engine = None
 69        self.datamodule = None
 70        self.anomalib_logger = None
 71
 72    def __del__(self):
 73        """Destructor method that unlinks symlinks and finalizes the anomaly detection logger."""
 74        try:
 75            self.unlink_symlinks()
 76            self.anomalib_logger.finalize("success")
 77        except:
 78            pass
 79
 80    def create_datamodule(self, transform):
 81        """Create datamodule for anomaly detection by preparing and symlink images/masks for the Anomalib datamodule."""
 82
 83        # Symlink the images and masks to the directory Anomalib expects.
 84        logging.info("Preparing images and masks for Anomalib")
 85        for sample in self.abnormal_data.iter_samples(progress=True, autosave=True):
 86            # Add mask groundtruth
 87            base_filename = sample.filename
 88            mask_filename = os.path.basename(base_filename).replace(".jpg", ".png")
 89
 90            mask_path = os.path.join(self.mask_dir, mask_filename)
 91            logging.debug(f"Assigned mask {mask_path} to sample {base_filename}")
 92
 93            if not os.path.exists(mask_path):
 94                logging.error(f"Mask file not found: {mask_path}")
 95
 96            sample[self.field_gt_anomaly_mask] = fo.Segmentation(mask_path=mask_path)
 97
 98            dir_name = os.path.dirname(sample.filepath).split("/")[-1]
 99            new_filename = f"{dir_name}_{base_filename}"
100            if not os.path.exists(os.path.join(self.abnormal_dir, new_filename)):
101                os.symlink(
102                    sample.filepath, os.path.join(self.abnormal_dir, new_filename)
103                )
104
105            if not os.path.exists(os.path.join(self.mask_dir, new_filename)):
106                os.symlink(
107                    sample[self.field_gt_anomaly_mask].mask_path,
108                    os.path.join(self.mask_dir, new_filename),
109                )
110
111        logging.info(f"{len(self.normal_data)} normal images in train split.")
112        self.datamodule = Folder(
113            name=self.dataset_name,
114            normal_dir=self.normal_dir,
115            abnormal_dir=self.abnormal_dir,
116            mask_dir=self.mask_dir,
117            task=self.TASK,
118            transform=transform,
119            train_batch_size=self.batch_size,
120            eval_batch_size=self.batch_size,
121            num_workers=NUM_WORKERS,
122            seed=GLOBAL_SEED,
123        )
124
125        self.datamodule.setup()
126
127    def unlink_symlinks(self):
128        """Removes symbolic links for abnormal samples and masks."""
129        for sample in self.abnormal_data.iter_samples(progress=True):
130            base_filename = sample.filename
131            dir_name = os.path.dirname(sample.filepath).split("/")[-1]
132            new_filename = f"{dir_name}_{base_filename}"
133
134            try:
135                os.unlink(os.path.join(self.abnormal_dir, new_filename))
136            except Exception as e:
137                logging.debug(
138                    f"Unlinking of {os.path.join(self.abnormal_dir, new_filename)} failed: {e}"
139                )
140
141            try:
142                os.unlink(os.path.join(self.mask_dir, new_filename))
143            except Exception as e:
144                logging.debug(
145                    f"Unlinking of {os.path.join(self.mask_dir, new_filename)} failed: {e}"
146                )
147
148    def train_and_export_model(self):
149        """Train an anomaly detection model if not already trained and export it, optionally uploading to HuggingFace."""
150
151        MAX_EPOCHS = self.config["epochs"]
152        PATIENCE = self.config["early_stop_patience"]
153
154        # Set folders
155        data_root = os.path.abspath(self.config["data_root"])
156        dataset_folder_ano_dec_masks = f"{self.dataset_name}_anomaly_detection_masks/"
157        filepath_masks = os.path.join(data_root, dataset_folder_ano_dec_masks)
158
159        filepath_train = self.normal_data.take(1).first().filepath
160        filepath_val = self.abnormal_data.take(1).first().filepath
161
162        self.normal_dir = os.path.dirname(filepath_train)
163        self.abnormal_dir = os.path.dirname(filepath_val)
164        self.mask_dir = os.path.dirname(filepath_masks)
165
166        # Resize image if defined in config
167        if self.image_size is not None:
168            transform = Compose([Resize(self.image_size, antialias=True)])
169        else:
170            transform = None
171
172        self.create_datamodule(transform=transform)
173        if not os.path.exists(self.model_path):
174            self.model = getattr(anomalib.models, self.model_name)()
175
176            os.makedirs(self.anomalib_output_root, exist_ok=True)
177            os.makedirs(self.tensorboard_output, exist_ok=True)
178            self.unlink_symlinks()
179            self.anomalib_logger = AnomalibTensorBoardLogger(
180                save_dir=self.tensorboard_output,
181            )
182
183            # Callbacks
184            callbacks = [
185                ModelCheckpoint(
186                    mode="max",
187                    monitor="pixel_AUROC",
188                    save_last=True,
189                    verbose=True,
190                    auto_insert_metric_name=True,
191                    every_n_epochs=1,
192                ),
193                EarlyStopping(monitor="pixel_AUROC", mode="max", patience=PATIENCE),
194            ]
195            self.engine = Engine(
196                task=self.TASK,
197                default_root_dir=self.anomalib_output_root,
198                logger=self.anomalib_logger,
199                max_epochs=MAX_EPOCHS,
200                callbacks=callbacks,
201                # image_metrics=self.eval_metrics, #Classification for whole image
202                pixel_metrics=self.eval_metrics,
203                accelerator="auto",
204            )
205            self.engine.fit(model=self.model, datamodule=self.datamodule)
206
207            # Export and generate inferencer
208            export_root = self.model_path.replace("weights/torch/model.pt", "")
209            self.engine.export(
210                model=self.model,
211                export_root=export_root,
212                export_type=ExportType.TORCH,
213                ckpt_path=self.engine.trainer.checkpoint_callback.best_model_path,
214            )
215
216            # Upload model to Hugging Face
217            if HF_DO_UPLOAD == True:
218                logging.info(f"Uploading model to Hugging Face: {self.hf_repo_name}")
219                api = HfApi()
220                api.create_repo(
221                    self.hf_repo_name, private=True, repo_type="model", exist_ok=True
222                )
223                api.upload_file(
224                    path_or_fileobj=self.model_path,
225                    path_in_repo="model.pt",
226                    repo_id=self.hf_repo_name,
227                    repo_type="model",
228                )
229
230        else:
231            logging.warning(
232                f"Skipping model {self.model_name}, training results are already in {self.model_path}."
233            )
234
235    def validate_model(self):
236        """Test the anomaly detection model using the designated testing dataset and log the performance results."""
237        if self.engine:
238            test_results = self.engine.test(
239                model=self.model,
240                datamodule=self.datamodule,
241                ckpt_path=self.engine.trainer.checkpoint_callback.best_model_path,
242            )
243            logging.info(f"Model test results: {test_results}")
244        else:
245            logging.error(f"Engine '{self.engine}' not available.")
246
247    def run_inference(self, mode):
248        """Runs the anomaly detection inference on the dataset either for the train-val or generic data."""
249        logging.info(f"Running inference")
250        try:
251            if os.path.exists(self.model_path):
252                file_path = self.model_path
253                logging.info(f"Loading model {self.model_name} from disk: {file_path}")
254            else:
255                download_dir = self.model_path.replace("model.pt", "")
256                logging.info(
257                    f"Downloading model {self.hf_repo_name} from Hugging Face to {download_dir}"
258                )
259                file_path = hf_hub_download(
260                    repo_id=self.hf_repo_name,
261                    filename="model.pt",
262                    local_dir=download_dir,
263                )
264        except Exception as e:
265            logging.error(f"Failed to load or download model: {str(e)}.")
266            return False
267
268        device = "cuda" if torch.cuda.is_available() else "cpu"
269        inferencer = TorchInferencer(path=os.path.join(file_path), device=device)
270        self.inferencer = inferencer
271
272        if mode == "train":
273            dataset = self.abnormal_data
274            logging.info(f"{len(self.abnormal_data)} images in evaluation split.")
275        elif mode == "inference":
276            dataset = self.dataset
277        else:
278            dataset = None
279            logging.error(f"Mode {mode} is not suported during inference.")
280
281        field_pred_anomaly_score = f"pred_anomaly_score_{self.model_name}"
282        field_pred_anomaly_map = f"pred_anomaly_map_{self.model_name}"
283        field_pred_anomaly_mask = f"pred_anomaly_mask_{self.model_name}"
284
285        for sample in dataset.iter_samples(autosave=True, progress=True):
286            image = read_image(sample.filepath, as_tensor=True)
287            output = self.inferencer.predict(image)
288
289            # Storing results in Voxel51 dataset
290            sample[field_pred_anomaly_score] = output.pred_score
291            sample[field_pred_anomaly_map] = fo.Heatmap(map=output.anomaly_map)
292            sample[field_pred_anomaly_mask] = fo.Segmentation(mask=output.pred_mask)
293
294    def eval_v51(self):
295        """Evaluates segmentation performance of the anomaly detection model on the abnormal dataset."""
296
297        eval_seg = self.abnormal_data.evaluate_segmentations(
298            f"pred_anomaly_mask_{self.model_name}",
299            gt_field=self.field_gt_anomaly_mask,
300            eval_key=f"eval_seg_{self.model_name}",
301        )
302        eval_seg.print_report(classes=[0, 255])
Anodec( dataset, eval_metrics, dataset_info, config, tensorboard_output, anomalib_output_root='./output/models/anomalib/')
31    def __init__(
32        self,
33        dataset,
34        eval_metrics,
35        dataset_info,
36        config,
37        tensorboard_output,
38        anomalib_output_root="./output/models/anomalib/",
39    ):
40        """Initialize the anomaly detection module with dataset, evaluation metrics, config, and output paths."""
41        torch.set_float32_matmul_precision(
42            "medium"
43        )  # Utilize Tensor core, came in warning
44        self.config = config
45        self.dataset = dataset
46        self.eval_metrics = eval_metrics
47        self.normal_data = dataset.match_tags("train")
48        self.abnormal_data = dataset.match_tags(["val", "test"])
49        self.dataset_name = dataset_info["name"]
50        self.TASK = TaskType.SEGMENTATION
51        self.model_name = self.config["model_name"]
52        self.image_size = self.config["image_size"]
53        self.batch_size = self.config["batch_size"]
54        self.tensorboard_output = os.path.abspath(tensorboard_output)
55        self.anomalib_output_root = os.path.abspath(anomalib_output_root)
56        self.model_path = os.path.join(
57            anomalib_output_root,
58            self.model_name,
59            self.dataset_name,
60            "weights/torch/model.pt",
61        )
62        self.field_gt_anomaly_mask = "ground_truth_anomaly_mask"
63
64        self.hf_repo_name = f"{HF_ROOT}/{self.dataset_name}_anomalib_{self.model_name}"
65
66        # Anomalib objects
67        self.inferencer = None
68        self.engine = None
69        self.datamodule = None
70        self.anomalib_logger = None

Initialize the anomaly detection module with dataset, evaluation metrics, config, and output paths.

config
dataset
eval_metrics
normal_data
abnormal_data
dataset_name
TASK
model_name
image_size
batch_size
tensorboard_output
anomalib_output_root
model_path
field_gt_anomaly_mask
hf_repo_name
inferencer
engine
datamodule
anomalib_logger
def create_datamodule(self, transform):
 80    def create_datamodule(self, transform):
 81        """Create datamodule for anomaly detection by preparing and symlink images/masks for the Anomalib datamodule."""
 82
 83        # Symlink the images and masks to the directory Anomalib expects.
 84        logging.info("Preparing images and masks for Anomalib")
 85        for sample in self.abnormal_data.iter_samples(progress=True, autosave=True):
 86            # Add mask groundtruth
 87            base_filename = sample.filename
 88            mask_filename = os.path.basename(base_filename).replace(".jpg", ".png")
 89
 90            mask_path = os.path.join(self.mask_dir, mask_filename)
 91            logging.debug(f"Assigned mask {mask_path} to sample {base_filename}")
 92
 93            if not os.path.exists(mask_path):
 94                logging.error(f"Mask file not found: {mask_path}")
 95
 96            sample[self.field_gt_anomaly_mask] = fo.Segmentation(mask_path=mask_path)
 97
 98            dir_name = os.path.dirname(sample.filepath).split("/")[-1]
 99            new_filename = f"{dir_name}_{base_filename}"
100            if not os.path.exists(os.path.join(self.abnormal_dir, new_filename)):
101                os.symlink(
102                    sample.filepath, os.path.join(self.abnormal_dir, new_filename)
103                )
104
105            if not os.path.exists(os.path.join(self.mask_dir, new_filename)):
106                os.symlink(
107                    sample[self.field_gt_anomaly_mask].mask_path,
108                    os.path.join(self.mask_dir, new_filename),
109                )
110
111        logging.info(f"{len(self.normal_data)} normal images in train split.")
112        self.datamodule = Folder(
113            name=self.dataset_name,
114            normal_dir=self.normal_dir,
115            abnormal_dir=self.abnormal_dir,
116            mask_dir=self.mask_dir,
117            task=self.TASK,
118            transform=transform,
119            train_batch_size=self.batch_size,
120            eval_batch_size=self.batch_size,
121            num_workers=NUM_WORKERS,
122            seed=GLOBAL_SEED,
123        )
124
125        self.datamodule.setup()

Create datamodule for anomaly detection by preparing and symlink images/masks for the Anomalib datamodule.

def train_and_export_model(self):
148    def train_and_export_model(self):
149        """Train an anomaly detection model if not already trained and export it, optionally uploading to HuggingFace."""
150
151        MAX_EPOCHS = self.config["epochs"]
152        PATIENCE = self.config["early_stop_patience"]
153
154        # Set folders
155        data_root = os.path.abspath(self.config["data_root"])
156        dataset_folder_ano_dec_masks = f"{self.dataset_name}_anomaly_detection_masks/"
157        filepath_masks = os.path.join(data_root, dataset_folder_ano_dec_masks)
158
159        filepath_train = self.normal_data.take(1).first().filepath
160        filepath_val = self.abnormal_data.take(1).first().filepath
161
162        self.normal_dir = os.path.dirname(filepath_train)
163        self.abnormal_dir = os.path.dirname(filepath_val)
164        self.mask_dir = os.path.dirname(filepath_masks)
165
166        # Resize image if defined in config
167        if self.image_size is not None:
168            transform = Compose([Resize(self.image_size, antialias=True)])
169        else:
170            transform = None
171
172        self.create_datamodule(transform=transform)
173        if not os.path.exists(self.model_path):
174            self.model = getattr(anomalib.models, self.model_name)()
175
176            os.makedirs(self.anomalib_output_root, exist_ok=True)
177            os.makedirs(self.tensorboard_output, exist_ok=True)
178            self.unlink_symlinks()
179            self.anomalib_logger = AnomalibTensorBoardLogger(
180                save_dir=self.tensorboard_output,
181            )
182
183            # Callbacks
184            callbacks = [
185                ModelCheckpoint(
186                    mode="max",
187                    monitor="pixel_AUROC",
188                    save_last=True,
189                    verbose=True,
190                    auto_insert_metric_name=True,
191                    every_n_epochs=1,
192                ),
193                EarlyStopping(monitor="pixel_AUROC", mode="max", patience=PATIENCE),
194            ]
195            self.engine = Engine(
196                task=self.TASK,
197                default_root_dir=self.anomalib_output_root,
198                logger=self.anomalib_logger,
199                max_epochs=MAX_EPOCHS,
200                callbacks=callbacks,
201                # image_metrics=self.eval_metrics, #Classification for whole image
202                pixel_metrics=self.eval_metrics,
203                accelerator="auto",
204            )
205            self.engine.fit(model=self.model, datamodule=self.datamodule)
206
207            # Export and generate inferencer
208            export_root = self.model_path.replace("weights/torch/model.pt", "")
209            self.engine.export(
210                model=self.model,
211                export_root=export_root,
212                export_type=ExportType.TORCH,
213                ckpt_path=self.engine.trainer.checkpoint_callback.best_model_path,
214            )
215
216            # Upload model to Hugging Face
217            if HF_DO_UPLOAD == True:
218                logging.info(f"Uploading model to Hugging Face: {self.hf_repo_name}")
219                api = HfApi()
220                api.create_repo(
221                    self.hf_repo_name, private=True, repo_type="model", exist_ok=True
222                )
223                api.upload_file(
224                    path_or_fileobj=self.model_path,
225                    path_in_repo="model.pt",
226                    repo_id=self.hf_repo_name,
227                    repo_type="model",
228                )
229
230        else:
231            logging.warning(
232                f"Skipping model {self.model_name}, training results are already in {self.model_path}."
233            )

Train an anomaly detection model if not already trained and export it, optionally uploading to HuggingFace.

def validate_model(self):
235    def validate_model(self):
236        """Test the anomaly detection model using the designated testing dataset and log the performance results."""
237        if self.engine:
238            test_results = self.engine.test(
239                model=self.model,
240                datamodule=self.datamodule,
241                ckpt_path=self.engine.trainer.checkpoint_callback.best_model_path,
242            )
243            logging.info(f"Model test results: {test_results}")
244        else:
245            logging.error(f"Engine '{self.engine}' not available.")

Test the anomaly detection model using the designated testing dataset and log the performance results.

def run_inference(self, mode):
247    def run_inference(self, mode):
248        """Runs the anomaly detection inference on the dataset either for the train-val or generic data."""
249        logging.info(f"Running inference")
250        try:
251            if os.path.exists(self.model_path):
252                file_path = self.model_path
253                logging.info(f"Loading model {self.model_name} from disk: {file_path}")
254            else:
255                download_dir = self.model_path.replace("model.pt", "")
256                logging.info(
257                    f"Downloading model {self.hf_repo_name} from Hugging Face to {download_dir}"
258                )
259                file_path = hf_hub_download(
260                    repo_id=self.hf_repo_name,
261                    filename="model.pt",
262                    local_dir=download_dir,
263                )
264        except Exception as e:
265            logging.error(f"Failed to load or download model: {str(e)}.")
266            return False
267
268        device = "cuda" if torch.cuda.is_available() else "cpu"
269        inferencer = TorchInferencer(path=os.path.join(file_path), device=device)
270        self.inferencer = inferencer
271
272        if mode == "train":
273            dataset = self.abnormal_data
274            logging.info(f"{len(self.abnormal_data)} images in evaluation split.")
275        elif mode == "inference":
276            dataset = self.dataset
277        else:
278            dataset = None
279            logging.error(f"Mode {mode} is not suported during inference.")
280
281        field_pred_anomaly_score = f"pred_anomaly_score_{self.model_name}"
282        field_pred_anomaly_map = f"pred_anomaly_map_{self.model_name}"
283        field_pred_anomaly_mask = f"pred_anomaly_mask_{self.model_name}"
284
285        for sample in dataset.iter_samples(autosave=True, progress=True):
286            image = read_image(sample.filepath, as_tensor=True)
287            output = self.inferencer.predict(image)
288
289            # Storing results in Voxel51 dataset
290            sample[field_pred_anomaly_score] = output.pred_score
291            sample[field_pred_anomaly_map] = fo.Heatmap(map=output.anomaly_map)
292            sample[field_pred_anomaly_mask] = fo.Segmentation(mask=output.pred_mask)

Runs the anomaly detection inference on the dataset either for the train-val or generic data.

def eval_v51(self):
294    def eval_v51(self):
295        """Evaluates segmentation performance of the anomaly detection model on the abnormal dataset."""
296
297        eval_seg = self.abnormal_data.evaluate_segmentations(
298            f"pred_anomaly_mask_{self.model_name}",
299            gt_field=self.field_gt_anomaly_mask,
300            eval_key=f"eval_seg_{self.model_name}",
301        )
302        eval_seg.print_report(classes=[0, 255])

Evaluates segmentation performance of the anomaly detection model on the abnormal dataset.