AI Insights

Predicting Sudden Cardiac Arrest: Time Series Classification with LSTM Recurrent Neural Networks

June 12, 2022


article featured image

Recurrent Neural Networks (RNNs) are powerful models for time-series classification, language translation, and other tasks. This tutorial will guide you through the process of building a simple end-to-end model using RNNs, training it on patients’ vitals and static data, and making predictions of ”Sudden Cardiac Arrest”.

The project partner: The use case in this case study stems from seed-stage startup Transformative.ai who hosted an Omdena Challenge as part of Omdena´s AI Incubator for impact startups.

Introduction to Time-series Data

Time-series data contains a sequence of observations collected for a defined time frame. These sequences can pertain to weather reading, customer’s shopping patterns, word sequence, etc. Manual analysis of such sequences can be challenging as an overwhelming amount of data becomes available, and it becomes difficult to find patterns in the data. 

Finding patterns and predicting outcomes today uses various machine learning techniques developed to analyze time-series data. The use of deep learning techniques has also seen an exponential rise in analyzing time-series or sequence data. Recurrent neural networks are popular deep learning techniques available for analyzing and predicting outcomes for time-series data.

What are Recurrent Neural Networks?

Recurrent Neural Networks (RNN), originally a Natural Language Processing technique, are powerful artificial neural networks that maintain the memory of the input. As RNNs maintains the memory of inputs, they can solve problems involving sequential data with long-term dependencies. They demonstrate promising performance when it comes to time-series machine learning problems, ranging from weather prediction to sentiment analysis, machine translation, speech recognition, etc.

In Recurrent Neural Networks, the input features are present in sequential order(i.e., in time-series), and the model tries to find the underlying pattern to predict the desired outcome.

But some specific classification/regression tasks can include a combination of time-series and static features. An example of such a use case is to predict cardiac arrest in patients based on their static data and vitals.

The patient’s static features include age, ethnic origin, gender, patient’s history, and medications. The vitals include time-series features like heart rate, systolic blood pressure, diastolic blood pressure, temperature, etc. 

The vitals are frequently measured when patients get admitted to an ICU unit in a hospital. This article describes how one can combine time-series features with static features to construct a custom RNN + SLP (single-layer perceptron) neural network model to predict cardiac arrest in ICU patients. 

[dipl_divi_shortcode id=”83536″]

What is the problem statement? 

The SCA (sudden cardiac arrest) prediction model constructed as a part of the core Omdena Challenge can be referred to in detail here. SCA (sudden cardiac arrest) is a medical emergency in which the heart suddenly stops beating, killing the patient within minutes. Survival rates for SCA are <25% within hospitals. Identifying and treating the underlying cause can prevent SCA.

The project’s purpose was to expand the cardiac arrest prediction algorithm to pulseless electrical activity and asystole, providing an all-cause cardiac arrest prediction algorithm for more than 90% of patients.

The Data

MIMIC-III and eICU were sources of data for building cardiac arrest prediction models. MIMIC-III (‘Medical Information Mart for Intensive Care’) is an extensive, single-center database comprising information relating to patients admitted to critical care units at a large tertiary care hospital. 

Data from many critical care units throughout the continental United States populates the eICU collaborative research database. The data in the collaborative database covers patients admitted to critical care units in 2014 and 2015.

The physionet website provides access to both databases. Once access is received, the data is available for querying in google Bigquery, the big data analytics platform.

The machine learning model needs extraction, cleaning, and processing of the eICU and MIMIC-III data. We will not cover the data processing activity here, but you can read about data processing in the article here

Post data processing, a total of 10 features were created. A single patient has three static features – age, ethnicity, gender, and seven time-series vital features – systolic blood pressure, diastolic blood pressure, temperature, heart rate, oxygen saturation, respiratory rate, and Glasgow coma scale. We can see this data in figure 1 below for a patient with an ICU stay.

Figure 1: Time-series vitals and static features for a patient for an icu stay.

