tests.workflow_auto_label_masks_test
1import fiftyone as fo 2import pytest 3from fiftyone.utils.huggingface import load_from_hub 4import config.config 5 6from main import workflow_auto_label_mask 7from utils.logging import configure_logging 8 9@pytest.fixture(scope="session") 10def dataset_v51(): 11 dataset_name_hub = "Voxel51/fisheye8k" 12 dataset_name = "fisheye8k_mask_test" 13 try: 14 dataset = fo.load_dataset(dataset_name) 15 except ValueError: 16 dataset = load_from_hub( 17 repo_id=dataset_name_hub, 18 max_samples=1, 19 name=dataset_name, 20 ) 21 return dataset 22 23@pytest.fixture(autouse=True) 24def setup_logging(): 25 configure_logging() 26 27@pytest.fixture(autouse=True) 28def deactivate_wandb_sync(): 29 config.config.WANDB_ACTIVE = False 30 31@pytest.mark.parametrize("workflow_config", [ 32 { # Test 1: SAM2 without a prompt 33 "semantic_segmentation": { 34 "sam2": { 35 "prompt_field": None, 36 "models": ["segment-anything-2.1-hiera-tiny-image-torch"] 37 } 38 }, 39 "depth_estimation": {} 40 }, 41 { # Test 2: SAM2 with a prompt 42 "semantic_segmentation": { 43 "sam2": { 44 "prompt_field": "detections", 45 "models": ["segment-anything-2.1-hiera-tiny-image-torch"] 46 } 47 }, 48 "depth_estimation": {} 49 }, 50 { # Test 3: Depth Estimation 51 "semantic_segmentation": {}, 52 "depth_estimation": { 53 "dpt": { 54 "models": {"Intel/dpt-swinv2-tiny-256"} 55 } 56 } 57 }, 58]) 59 60def test_auto_label_mask(dataset_v51, workflow_config): 61 62 dataset_info = {"name": "fisheye8k_mask_test"} 63 64 print("\n[TEST] Starting workflow_auto_label_mask integration test") 65 print(f"[TEST] Dataset Name: {dataset_v51.name}") 66 print(f"[TEST] Number of samples: {len(dataset_v51)}") 67 68 69 # Run if valid config 70 if workflow_config["semantic_segmentation"] or workflow_config["depth_estimation"]: 71 workflow_auto_label_mask(dataset_v51, dataset_info, workflow_config) 72 print("[TEST] workflow_auto_label_mask completed successfully!") 73 else: 74 print("[TEST] Skipping workflow, running assertion checks on unmodified dataset.") 75 76 assert len(dataset_v51) > 0, "The dataset should not be empty after processing" 77 78 if workflow_config["semantic_segmentation"]: 79 prompt_field = workflow_config["semantic_segmentation"]["sam2"]["prompt_field"] 80 else: 81 prompt_field = None 82 83 expected_depth_field = "pred_de_Intel_dpt_swinv2_tiny_256" 84 expected_sam_field_noprompt = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt" 85 expected_sam_field_prompt = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}" 86 87 field_gt = "detections" 88 89 if prompt_field is None: 90 field_masks = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt" 91 else: 92 field_masks = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}" 93 94 classes_gt = set() 95 classes_sam = set() 96 n_found_fields = 0 97 98 for sample in dataset_v51: 99 print(f"Fields in sample: {sample.field_names}") 100 if workflow_config["semantic_segmentation"]: 101 bboxes_gt = sample[field_gt] 102 103 try: 104 sam_masks = sample[field_masks] 105 except KeyError: 106 print(f"Field {field_masks} not found in sample {sample}") 107 108 try: 109 if prompt_field is not None: 110 for bbox in bboxes_gt.detections: 111 classes_gt.add(bbox.label) 112 for mask in sam_masks.detections: 113 classes_sam.add(mask.label) 114 field = sample[expected_sam_field_prompt] 115 else: 116 field = sample[expected_sam_field_noprompt] 117 n_found_fields += 1 118 119 except Exception as e: 120 print(f"Error: {e}") 121 pass 122 if workflow_config["depth_estimation"]: 123 try: 124 field = sample[expected_depth_field] 125 n_found_fields += 1 126 except: 127 pass 128 129 assert classes_gt == classes_sam, f"Classes in Ground Truth {classes_gt} and SAM Masks {classes_sam}" 130 print(f"[TEST] Found {n_found_fields} new fields in the dataset") 131 assert n_found_fields > 0, "No new fields were added to the dataset" 132 133 print("Class Distribution in Ground Truth Bounding Boxes:") 134 print(f"\tGround Truth Classes: {classes_gt}") 135 print(f"\tSAM Masks: {classes_sam}") 136 137 print("[TEST] Verified that new fields are present in the dataset.")
@pytest.fixture(scope='session')
def
dataset_v51():
10@pytest.fixture(scope="session") 11def dataset_v51(): 12 dataset_name_hub = "Voxel51/fisheye8k" 13 dataset_name = "fisheye8k_mask_test" 14 try: 15 dataset = fo.load_dataset(dataset_name) 16 except ValueError: 17 dataset = load_from_hub( 18 repo_id=dataset_name_hub, 19 max_samples=1, 20 name=dataset_name, 21 ) 22 return dataset
@pytest.fixture(autouse=True)
def
setup_logging():
@pytest.fixture(autouse=True)
def
deactivate_wandb_sync():
@pytest.mark.parametrize('workflow_config', [{'semantic_segmentation': {'sam2': {'prompt_field': None, 'models': ['segment-anything-2.1-hiera-tiny-image-torch']}}, 'depth_estimation': {}}, {'semantic_segmentation': {'sam2': {'prompt_field': 'detections', 'models': ['segment-anything-2.1-hiera-tiny-image-torch']}}, 'depth_estimation': {}}, {'semantic_segmentation': {}, 'depth_estimation': {'dpt': {'models': {'Intel/dpt-swinv2-tiny-256'}}}}])
def
test_auto_label_mask(dataset_v51, workflow_config):
32@pytest.mark.parametrize("workflow_config", [ 33 { # Test 1: SAM2 without a prompt 34 "semantic_segmentation": { 35 "sam2": { 36 "prompt_field": None, 37 "models": ["segment-anything-2.1-hiera-tiny-image-torch"] 38 } 39 }, 40 "depth_estimation": {} 41 }, 42 { # Test 2: SAM2 with a prompt 43 "semantic_segmentation": { 44 "sam2": { 45 "prompt_field": "detections", 46 "models": ["segment-anything-2.1-hiera-tiny-image-torch"] 47 } 48 }, 49 "depth_estimation": {} 50 }, 51 { # Test 3: Depth Estimation 52 "semantic_segmentation": {}, 53 "depth_estimation": { 54 "dpt": { 55 "models": {"Intel/dpt-swinv2-tiny-256"} 56 } 57 } 58 }, 59]) 60 61def test_auto_label_mask(dataset_v51, workflow_config): 62 63 dataset_info = {"name": "fisheye8k_mask_test"} 64 65 print("\n[TEST] Starting workflow_auto_label_mask integration test") 66 print(f"[TEST] Dataset Name: {dataset_v51.name}") 67 print(f"[TEST] Number of samples: {len(dataset_v51)}") 68 69 70 # Run if valid config 71 if workflow_config["semantic_segmentation"] or workflow_config["depth_estimation"]: 72 workflow_auto_label_mask(dataset_v51, dataset_info, workflow_config) 73 print("[TEST] workflow_auto_label_mask completed successfully!") 74 else: 75 print("[TEST] Skipping workflow, running assertion checks on unmodified dataset.") 76 77 assert len(dataset_v51) > 0, "The dataset should not be empty after processing" 78 79 if workflow_config["semantic_segmentation"]: 80 prompt_field = workflow_config["semantic_segmentation"]["sam2"]["prompt_field"] 81 else: 82 prompt_field = None 83 84 expected_depth_field = "pred_de_Intel_dpt_swinv2_tiny_256" 85 expected_sam_field_noprompt = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt" 86 expected_sam_field_prompt = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}" 87 88 field_gt = "detections" 89 90 if prompt_field is None: 91 field_masks = "pred_ss_segment_anything_2_1_hiera_tiny_image_torch_noprompt" 92 else: 93 field_masks = f"pred_ss_segment_anything_2_1_hiera_tiny_image_torch_prompt_{prompt_field}" 94 95 classes_gt = set() 96 classes_sam = set() 97 n_found_fields = 0 98 99 for sample in dataset_v51: 100 print(f"Fields in sample: {sample.field_names}") 101 if workflow_config["semantic_segmentation"]: 102 bboxes_gt = sample[field_gt] 103 104 try: 105 sam_masks = sample[field_masks] 106 except KeyError: 107 print(f"Field {field_masks} not found in sample {sample}") 108 109 try: 110 if prompt_field is not None: 111 for bbox in bboxes_gt.detections: 112 classes_gt.add(bbox.label) 113 for mask in sam_masks.detections: 114 classes_sam.add(mask.label) 115 field = sample[expected_sam_field_prompt] 116 else: 117 field = sample[expected_sam_field_noprompt] 118 n_found_fields += 1 119 120 except Exception as e: 121 print(f"Error: {e}") 122 pass 123 if workflow_config["depth_estimation"]: 124 try: 125 field = sample[expected_depth_field] 126 n_found_fields += 1 127 except: 128 pass 129 130 assert classes_gt == classes_sam, f"Classes in Ground Truth {classes_gt} and SAM Masks {classes_sam}" 131 print(f"[TEST] Found {n_found_fields} new fields in the dataset") 132 assert n_found_fields > 0, "No new fields were added to the dataset" 133 134 print("Class Distribution in Ground Truth Bounding Boxes:") 135 print(f"\tGround Truth Classes: {classes_gt}") 136 print(f"\tSAM Masks: {classes_sam}") 137 138 print("[TEST] Verified that new fields are present in the dataset.")