Alzheimer's Disease Stage Detection Using Deep Learning
This repository contains the implementation of a deep learning model designed to detect various stages of Alzheimer's disease using MRI brain scans. The project employs a transfer learning approach, utilizing the EfficientNetB0 architecture, adapted to classify images into four distinct stages: nondemented, very mild, mild, and moderate dementia. Built with PyTorch, the model is trained on a structured dataset of MRI images and provides a robust framework for early and accurate diagnosis of Alzheimer’s progression.
- Model Architecture: Utilizes the pre-trained EfficientNetB0 model, fine-tuned for the specific task of classifying Alzheimer's disease stages from brain MRIs.
- Data Processing: Includes comprehensive preprocessing and augmentation techniques to optimize model training and performance.
- API Integration: Features a Flask-based REST API that enables users to upload MRI images and receive diagnostic predictions, facilitating easy interaction with the model.
- Containerization: Docker support for easy setup and deployment, ensuring consistency across different environments.
- Evaluation Metrics: Implements various metrics for thorough evaluation and validation of the model, ensuring high accuracy and reliability.
/model: Contains the model definitions and state./data: Dataset directory with train and test splits./utils: Utility functions for data loading and transformations./api: Contains the Flask application for the REST API./tests/: Includes test scriptstrain.py: Script for model training.evaluate.py: Script for model evaluation and performance metrics.
- To provide a reliable tool for early detection of Alzheimer's disease stages, aiding in better management and treatment planning.
- To contribute to the ongoing research in medical AI by demonstrating the application of advanced machine learning techniques in healthcare diagnostics.
- To offer an accessible platform for further development and validation by the research community, healthcare professionals, and technology enthusiasts.
- Train the model using
train.py. - Evaluate the model's performance with
evaluate.py. - Start the API server by running
python api/app.py. - Test the API using
python tests/test_alzheimers_api.py.
This project is intended for educational and research purposes, aiming to bridge the gap between medical imaging and machine learning technologies.
Located in the /api directory, app.py sets up a Flask-based REST API for the Alzheimer's Disease Detection model:
/predictendpoint for MRI image upload and prediction- Image preprocessing using PyTorch transformations
- Integration with the trained AlzheimerNet model
Found in the /tests directory, test_alzheimers_api.py provides API testing functionality:
- Random selection of test images
- API endpoint testing with selected images
- Comparison of true labels with model predictions
Located in the root directory, train.py contains the main training logic for the Alzheimer's detection model:
- Defines the
train_modelfunction for model training - Implements the main execution flow in the
mainfunction - Utilizes PyTorch for model training and optimization
- Loads data using custom data loaders from
utils/data_utils.py - Initializes the AlzheimerNet model, optimizer, and loss function
- Trains the model for a specified number of epochs
- Saves the trained model to the
saved_modelsdirectory
Located in the root directory, evaluate.py contains the evaluation logic for the Alzheimer's detection model:
- Defines the
evaluate_modelfunction for model evaluation - Implements the main execution flow in the
mainfunction - Loads the trained model from the
saved_modelsdirectory - Utilizes the test data loader to evaluate model performance
- Computes and prints the accuracy of the model on the test set
- Handles device-agnostic evaluation (CPU or CUDA)