Dementia Classification Using a Convolutional Neural Network

The exact diagnosis of dementia is a difficult task that requires various tests and expert consultation. While magnetic resonance images of the brain do not provide the full story, they offer information about changes to the structure of the brain. A quick, accurate diagnosis is essential for timely intervention.

Introduction

Dementia, a syndrome characterized by a decline in memory, thinking, and behavior, represents one of the most significant health challenges facing aging populations. This syndrome often puts stress on healthcare systems, caregivers, and family members of individuals affected by the condition. Further, dementia rates are predicted to increase substantially in the next 30 years [5].

Description of the image
Figure 1. Dementia prevalance across the years (https://www.thelancet.com/journals/lanpub/article/PIIS2468-2667(21)00249-8/fulltext)
The timely and accurate diagnosis of dementia is crucial for several reasons. Early detection allows for interventions that may help slow the progression of the disease, thereby improving the quality of life for patients and their caregivers. In addition, accurate diagnosis enables healthcare professionals to tailor treatment plans and provide appropriate support services, optimizing patient care and management.

Conventional diagnostic methods for dementia often rely on clinical assessments, cognitive tests, and neuroimaging techniques such as magnetic resonance imaging (MRI). While these approaches can provide insights into the structural and functional changes associated with dementia, they can be subjective, time-consuming, and reliant on the expertise of clinicians. Machine learning (ML) algorithms, particularly convolutional neural networks (CNNs), excel at learning intricate patterns and relationships within image datasets, making them well-suited for analyzing large volumes of MRI images and identifying signs of dementia. Furthermore, CNNs have the potential to identify subtle details in brain structure that may not be apparent to human observers, enhancing the sensitivity of dementia detection. In this paper, we aim to explore the application of ML techniques for classifying the degree of dementia using MRI images of the brain. We will investigate the efficacy of two multi-class ML models, implemented in TensorFlow, in distinguishing between different subtypes of dementia. By analyzing the strengths and limitations of these approaches in dementia classification, we hope to contribute to the ongoing efforts to improve early and accurate detection of this condition.

MRI Dataset

The dataset used in this study was obtained from Kaggle and contains 6400 pre-processed grayscale MRI images of the brain. Each image is of size 256 by 256 and labelled according to four classes: non-demented, very mildly-demented, mildly-demented, and moderately-demented. An example of a non-demented MRI is shown in Figure 2. The MRI images in this dataset show the structure of a horizontal slice of the brain as depicted in Figure 3.

Description of the image
Figure 2. Example MRI scan.
Description of the image
Figure 3. The dataset is comprised of saggital MRI slices of the brain.

CNN architecture and training

The first machine learning model we use in this study is a convolutional neural network. In this study, we use an architecture with two convolutional layers with ReLu activation and a fully connected layer with softmax activation, shown in Figure 4. The first convolutional layer has 32 filters of size \( 5\times 5\). The second convolutional layer has 16 filters of size \(3 \times 3\). Then, the output is flattened and fed into a fully connected layer with four outputs. Finally, the softmax activation function is used to convert the four outputs into probabilities for the four labelled classes. The dataset is split into training, validation and test datasets using an 80-10-10 split and a batch size of 64. Finally, the network is trained using a sparse categorical cross entropy loss function and the Adam optimizer.

Description of the image
Figure 4. Convolutional network architecture.

FC architecture and training

The second machine learning model we use in this study is a fully connected neural network. The architecture utilizes two fully connected layers as shown in Figure 5. First, the image is flattened into a one dimensional tensor and input into a fully connected layer with 4000 hidden units with ReLu activation. The second fully connected layer has 4 units with softmax activation to output the probabilities of classification for each class. The data is put into an 80-10-10 split with a batch size of 64. Similarly, the network is trained using a sparse categorical cross entropy loss function and the Adam optimizer.

Description of the image
Figure 5. Fully connected network architecture.

Results

After running both models for 25 epochs, the CNN performed significantly better than the fully connected network in terms of accuracy on the validation and testing datasets. The training and validation curves for each model are shown in Figure 6.

Description of the image
Figure 6. Training and validation accuracy of both models.
Both models took roughly 25 minutes to complete training. The CNN converges after about 10 epochs and reaches 95% accuracy on the validation dataset without any signs of overfitting. In comparison, the fully connected network does not seem to converge even after 25 epochs. The confusion matrices for the CNN and the fully connected network on the test dataset are shown in Figure 7 and Figure 8, respectively.
Description of the image
Figure 7. Confusion matrix for the CNN results.
Description of the image
Figure 8. Confusion matrix for the FC results.
The CNN received a 95.0% accuracy on the test dataset while the fully connected network received a 72.4% accuracy. The CNN performed relatively well across all labels. The errors seem to mostly come from predicting examples as a milder form of dementia. On the other hand, the fully connected network performed very poorly in classifying mild to moderate cases of dementia. In fact, Figure 8 shows that only two examples were predicted to have a label of mildly demented or moderately demented when there were 91 mildly demented examples and 7 moderately demented examples.

Discussion

Our results indicate that a CNN is a promising ML algorithm for dementia classification of MRI images of the brain. The architecture used in this study performed well on the dataset we considered. We attribute this to the convolutional filters learning key predictors in particular regions of the image. Unlike fully connected networks, CNNs are able to understand relationships between neighboring pixels, allowing it to learn important regions of images that are translationally invariant. Further, alterations to the CNN architecture such as number of layers, activations, and filter size may lead to improved accuracies. There is a potential limitation in the model we trained. The dataset used is imbalanced with a small portion of examples labelled as moderately-demented. The distribution of labels in the dataset is shown in Figure 9.

Description of the image
Figure 9. Class distribution in the dataset.
While the confusion matrix on the test dataset did correctly predict majority of the cases where the true label was moderately-demented, the model could still achieve a high validation accuracy while incorrectly classifying all of the moderately-demented examples. Despite the imbalance in the dataset, it is important to realize that we were able to produce accurate results while only using a single slice of the brain MRI. We predict that using an image stack of MRI slices that describe the entire volume of the brain will only improve the efficacy of the CNN in classification and is potential future work. Additionally, it would be possible to use salience maps to further understand the biological markers that predict dementia, which may lead to a greater understanding of this syndrome. Further, additional work needs to be done to analyze the efficacy of time series ML models of dementia prediction which utilize brain scans of individuals at multiple points in life. This would lead to a more individualized diagnosis in case the predictors in brain scans vary between individuals in a population.

Conclusion

In summary, our exploration into dementia severity classification using preprocessed MRI images highlights the promise of machine learning, particularly convolutional neural networks (CNNs), in enhancing diagnostic accuracy. The CNN model exhibited superior performance compared to the fully connected network, achieving a stellar \(95.0\%\) accuracy on the test dataset and perfect accuracy scores on the training dataset. This success underscores CNNs' ability to discern intricate patterns and spatial relationships within MRI images, providing valuable insights into dementia severity, and is consistent with our hypotheses prior to testing. Despite these promising outcomes, there are a few crucial considerations for future research. Addressing dataset imbalances, especially regarding moderately-demented examples, remains paramount in order to better train the network. Incorporating image stacks of MRI slices could offer a more comprehensive view of brain structure and potentially require more complicated CNNs to help classify. Furthermore, integrating salience maps may shine light on the biological markers associated with dementia and other neurodegenerative diseases. Exploring time-series ML models utilizing longitudinal brain scans has merit for personalized diagnosis and intervention strategies tailored to individual patients' needs. Our findings contribute to the ongoing quest for early and accurate dementia detection, crucial for optimizing patient care and management, in which we believe the combination of machine learning and MRI imaging will play an important role.

References

[1] Alzheimer MRI Preprocessed Dataset, URL: https://www.kaggle.com/datasets/sachinkumar413/ alzheimer-mri-dataset, Accessed on: July 9, 2024.

[2] Anatomy of the encephalon (MRI) in axial slices, Antoine Micheau, Denis Hoa, URL: https://www. imaios.com/en/e-anatomy/brain/mri-axial-brain, Accessed on: July 9, 2024.

[3] Qiu, Miller, Joshi, et al. Multimodal deep learn- ing for Alzheimer’s disease dementia assessment. Nature Communications (2022).

[4] Kavitha, Mani, Srividhya, et al. Early-Stage Alzheimer’s Disease Prediction Using Machine Learning Models. Front Public Health (2022).

[5] Nichols, Steinmetz, Vollset, et al. Estimation of the global prevalence of dementia in 2019 and forecasted prevalence in 2050: an analysis for the Global Burden of Disease Study 2019. The Lancet Public Health (2022).