Figure 1: Time-series vitals and static features for a patient for an ICU stay.

Our Model: The Recurrent Neural Network + Single Layer Perceptron

We need a deep learning model capable of learning from time-series features and static features for this problem. Hence we construct a single layer perceptron (SLP) and a bi-directional LSTM using Keras and TensorFlow

The intuition behind the approach is that the bi-directional RNN will learn the relationship between the time-series features, and the single-layer perceptron will focus on the static features of the model. The early prediction of circulatory failure is the inspiration for this model and can be accessed here.

The model architecture consists of 1 SLP and 3 LSTM layers, followed by a concatenated layer to combine output from RNNs and SLP layers. This combined output then gets passed on to another dense layer followed by an output layer with sigmoid activation to predict cardiac arrest or not. The model architecture is viewable in the image below.

Deep learning model capable of learning from time-series features and static features

For the time window under consideration, static_input contains static features ( i.e., their values do not change with time) like age, ethnicity, gender, etc. Recurrent_input includes time-series features like the vitals whose measurements change every hour for the time window under consideration. For this model, we consider an 8-hour window.

# Define timesteps and the number of features

n_timesteps = 8

n_features = 7

# RNN + SLP Model

# Define input layer

recurrent_input = Input(shape=(n_timesteps,n_features),name=&amp;amp;amp;quot;TIMESERIES_INPUT&amp;amp;amp;quot;)

static_input = Input(shape=(x_train_over_static.shape[1], ),name=&amp;amp;amp;quot;STATIC_INPUT&amp;amp;amp;quot;)

# RNN Layers

# layer - 1

rec_layer_one = Bidirectional(LSTM(128, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01),return_sequences=True),name =&amp;amp;amp;quot;BIDIRECTIONAL_LAYER_1&amp;amp;amp;quot;)(recurrent_input)

rec_layer_one = Dropout(0.1,name =&amp;amp;amp;quot;DROPOUT_LAYER_1&amp;amp;amp;quot;)(rec_layer_one)

# layer - 2

rec_layer_two = Bidirectional(LSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01)),name =&amp;amp;amp;quot;BIDIRECTIONAL_LAYER_2&amp;amp;amp;quot;)(rec_layer_one)

rec_layer_two = Dropout(0.1,name =&amp;amp;amp;quot;DROPOUT_LAYER_2&amp;amp;amp;quot;)(rec_layer_two)

# SLP Layers

static_layer_one = Dense(64, kernel_regularizer=l2(0.001), activation='relu',name=&amp;amp;amp;quot;DENSE_LAYER_1&amp;amp;amp;quot;)(static_input)

# Combine layers - RNN + SLP

combined = Concatenate(axis= 1,name = &amp;amp;amp;quot;CONCATENATED_TIMESERIES_STATIC&amp;amp;amp;quot;)([rec_layer_two,static_layer_one])

combined_dense_two = Dense(64, activation='relu',name=&amp;amp;amp;quot;DENSE_LAYER_2&amp;amp;amp;quot;)(combined)

output = Dense(n_outputs,activation='sigmoid',name=&amp;amp;amp;quot;OUTPUT_LAYER&amp;amp;amp;quot;)(combined_dense_two)

# Compile Model

model = Model(inputs=[recurrent_input,static_input],outputs=[output])

# binary cross entropy loss

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',f1_m,precision_m, recall_m])

# focal loss

def focal_loss_custom(alpha, gamma):

def binary_focal_loss(y_true, y_pred):

fl = tfa.losses.SigmoidFocalCrossEntropy(alpha=alpha, gamma=gamma)

y_true_K = K.ones_like(y_true)

focal_loss = fl(y_true, y_pred)

return focal_loss

return binary_focal_loss

model.compile(loss=focal_loss_custom(alpha=0.2, gamma=2.0), optimizer='adam', metrics=['accuracy',f1_m,precision_m, recall_m])

model.summary()

The model summary shows the model constructed as seen in figure 2.

Figure 2: Bidirectional RNN + SLP model summary

