The exact diagnosis of dementia is a difficult task that requires various tests and expert consultation. While magnetic resonance images of the brain do not provide the full story, they offer information about changes to the structure of the brain. A quick, accurate diagnosis is essential for timely intervention.
Dementia, a syndrome characterized by a decline in memory, thinking, and behavior, represents one of the most significant health challenges facing aging populations. This syndrome often puts stress on healthcare systems, caregivers, and family members of individuals affected by the condition. Further, dementia rates are predicted to increase substantially in the next 30 years [5].
Conventional diagnostic methods for dementia often rely on clinical assessments, cognitive tests, and neuroimaging techniques such as magnetic resonance imaging (MRI). While these approaches can provide insights into the structural and functional changes associated with dementia, they can be subjective, time-consuming, and reliant on the expertise of clinicians. Machine learning (ML) algorithms, particularly convolutional neural networks (CNNs), excel at learning intricate patterns and relationships within image datasets, making them well-suited for analyzing large volumes of MRI images and identifying signs of dementia. Furthermore, CNNs have the potential to identify subtle details in brain structure that may not be apparent to human observers, enhancing the sensitivity of dementia detection. In this paper, we aim to explore the application of ML techniques for classifying the degree of dementia using MRI images of the brain. We will investigate the efficacy of two multi-class ML models, implemented in TensorFlow, in distinguishing between different subtypes of dementia. By analyzing the strengths and limitations of these approaches in dementia classification, we hope to contribute to the ongoing efforts to improve early and accurate detection of this condition.
The dataset used in this study was obtained from Kaggle and contains 6400 pre-processed grayscale MRI images of the brain. Each image is of size 256 by 256 and labelled according to four classes: non-demented, very mildly-demented, mildly-demented, and moderately-demented. An example of a non-demented MRI is shown in Figure 2. The MRI images in this dataset show the structure of a horizontal slice of the brain as depicted in Figure 3.
The first machine learning model we use in this study is a convolutional neural network. In this study, we use an architecture with two convolutional layers with ReLu activation and a fully connected layer with softmax activation, shown in Figure 4. The first convolutional layer has 32 filters of size \( 5\times 5\). The second convolutional layer has 16 filters of size \(3 \times 3\). Then, the output is flattened and fed into a fully connected layer with four outputs. Finally, the softmax activation function is used to convert the four outputs into probabilities for the four labelled classes. The dataset is split into training, validation and test datasets using an 80-10-10 split and a batch size of 64. Finally, the network is trained using a sparse categorical cross entropy loss function and the Adam optimizer.
The second machine learning model we use in this study is a fully connected neural network. The architecture utilizes two fully connected layers as shown in Figure 5. First, the image is flattened into a one dimensional tensor and input into a fully connected layer with 4000 hidden units with ReLu activation. The second fully connected layer has 4 units with softmax activation to output the probabilities of classification for each class. The data is put into an 80-10-10 split with a batch size of 64. Similarly, the network is trained using a sparse categorical cross entropy loss function and the Adam optimizer.
After running both models for 25 epochs, the CNN performed significantly better than the fully connected network in terms of accuracy on the validation and testing datasets. The training and validation curves for each model are shown in Figure 6.
Our results indicate that a CNN is a promising ML algorithm for dementia classification of MRI images of the brain. The architecture used in this study performed well on the dataset we considered. We attribute this to the convolutional filters learning key predictors in particular regions of the image. Unlike fully connected networks, CNNs are able to understand relationships between neighboring pixels, allowing it to learn important regions of images that are translationally invariant. Further, alterations to the CNN architecture such as number of layers, activations, and filter size may lead to improved accuracies. There is a potential limitation in the model we trained. The dataset used is imbalanced with a small portion of examples labelled as moderately-demented. The distribution of labels in the dataset is shown in Figure 9.
In summary, our exploration into dementia severity classification using preprocessed MRI images highlights the promise of machine learning, particularly convolutional neural networks (CNNs), in enhancing diagnostic accuracy. The CNN model exhibited superior performance compared to the fully connected network, achieving a stellar \(95.0\%\) accuracy on the test dataset and perfect accuracy scores on the training dataset. This success underscores CNNs' ability to discern intricate patterns and spatial relationships within MRI images, providing valuable insights into dementia severity, and is consistent with our hypotheses prior to testing. Despite these promising outcomes, there are a few crucial considerations for future research. Addressing dataset imbalances, especially regarding moderately-demented examples, remains paramount in order to better train the network. Incorporating image stacks of MRI slices could offer a more comprehensive view of brain structure and potentially require more complicated CNNs to help classify. Furthermore, integrating salience maps may shine light on the biological markers associated with dementia and other neurodegenerative diseases. Exploring time-series ML models utilizing longitudinal brain scans has merit for personalized diagnosis and intervention strategies tailored to individual patients' needs. Our findings contribute to the ongoing quest for early and accurate dementia detection, crucial for optimizing patient care and management, in which we believe the combination of machine learning and MRI imaging will play an important role.
[1] Alzheimer MRI Preprocessed Dataset, URL: https://www.kaggle.com/datasets/sachinkumar413/ alzheimer-mri-dataset, Accessed on: July 9, 2024.
[2] Anatomy of the encephalon (MRI) in axial slices, Antoine Micheau, Denis Hoa, URL: https://www. imaios.com/en/e-anatomy/brain/mri-axial-brain, Accessed on: July 9, 2024.
[3] Qiu, Miller, Joshi, et al. Multimodal deep learn- ing for Alzheimer’s disease dementia assessment. Nature Communications (2022).
[4] Kavitha, Mani, Srividhya, et al. Early-Stage Alzheimer’s Disease Prediction Using Machine Learning Models. Front Public Health (2022).
[5] Nichols, Steinmetz, Vollset, et al. Estimation of the global prevalence of dementia in 2019 and forecasted prevalence in 2050: an analysis for the Global Burden of Disease Study 2019. The Lancet Public Health (2022).