77
88import torch
99import torch .nn as nn
10+ import torch .nn .functional as F
1011import torch .optim as optim
11- from torch .utils .data import DataLoader , Dataset
12+ from torch .utils .data import DataLoader , Dataset , Subset
1213
1314import librosa
1415from sklearn .metrics import roc_auc_score , average_precision_score
1516
1617import hashlib
1718import csv
18-
19- # ----------------------------
20- # Import the 2D Vision Models
21- # ----------------------------
22- from RegNetX import RegNetX_400MF
23- from MobileNet import MobileNet
24- # from ConvNetX import ConvNeXt
19+ import random
2520
2621# ----------------------------
2722# Constants
2823# ----------------------------
2924SAMPLING_RATE = 16000
3025NUM_CLASSES = 10
3126MAX_AUDIO_LENGTH = 16000
32- BATCH_SIZE = 32
3327
3428# ----------------------------
3529# Audio Preprocessing
3630# ----------------------------
3731def normalize_audio (x ):
38- return x / np .max (np .abs (x ))
32+ max_val = np .max (np .abs (x ))
33+ return x / max_val if max_val > 0 else x
3934
4035def pad_audio (audio , max_len = MAX_AUDIO_LENGTH ):
4136 return audio [:max_len ] if len (audio ) > max_len else np .pad (audio , (0 , max_len - len (audio )), 'constant' )
4237
4338# ----------------------------
44- # Dataset
39+ # Dataset & Wrapper
4540# ----------------------------
46- class AudioMNISTDataset (Dataset ):
41+ class AudioMNISTBaseDataset (Dataset ):
42+ """Loads all audio into memory ONCE. Returns raw numpy arrays."""
4743 def __init__ (self , data_path ):
4844 self .data = []
4945 self .labels = []
5046
5147 wav_files = glob .glob (os .path .join (data_path , '*' , '*.wav' ))
52- # Deterministic shuffle using md5 hash of path
5348 wav_files = sorted (wav_files , key = lambda x : hashlib .md5 (x .encode ()).hexdigest ())
54- self .wav_files = wav_files .copy () # store for TSV
49+ self .wav_files = wav_files .copy ()
5550
5651 for audio_path in tqdm (wav_files , desc = "Loading audio files" ):
5752 audio , _ = librosa .load (audio_path , sr = SAMPLING_RATE )
@@ -65,12 +60,60 @@ def __len__(self):
6560 return len (self .data )
6661
6762 def __getitem__ (self , idx ):
68- audio = torch .tensor (self .data [idx ], dtype = torch .float32 ).unsqueeze (0 )
69- label = self .labels [idx ]
70- return audio , label
63+ return self .data [idx ], self .labels [idx ]
64+
65+ class AudioSubsetWrapper (Dataset ):
66+ """Wraps a subset to apply dynamic augmentation and convert to Tensors."""
67+ def __init__ (self , subset , augment = False ):
68+ self .subset = subset
69+ self .augment = augment
70+
71+ def __len__ (self ):
72+ return len (self .subset )
73+
74+ def apply_augmentation (self , x ):
75+ if random .random () < 0.5 :
76+ x = np .clip (x + np .random .randn (len (x )) * 0.005 , - 1.0 , 1.0 ) # noise
77+ if random .random () < 0.5 :
78+ x = np .roll (x , np .random .randint (- 200 , 200 )) # time shift
79+ if random .random () < 0.5 :
80+ x = np .clip (x * np .random .uniform (0.8 , 1.2 ), - 1.0 , 1.0 ) # random gain
81+ return x
82+
83+ def __getitem__ (self , idx ):
84+ x , y = self .subset [idx ]
85+ if self .augment :
86+ x = self .apply_augmentation (x )
87+ x = torch .tensor (x , dtype = torch .float32 ).unsqueeze (0 ) # (1, length)
88+ return x , y
89+
90+ def load_data (data_path , batch_size , augment_train = False , split_tsv = "split_indices_standard.tsv" ):
91+ base_dataset = AudioMNISTBaseDataset (data_path )
92+
93+ train_size = int (0.8 * len (base_dataset ))
94+ train_indices = list (range (0 , train_size ))
95+ test_indices = list (range (train_size , len (base_dataset )))
96+
97+ # Isolate augmentation using the wrapper
98+ train_dataset = AudioSubsetWrapper (Subset (base_dataset , train_indices ), augment = augment_train )
99+ test_dataset = AudioSubsetWrapper (Subset (base_dataset , test_indices ), augment = False )
100+
101+ train_loader = DataLoader (train_dataset , batch_size = batch_size , shuffle = True )
102+ test_loader = DataLoader (test_dataset , batch_size = batch_size , shuffle = False )
103+
104+ with open (split_tsv , "w" , newline = "" ) as f :
105+ writer = csv .writer (f , delimiter = "\t " )
106+ writer .writerow (["index" , "split" , "label" , "file_path" ])
107+ for idx in train_indices :
108+ writer .writerow ([idx , "train" , base_dataset .labels [idx ], base_dataset .wav_files [idx ]])
109+ for idx in test_indices :
110+ writer .writerow ([idx , "test" , base_dataset .labels [idx ], base_dataset .wav_files [idx ]])
111+ print (f"Saved split information to { split_tsv } " )
112+
113+ return train_loader , test_loader
71114
72115# ----------------------------
73- # Model Definition (Updated Wrapper)
116+ # Model Definition
74117# ----------------------------
75118class AudioMNISTModel (nn .Module ):
76119 """
@@ -97,41 +140,12 @@ def forward(self, x):
97140 x = x .view (x .size (0 ), * self .reshape_dims )
98141 return self .backbone (x )
99142
100- # ----------------------------
101- # Load Data
102- # ----------------------------
103- def load_data (data_path , batch_size , split_tsv = "split_indices_model1.tsv" ):
104- dataset = AudioMNISTDataset (data_path )
105- # Fixed 80/20 split (after deterministic shuffle)
106- train_size = int (0.8 * len (dataset ))
107- train_indices = list (range (0 , train_size ))
108- test_indices = list (range (train_size , len (dataset )))
109-
110- train_dataset = torch .utils .data .Subset (dataset , range (0 , train_size ))
111- test_dataset = torch .utils .data .Subset (dataset , range (train_size , len (dataset )))
112-
113- train_loader = DataLoader (train_dataset , batch_size = batch_size , shuffle = True )
114- test_loader = DataLoader (test_dataset , batch_size = batch_size , shuffle = False )
115-
116- # --- Write split info to TSV ---
117- with open (split_tsv , "w" , newline = "" ) as f :
118- writer = csv .writer (f , delimiter = "\t " )
119- writer .writerow (["index" , "split" , "label" , "file_path" ])
120- for idx in train_indices :
121- writer .writerow ([idx , "train" , dataset .labels [idx ], dataset .wav_files [idx ]])
122- for idx in test_indices :
123- writer .writerow ([idx , "test" , dataset .labels [idx ], dataset .wav_files [idx ]])
124- print (f"Saved split information to { split_tsv } " )
125-
126- return train_loader , test_loader
127-
128143# ----------------------------
129144# Training loop
130145# ----------------------------
131146def train (model , train_loader , device , epochs = 10 , lr = 0.001 ):
132147 criterion = nn .CrossEntropyLoss ()
133148 optimizer = optim .Adam (model .parameters (), lr = lr )
134-
135149 model .to (device )
136150 model .train ()
137151
@@ -144,7 +158,7 @@ def train(model, train_loader, device, epochs=10, lr=0.001):
144158 for images , labels in tqdm (train_loader , desc = f"Epoch { epoch + 1 } /{ epochs } " , unit = "batch" ):
145159 images , labels = images .to (device ), labels .to (device )
146160
147- optimizer .zero_grad ()
161+ optimizer .zero_grad (set_to_none = True )
148162 outputs = model (images )
149163 loss = criterion (outputs , labels )
150164 loss .backward ()
@@ -160,7 +174,6 @@ def train(model, train_loader, device, epochs=10, lr=0.001):
160174 elapsed = time .time () - start_time
161175 print (f"Epoch { epoch + 1 } finished in { elapsed :.2f} s - Loss: { avg_loss :.4f} , Accuracy: { avg_acc :.4f} " )
162176
163-
164177# ----------------------------
165178# Evaluation
166179# ----------------------------
@@ -196,9 +209,8 @@ def evaluate_model(model, test_loader, device):
196209
197210 y_true = np .array (y_true )
198211 y_pred = np .array (y_pred )
199-
200- # compute AUROC and AUPRC
201- y_true_onehot = np .eye (10 )[y_true ]
212+ y_true_onehot = np .eye (NUM_CLASSES )[y_true ]
213+
202214 auroc = roc_auc_score (y_true_onehot , y_pred , multi_class = "ovr" )
203215 auprc = average_precision_score (y_true_onehot , y_pred )
204216
@@ -207,42 +219,29 @@ def evaluate_model(model, test_loader, device):
207219 print (f"Test auROC: { auroc :.4f} " )
208220 print (f"Test auPRC: { auprc :.4f} " )
209221
210-
211222# ----------------------------
212223# Main
213224# ----------------------------
214225def main ():
215- parser = argparse .ArgumentParser (description = "MNIST training code (PyTorch) with Augmentation" )
216- parser .add_argument ("--output" , type = str , default = "mnist_model_aug.pt" , help = "Model output name" )
226+ parser = argparse .ArgumentParser (description = "AudioMNIST Augmented Training" )
227+ parser .add_argument ("--data" , type = str , default = "./data/AudioMNIST" , help = "Path to dataset" )
228+ parser .add_argument ("--output" , type = str , default = "audiomnist_aug.pt" , help = "Model output name" )
217229 parser .add_argument ("--batch-size" , type = int , default = 64 )
218- parser .add_argument ("--epochs" , type = int , default = 5 , help = "Number of training epochs" )
230+ parser .add_argument ("--epochs" , type = int , default = 10 )
219231 args = parser .parse_args ()
220232
221233 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
234+ model = AudioMNISTModel (backbone_class = MobileNet , num_classes = NUM_CLASSES )
222235
223-
224- # Load data
225- train_loader , test_loader = load_data (batch_size = args .batch_size )
226-
227- # Initialize model
228- # Pass a dummy batch to configure the MobileNet stem for 1-channel MNIST images
229- # and properly calculate the fully-connected layer inputs for 28x28 resolution.
230-
231- dummy_batch = train_loader .dataset [0 ][0 ].unsqueeze (0 )
232- model = MobileNet (one_batch = dummy_batch , num_classes = 10 )
236+ # ENABLE DATA AUGMENTATION HERE
237+ train_loader , test_loader = load_data (args .data , args .batch_size , augment_train = True , split_tsv = "split_indices_aug.tsv" )
233238
234-
235- # Train
236239 train (model , train_loader , device , epochs = args .epochs )
237-
238- # Save model
239240 torch .save (model .state_dict (), args .output )
240241 print (f"Model saved to { args .output } " )
241242
242- # Evaluate
243- print ("Model statistics on test dataset" )
243+ print ("Model statistics on clean test dataset" )
244244 evaluate_model (model , test_loader , device )
245245
246-
247246if __name__ == "__main__" :
248247 main ()
0 commit comments