A simple implementation of Badnets: Identifying vulnerabilities in the machine learning model supply chain .
The CNN I built consists of three convolutional layers and two fully connected layers. Taking into account the need to support both MNIST and CIFAR10 datasets, this model incorporates an extra convolutional layer when compared to the model presented in the original paper.
The implementation strategy is as follows:
- First, train the model using a partially randomly poisoned training set.
- Then, test the model using a clean test set and a test set fully containing triggers, obtaining metrics such as BA( Benign accuracy) and ASR(Attack Success Rate).
To install the required packages, you can run the following command:
pip install -r requirements.txtTo run the code, you can use the following command:
python main.pyYou can customize various parameters on the command line:
python main.py --helpThis project is licensed under the MIT License - see the LICENSE file for details.