Figure 2: Bidirectional RNN + SLP model summary

Since the input data is imbalanced, we also define the precision, recall, and f1-score calculation for the evaluation of the model.

# Define metrics for evaluating the model - recall, precision and f1-score
def recall_m(y_true, y_pred):
   true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
   possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
   recall = true_positives / (possible_positives + K.epsilon())
   return recall
def precision_m(y_true, y_pred):
   true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
   predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
   precision = true_positives / (predicted_positives + K.epsilon())
   return precision
def f1_m(y_true, y_pred):
   precision = precision_m(y_true, y_pred)
   recall = recall_m(y_true, y_pred)
   return 2*((precision*recall)/(precision+recall+K.epsilon()))

Results

Let’s train our model. We first train the model using binary cross-entropy loss and then using focal loss. Focal loss applies a modulating term to the cross-entropy loss to focus learning on hard negative examples, and you can read in detail about it in the paper.

# fit network
history =  model.fit([np.asarray(x_train_reshape).astype('float32'), np.asarray(x_train_over_static).astype('float32')],
                     y_train_reshape, epochs=epochs, batch_size=batch_size, verbose=verbose, validation_data=([np.asarray(x_val_reshape).astype('float32'), np.asarray(x_val_static).astype('float32')],y_val_reshape))
# summarize history for accuracy
pyplot.plot(history.history['accuracy'])
pyplot.plot(history.history['val_accuracy'])
pyplot.title('model accuracy')
pyplot.ylabel('accuracy')
pyplot.xlabel('epoch')
pyplot.legend(['train', 'validation'], loc='upper left')
pyplot.show()
# summarize history for loss
pyplot.plot(history.history['loss'])
pyplot.plot(history.history['val_loss'])
pyplot.title('model loss')
pyplot.ylabel('loss')
pyplot.xlabel('epoch')
pyplot.legend(['train', 'validation'], loc='upper left')
pyplot.show()
#evaluate model
loss, accuracy, f1_score, precision, recall = model.evaluate([np.asarray(x_test_reshape).astype('float32'),np.asarray(x_test_static).astype('float32')], y_test_reshape, batch_size=batch_size, verbose=0)
#print output
print("Accuracy:{} , F1_Score:{}, Precision:{}, Recall:{}".format(accuracy, f1_score, precision, recall))
Figure 3: Loss and accuracy Plots using binary cross entropy loss.

Figure 3: Loss and accuracy Plots using binary cross-entropy loss.

Figure 3 the plots for both loss and accuracy for train and validation sets respectively. 

The results on the test set using binary cross entropy loss is Accuracy:0.95, F1_Score:0.18, Precision:0.12, Recall:0.41.

Figure 4: Loss and accuracy Plots using focal loss.

Figure 4: Loss and accuracy Plots using focal loss.

The results on the test set using focal loss is Accuracy:0.99, F1_Score:0.35, Precision:0.37, Recall:0.35. Using the focal loss improves the model performance significantly.

Conclusion

RNNs are proven to work exceptionally well with time-series-based data. Often in actual life data, supplementary static features may be available, which cannot get directly incorporated into RNNs because of their non-sequential nature. The method described involves adding static features to RNNs to influence the learning process. A previous approach to the problem was to implement several models for each modality and combine them at the prediction level. Combining these two methods into the same model architecture allows the model to learn simultaneously from the static and temporal features.

We conclude that the addition of the static features improves the performance of the RNN than would otherwise by using the sequential and static features alone. 

This article is written by Sanjana Tule, Sijuade Oguntayo.

Want to work with us too?

media card
Transforming Artwork Analysis with Advanced Computer Vision Techniques
media card
Harnessing AI to Monitor and Optimize Reforestation Efforts in Madagascar
media card
FloodGuard: Harnessing the Power of AI and GIS to Protect Bangladesh from the Fury of Floods
media card
Clear Data for Clear Skies: How We used AI to Predict Air Quality in Poland