utils.custom_view
1import logging 2 3import fiftyone.utils.random as four 4from fiftyone import ViewField as F 5 6from config.config import GLOBAL_SEED 7 8#### FUNCTION MUST EXPECT A DATASET AND RETURN A VIEW #### 9 10 11def keep_val_no_crowd_no_added(dataset, include_added=False): 12 view_pot_train = dataset.match_tags(["val", "crowd"], bool=False) 13 view_pot_train.untag_samples(["train"]) 14 15 if include_added == False: 16 view_train = view_pot_train.match_tags(["added"], bool=False) 17 else: 18 view_train = view_pot_train 19 view_train.tag_samples(["train"]) 20 21 view_train_val = view_train.match_tags(["train", "val"]) 22 logging.info( 23 f"Reduced number of samples from {len(dataset)} to {len(view_train_val)}" 24 ) 25 return view_train_val 26 27 28def subset_splits( 29 dataset, 30 n_iteration, 31 target_train_perc=0.8, 32 target_test_perc=0, 33 target_val_perc=0.2, 34): 35 36 dataset_perct_target = (n_iteration + 1) / 10 37 logging.info(f"Trying to select {dataset_perct_target}% of the dataset.") 38 dataset_perct = min(dataset_perct_target, 1) 39 40 # All samples that are not reserved for eval 41 view_pot_train = dataset.match_tags(["val", "crowd"], bool=False) 42 view_pot_train.untag_samples(["train"]) # keep fixed val set (two cameras) 43 44 n_samples_total = len(view_pot_train) 45 n_samples_target = int(n_samples_total * dataset_perct) 46 47 # Take percentage for train 48 view_train = view_pot_train.take(n_samples_target, seed=GLOBAL_SEED) 49 view_train.tag_samples(["train"]) 50 51 # gs_Huron_Plymouth 52 # Main_stadium 53 54 # view.untag_samples(["train", "test", "val"]) 55 # four.random_split( 56 # view_train, 57 # { 58 # "train": target_train_perc, 59 # "test": target_test_perc, 60 # "val": target_val_perc, 61 # }, 62 # ) 63 64 view_train_val = dataset.match_tags(["train", "val"]) 65 logging.info( 66 f"Reduced number of samples from {len(dataset)} to {len(view_train_val)}" 67 ) 68 logging.info(f"Split Distribution: {view_train_val.count_sample_tags()}") 69 70 return view_train_val 71 72 73def max_detections( 74 dataset, 75 target_train_perc=0.6, 76 target_test_perc=0.2, 77 target_val_perc=0.2, 78 max_detections=7, 79 gt_field="ground_truth", 80): 81 82 # Only consider frames with a max. number of detections 83 num_objects = F(f"{gt_field}.detections").length() 84 max_detections_view = dataset.match(num_objects <= max_detections) 85 86 logging.info( 87 f"Reduced number of samples from {len(dataset)} to {len(max_detections_view)}" 88 ) 89 90 # Generate new splits 91 max_detections_view.untag_samples(["train", "test", "val"]) 92 four.random_split( 93 max_detections_view, 94 { 95 "train": target_train_perc, 96 "test": target_test_perc, 97 "val": target_val_perc, 98 }, 99 ) 100 101 return max_detections_view 102 103 104def vru_mcity_fisheye( 105 dataset, 106 target_train_perc=0.7, 107 target_test_perc=0.2, 108 target_val_perc=0.1, 109 create_splits=False, 110 random_split=True, 111): 112 113 # Only select labels of VRU classes 114 vru_view = dataset.filter_labels( 115 "ground_truth", 116 (F("label") == "pedestrian") | (F("label") == "motorbike/cycler"), 117 ) 118 119 n_samples = len(vru_view) 120 logging.info(f"Reduced number of samples from {len(dataset)} to {n_samples}") 121 logging.info(f"Original split distribution: {vru_view.count_sample_tags()}") 122 123 if create_splits and random_split: 124 vru_view.untag_samples(["train", "test", "val"]) 125 four.random_split( 126 vru_view, 127 { 128 "train": target_train_perc, 129 "test": target_test_perc, 130 "val": target_val_perc, 131 }, 132 ) 133 elif create_splits: 134 # Get target number of samples per split 135 target_n_train = int(n_samples * target_train_perc) 136 target_n_test = int(n_samples * target_test_perc) 137 target_n_val = int(n_samples * target_val_perc) 138 139 # Get current number of samples for 'train' split and rest 140 train_view = vru_view.match_tags(["train"]) 141 test_val_view = vru_view.match_tags(["train"], bool=False) 142 n_samples_train = len(train_view) 143 144 if target_n_test + target_n_val > len(test_val_view): 145 logging.error( 146 f"Target test/val count of {target_n_test + target_n_val} exceeds the number of available samples {test_val_view}." 147 ) 148 return dataset 149 150 test_val_view.untag_samples(["test", "val"]) 151 needed_train_from_test_val = target_n_train - n_samples_train 152 if needed_train_from_test_val < 0: 153 logging.error( 154 f"Already {n_samples_train} samples labeled 'train', but requested {target_n_train} samples for 'train' split" 155 ) 156 return dataset 157 158 to_be_train = test_val_view.take(needed_train_from_test_val, seed=GLOBAL_SEED) 159 to_be_train.tag_samples("train") 160 161 test_val_view = test_val_view.match_tags("train", bool=False) 162 to_be_test = test_val_view.take(target_n_test, seed=GLOBAL_SEED) 163 to_be_test.tag_samples("test") 164 165 test_val_view = test_val_view.match_tags(["train", "test"], bool=False) 166 to_be_val = test_val_view.take(target_n_val, seed=GLOBAL_SEED) 167 to_be_val.tag_samples("val") 168 169 logging.info(f"New split distribution: {vru_view.count_sample_tags()}") 170 return vru_view
def
keep_val_no_crowd_no_added(dataset, include_added=False):
12def keep_val_no_crowd_no_added(dataset, include_added=False): 13 view_pot_train = dataset.match_tags(["val", "crowd"], bool=False) 14 view_pot_train.untag_samples(["train"]) 15 16 if include_added == False: 17 view_train = view_pot_train.match_tags(["added"], bool=False) 18 else: 19 view_train = view_pot_train 20 view_train.tag_samples(["train"]) 21 22 view_train_val = view_train.match_tags(["train", "val"]) 23 logging.info( 24 f"Reduced number of samples from {len(dataset)} to {len(view_train_val)}" 25 ) 26 return view_train_val
def
subset_splits( dataset, n_iteration, target_train_perc=0.8, target_test_perc=0, target_val_perc=0.2):
29def subset_splits( 30 dataset, 31 n_iteration, 32 target_train_perc=0.8, 33 target_test_perc=0, 34 target_val_perc=0.2, 35): 36 37 dataset_perct_target = (n_iteration + 1) / 10 38 logging.info(f"Trying to select {dataset_perct_target}% of the dataset.") 39 dataset_perct = min(dataset_perct_target, 1) 40 41 # All samples that are not reserved for eval 42 view_pot_train = dataset.match_tags(["val", "crowd"], bool=False) 43 view_pot_train.untag_samples(["train"]) # keep fixed val set (two cameras) 44 45 n_samples_total = len(view_pot_train) 46 n_samples_target = int(n_samples_total * dataset_perct) 47 48 # Take percentage for train 49 view_train = view_pot_train.take(n_samples_target, seed=GLOBAL_SEED) 50 view_train.tag_samples(["train"]) 51 52 # gs_Huron_Plymouth 53 # Main_stadium 54 55 # view.untag_samples(["train", "test", "val"]) 56 # four.random_split( 57 # view_train, 58 # { 59 # "train": target_train_perc, 60 # "test": target_test_perc, 61 # "val": target_val_perc, 62 # }, 63 # ) 64 65 view_train_val = dataset.match_tags(["train", "val"]) 66 logging.info( 67 f"Reduced number of samples from {len(dataset)} to {len(view_train_val)}" 68 ) 69 logging.info(f"Split Distribution: {view_train_val.count_sample_tags()}") 70 71 return view_train_val
def
max_detections( dataset, target_train_perc=0.6, target_test_perc=0.2, target_val_perc=0.2, max_detections=7, gt_field='ground_truth'):
74def max_detections( 75 dataset, 76 target_train_perc=0.6, 77 target_test_perc=0.2, 78 target_val_perc=0.2, 79 max_detections=7, 80 gt_field="ground_truth", 81): 82 83 # Only consider frames with a max. number of detections 84 num_objects = F(f"{gt_field}.detections").length() 85 max_detections_view = dataset.match(num_objects <= max_detections) 86 87 logging.info( 88 f"Reduced number of samples from {len(dataset)} to {len(max_detections_view)}" 89 ) 90 91 # Generate new splits 92 max_detections_view.untag_samples(["train", "test", "val"]) 93 four.random_split( 94 max_detections_view, 95 { 96 "train": target_train_perc, 97 "test": target_test_perc, 98 "val": target_val_perc, 99 }, 100 ) 101 102 return max_detections_view
def
vru_mcity_fisheye( dataset, target_train_perc=0.7, target_test_perc=0.2, target_val_perc=0.1, create_splits=False, random_split=True):
105def vru_mcity_fisheye( 106 dataset, 107 target_train_perc=0.7, 108 target_test_perc=0.2, 109 target_val_perc=0.1, 110 create_splits=False, 111 random_split=True, 112): 113 114 # Only select labels of VRU classes 115 vru_view = dataset.filter_labels( 116 "ground_truth", 117 (F("label") == "pedestrian") | (F("label") == "motorbike/cycler"), 118 ) 119 120 n_samples = len(vru_view) 121 logging.info(f"Reduced number of samples from {len(dataset)} to {n_samples}") 122 logging.info(f"Original split distribution: {vru_view.count_sample_tags()}") 123 124 if create_splits and random_split: 125 vru_view.untag_samples(["train", "test", "val"]) 126 four.random_split( 127 vru_view, 128 { 129 "train": target_train_perc, 130 "test": target_test_perc, 131 "val": target_val_perc, 132 }, 133 ) 134 elif create_splits: 135 # Get target number of samples per split 136 target_n_train = int(n_samples * target_train_perc) 137 target_n_test = int(n_samples * target_test_perc) 138 target_n_val = int(n_samples * target_val_perc) 139 140 # Get current number of samples for 'train' split and rest 141 train_view = vru_view.match_tags(["train"]) 142 test_val_view = vru_view.match_tags(["train"], bool=False) 143 n_samples_train = len(train_view) 144 145 if target_n_test + target_n_val > len(test_val_view): 146 logging.error( 147 f"Target test/val count of {target_n_test + target_n_val} exceeds the number of available samples {test_val_view}." 148 ) 149 return dataset 150 151 test_val_view.untag_samples(["test", "val"]) 152 needed_train_from_test_val = target_n_train - n_samples_train 153 if needed_train_from_test_val < 0: 154 logging.error( 155 f"Already {n_samples_train} samples labeled 'train', but requested {target_n_train} samples for 'train' split" 156 ) 157 return dataset 158 159 to_be_train = test_val_view.take(needed_train_from_test_val, seed=GLOBAL_SEED) 160 to_be_train.tag_samples("train") 161 162 test_val_view = test_val_view.match_tags("train", bool=False) 163 to_be_test = test_val_view.take(target_n_test, seed=GLOBAL_SEED) 164 to_be_test.tag_samples("test") 165 166 test_val_view = test_val_view.match_tags(["train", "test"], bool=False) 167 to_be_val = test_val_view.take(target_n_val, seed=GLOBAL_SEED) 168 to_be_val.tag_samples("val") 169 170 logging.info(f"New split distribution: {vru_view.count_sample_tags()}") 171 return vru_view