Making safer biomedical predictions with Deep Learning
This is the fourth article in the series Deep Learning for Life Sciences. In the previous posts, I showed how to use Deep Learning on Ancient DNA, Deep Learning for Single Cell Biology and Deep Learning for Data Integration. Now we are going to dive into Biomedicine and learn why and how we should use Bayesian Deep Learning for patient safety.
Next Generation Sequencing (NGS) provided a major advance for our understanding of pathogenic mechanisms leading to common human diseases. Nevertheless, amount of data still remains a bottleneck for analysis in Biomedicine. In contrast to Data Science, millions of examples are rather uncommon in Biomedicine while high dimensional data are quite typical, therefore Machine Learning has very limited applications in Biomedicine. Lack of data and high-dimensional parameter space hinder precision in clinical diagnostics bringing a lot of false predictions which do not hold in clinical trials. When data are sparse/scarce/noisy and high-dimensional, Bayesian Statistics helps to make generalizable predictions.
Here we will discuss how to implement Bayesian Deep Learning with PyMC3 in order to ensure patient safety and provide more accurate and intelligent predictions for clinical diagnostics.
Why to be Bayesian when running Deep Learning?
In the previous post I explained that performing a statistical analysis you should pay particular attention to the balance between the number of statistical observations, N, and the dimension of your space, i.e. number of features, P. Depending on the amount of data, you can select between Bayesian Statistics, Frequentist Statistics and Machine/Deep Learning.
So it makes sense to use Deep Learning when you have a lot of data because you can abandon the dull world of Linear Algebra and jump into the rabbit hole of non-linear mathematics. In contrast, Biomedicine usually works in the opposite limit, N<<P, and needs Priors to compensate for the lack of data. This is the first reason why Biomedical analysis should be Bayesian.
Now imagine for a moment that you got some Biomedical Big Data, this is uncommon but not impossible if one works with Imaging or Single Cell Biology, here you can and should do Deep Learning. But why would you want to be Bayesian in this case?
Here comes the second reason: the necessity to generate less categorical (compared to traditional Frequentist based Deep Learning) predictions by incorporating uncertainties into the model. This is of tremendous importance for the areas with a very high price of false predictions such as self-driving cars, modelling stock market, earthquakes and particularly clinical diagnostics.
Why not Frequentist analysis for Biomedicine?
There are many reasons to be cautious when applying Frequentist Statistics to clinical diagnostics. It is heavily based on normality assumption and hence sensitive to outliers, it operates with descriptive statistics which do not always reflect the underlying data distributions and therefore fail to correctly capture the difference between data sets in the Anscombe’s quartet.
In contrast, Bayesian probabilistic modelling of the Anscombe’s data sets would result in large discrepancies in probability distributions.
Intelligence is to know how much you do not know
There are a few famous examples commonly referred to as Data Saurus which further demonstrate that Frequentist Statistics can not capture the difference between groups of samples with identical descriptive statistics such as mean, standard deviation or Pearson’s correlation coefficient.
Therefore, the simplistic Frequentist analysis should not be used for clinical diagnostics where we can not afford making false predictions that can damage people’s lives.
Bayesian Deep Learning on scRNAseq with PyMC3
Here I will use scRNAseq data on Cancer Associated Fibroblats (CAFs) and apply Bayesian Deep Learning for their classification between malignant and non-malignant cell types. In a similar manner, diabetes patients can be assigned to certain disease sub-types for accurate treatment prescription. We will start with downloading expression data from here, loading them into Python, splitting into training and validation subsets and visualizing with tSNE. As usually, rows of the expression matrix are samples/cells, columns are features/genes, last column contains cell labels derived from unbiased DBSCAN clustering.
Four clusters are clearly distinguishable at the tSNE plot. Next we are going to construct a Bayesian Neural Network (BNN) model with one hidden layer and 16 neurons, this is done by assigning Normal Priors to weights and biases and initializing them with random values.
For building BNN, I am going to use PyMC3 and follow approach described in the fantastic blog of Thomas Wiecki. Within the model we define also the likelihood which is a Categorical distribution since we are dealing with a scRNAseq multi-class (4 classes) classification problem.
By putting Priors on the weights and biases we let the model know that those parameters have uncertainties, therefore the MCMC sampler will build Posterior distributions for them. Now we are going to define a function which draws samples from the Posteriors of the parameters of the Bayesian Neural Network using one of the Hamiltonian Monte Carlo (a much faster sampler compared to e.g. Metropolis when derivatives of the parameters can be calculated) algorithms called NUTS. Sampling is the training of the BNN.
Now we are going to validate the predictions of the Bayesian Neural Network model using Posterior Predictive Check (PPC) procedure. For this purpose, we will use the trained model and draw decision boundary on the tNSE plot for the test subset. Decision boundary is created by building a 100 x 100 grid on the tSNE plot and running the model prediction for each point of the grid. Next, we calculate the mean and the standard deviation of the probability of assignment of each point on the grid to one of the 4 cell sub-types and visualize the mean probability and uncertainty of the probability.
The plots above correspond to the tSNE on the test subset (upper left); tSNE on the test subset with mean probability of assignment of each point to any of the 4 cell sub-types (upper right), which is basically what Maximum Likelihood / Frequentist Neural Network would predict; and tSNE on the test subset with the uncertainty of the probability of assignment of each point to the 4 cell sub-types (lower right), which is a particular output of Bayesian Neural Network. Here red and blue colors imply high and low probability of assigning points of tSNE to any cell sub-type, respectively. The darker area on the uncertainty heatmap indicates regions of higher uncertainty.
What we can immediately see is that the mean probability heatmap contains two cells from the yellow class assigned with 100% probability to the purple cluster. This is a severe misclassification and a demonstration of the failure of Maximum Likelihood / Frequentist Neural Network. In contrast, the uncertainty heatmap shows that the two yellow cells fall onto a relatively dark area meaning that the Bayesian Neural Network was not at all sure about assigning those cells to any cluster. This example demonstrates the power of Bayesian Deep Learning for making safer and less radical classifications which is of particular importance for clinical diagnostics.
Here we have learnt that Bayesian Deep Learning is a more accurate and safe way of doing predictions, which makes a lot of sense to use in clinical diagnostics where we are not allowed to be mistaken with treatment prescriptions. We have used PyMC3 and MCMC in order to build a Bayesian Neural Network model and sample from the posterior probability of the assignment of samples to malignant vs. non-malignant classes. Finally, we demonstrated superiority of the Bayesian Deep Learning over the Frequentist approach in utilizing uncertainty information to avoid sample missclassification.
As usually, let me know in the comments if you have a specific favorite area in Life Sciences which you would like to address within the Deep Learning framework. Follow me at Medium Nikolay Oskolkov, in twitter @NikolayOskolkov, and check out the codes for this post on my github. I plan to write the next post about Deep Learning for Microscopy Image Analysis, stay tuned.