tests.workflow_auto_label_masks_test

  1import fiftyone as fo
  2import pytest
  3from fiftyone.utils.huggingface import load_from_hub
  4import config.config
  5
  6from main import workflow_auto_label_mask
  7from utils.logging import configure_logging
  8
  9@pytest.fixture(scope="session")
 10def dataset_v51():
 11    dataset_name_hub = "Voxel51/fisheye8k"
 12    dataset_name = "fisheye8k_mask_test"
 13    try:
 14        dataset = fo.load_dataset(dataset_name)
 15    except ValueError:
 16        dataset = load_from_hub(
 17            repo_id=dataset_name_hub,
 18            max_samples=1,
 19            name=dataset_name,
 20        )
 21    return dataset
 22
 23@pytest.fixture(autouse=True)
 24def setup_logging():
 25    configure_logging()
 26
 27@pytest.fixture(autouse=True)
 28def deactivate_wandb_sync():
 29    config.config.WANDB_ACTIVE = False
 30
 31@pytest.mark.parametrize("workflow_config", [
 32    {   # Test 1: SAM2 without a prompt
 33        "semantic_segmentation": {
 34            "sam2": {
 35                "prompt_field": None,
 36                "models": ["segment-anything-2.1-hiera-tiny-image-torch"]
 37            }
 38        },
 39        "depth_estimation": {}
 40    },
 41    {   # Test 2: SAM2 with a prompt
 42        "semantic_segmentation": {
 43            "sam2": {
 44                "prompt_field": "detections",
 45                "models": ["segment-anything-2.1-hiera-tiny-image-torch"]
 46            }
 47        },
 48        "depth_estimation": {}
 49    },
 50    {   # Test 3: Depth Estimation
 51        "semantic_segmentation": {},
 52        "depth_estimation": {
 53            "dpt": {
 54                "models": {"Intel/dpt-swinv2-tiny-256"}
 55            }
 56        }
 57    },
 58])
 59
 60def test_auto_label_mask(dataset_v51, workflow_config):
 61
 62    dataset_info = {"name": "fisheye8k_mask_test"}
 63
 64    print("\n[TEST] Starting workflow_auto_label_mask integration test")
 65    print(f"[TEST] Dataset Name: {dataset_v51.name}")
 66    print(f"[TEST] Number of samples: {len(dataset_v51)}")
 67
 68
 69    # Run if valid config
 70    if workflow_config["semantic_segmentation"] or workflow_config["depth_estimation"]:
 71        workflow_auto_label_mask(dataset_v51, dataset_info, workflow_config)
 72        print("[TEST] workflow_auto_label_mask completed successfully!")
 73    else:
 74        print("[TEST] Skipping workflow, running assertion checks on unmodified dataset.")
 75
 76    assert len(dataset_v51) > 0, "The dataset should not be empty after processing"
 77
 78    if workflow_config["semantic_segmentation"]:
 79        prompt_field = workflow_config["semantic_segmentation"]["sam2"]["prompt_field"]
 80    else:
 81        prompt_field = None
 82
 83    expected_depth_field = "pred_de_Intel_dpt_swinv2_tiny_256"
 84    expected_sam_field_noprompt = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt"
 85    expected_sam_field_prompt = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}"
 86
 87    field_gt = "detections"
 88
 89    if prompt_field is None:
 90        field_masks = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt"
 91    else:
 92        field_masks = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}"
 93
 94    classes_gt = set()
 95    classes_sam = set()
 96    n_found_fields = 0
 97
 98    for sample in dataset_v51:
 99        print(f"Fields in sample: {sample.field_names}")
100        if workflow_config["semantic_segmentation"]:
101            bboxes_gt = sample[field_gt]
102
103            try:
104                sam_masks = sample[field_masks]
105            except KeyError:
106                print(f"Field {field_masks} not found in sample {sample}")
107
108            try:
109                if prompt_field is not None:
110                    for bbox in bboxes_gt.detections:
111                        classes_gt.add(bbox.label)
112                    for mask in sam_masks.detections:
113                        classes_sam.add(mask.label)
114                    field = sample[expected_sam_field_prompt]
115                else:
116                    field = sample[expected_sam_field_noprompt]
117                n_found_fields += 1
118
119            except Exception as e:
120                print(f"Error: {e}")
121                pass
122        if workflow_config["depth_estimation"]:
123            try:
124                field = sample[expected_depth_field]
125                n_found_fields += 1
126            except:
127                pass
128
129    assert classes_gt == classes_sam, f"Classes in Ground Truth {classes_gt} and SAM Masks {classes_sam}"
130    print(f"[TEST] Found {n_found_fields} new fields in the dataset")
131    assert n_found_fields > 0, "No new fields were added to the dataset"
132
133    print("Class Distribution in Ground Truth Bounding Boxes:")
134    print(f"\tGround Truth Classes: {classes_gt}")
135    print(f"\tSAM Masks: {classes_sam}")
136
137    print("[TEST] Verified that new fields are present in the dataset.")
@pytest.fixture(scope='session')
def dataset_v51():
10@pytest.fixture(scope="session")
11def dataset_v51():
12    dataset_name_hub = "Voxel51/fisheye8k"
13    dataset_name = "fisheye8k_mask_test"
14    try:
15        dataset = fo.load_dataset(dataset_name)
16    except ValueError:
17        dataset = load_from_hub(
18            repo_id=dataset_name_hub,
19            max_samples=1,
20            name=dataset_name,
21        )
22    return dataset
@pytest.fixture(autouse=True)
def setup_logging():
24@pytest.fixture(autouse=True)
25def setup_logging():
26    configure_logging()
@pytest.fixture(autouse=True)
def deactivate_wandb_sync():
28@pytest.fixture(autouse=True)
29def deactivate_wandb_sync():
30    config.config.WANDB_ACTIVE = False
@pytest.mark.parametrize('workflow_config', [{'semantic_segmentation': {'sam2': {'prompt_field': None, 'models': ['segment-anything-2.1-hiera-tiny-image-torch']}}, 'depth_estimation': {}}, {'semantic_segmentation': {'sam2': {'prompt_field': 'detections', 'models': ['segment-anything-2.1-hiera-tiny-image-torch']}}, 'depth_estimation': {}}, {'semantic_segmentation': {}, 'depth_estimation': {'dpt': {'models': {'Intel/dpt-swinv2-tiny-256'}}}}])
def test_auto_label_mask(dataset_v51, workflow_config):
 32@pytest.mark.parametrize("workflow_config", [
 33    {   # Test 1: SAM2 without a prompt
 34        "semantic_segmentation": {
 35            "sam2": {
 36                "prompt_field": None,
 37                "models": ["segment-anything-2.1-hiera-tiny-image-torch"]
 38            }
 39        },
 40        "depth_estimation": {}
 41    },
 42    {   # Test 2: SAM2 with a prompt
 43        "semantic_segmentation": {
 44            "sam2": {
 45                "prompt_field": "detections",
 46                "models": ["segment-anything-2.1-hiera-tiny-image-torch"]
 47            }
 48        },
 49        "depth_estimation": {}
 50    },
 51    {   # Test 3: Depth Estimation
 52        "semantic_segmentation": {},
 53        "depth_estimation": {
 54            "dpt": {
 55                "models": {"Intel/dpt-swinv2-tiny-256"}
 56            }
 57        }
 58    },
 59])
 60
 61def test_auto_label_mask(dataset_v51, workflow_config):
 62
 63    dataset_info = {"name": "fisheye8k_mask_test"}
 64
 65    print("\n[TEST] Starting workflow_auto_label_mask integration test")
 66    print(f"[TEST] Dataset Name: {dataset_v51.name}")
 67    print(f"[TEST] Number of samples: {len(dataset_v51)}")
 68
 69
 70    # Run if valid config
 71    if workflow_config["semantic_segmentation"] or workflow_config["depth_estimation"]:
 72        workflow_auto_label_mask(dataset_v51, dataset_info, workflow_config)
 73        print("[TEST] workflow_auto_label_mask completed successfully!")
 74    else:
 75        print("[TEST] Skipping workflow, running assertion checks on unmodified dataset.")
 76
 77    assert len(dataset_v51) > 0, "The dataset should not be empty after processing"
 78
 79    if workflow_config["semantic_segmentation"]:
 80        prompt_field = workflow_config["semantic_segmentation"]["sam2"]["prompt_field"]
 81    else:
 82        prompt_field = None
 83
 84    expected_depth_field = "pred_de_Intel_dpt_swinv2_tiny_256"
 85    expected_sam_field_noprompt = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt"
 86    expected_sam_field_prompt = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}"
 87
 88    field_gt = "detections"
 89
 90    if prompt_field is None:
 91        field_masks = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt"
 92    else:
 93        field_masks = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}"
 94
 95    classes_gt = set()
 96    classes_sam = set()
 97    n_found_fields = 0
 98
 99    for sample in dataset_v51:
100        print(f"Fields in sample: {sample.field_names}")
101        if workflow_config["semantic_segmentation"]:
102            bboxes_gt = sample[field_gt]
103
104            try:
105                sam_masks = sample[field_masks]
106            except KeyError:
107                print(f"Field {field_masks} not found in sample {sample}")
108
109            try:
110                if prompt_field is not None:
111                    for bbox in bboxes_gt.detections:
112                        classes_gt.add(bbox.label)
113                    for mask in sam_masks.detections:
114                        classes_sam.add(mask.label)
115                    field = sample[expected_sam_field_prompt]
116                else:
117                    field = sample[expected_sam_field_noprompt]
118                n_found_fields += 1
119
120            except Exception as e:
121                print(f"Error: {e}")
122                pass
123        if workflow_config["depth_estimation"]:
124            try:
125                field = sample[expected_depth_field]
126                n_found_fields += 1
127            except:
128                pass
129
130    assert classes_gt == classes_sam, f"Classes in Ground Truth {classes_gt} and SAM Masks {classes_sam}"
131    print(f"[TEST] Found {n_found_fields} new fields in the dataset")
132    assert n_found_fields > 0, "No new fields were added to the dataset"
133
134    print("Class Distribution in Ground Truth Bounding Boxes:")
135    print(f"\tGround Truth Classes: {classes_gt}")
136    print(f"\tSAM Masks: {classes_sam}")
137
138    print("[TEST] Verified that new fields are present in the dataset.")