Predicting Sudden Cardiac Arrest: Time Series Classification with LSTM Recurrent Neural Networks
June 12, 2022
Recurrent Neural Networks (RNNs) are powerful models for time-series classification, language translation, and other tasks. This tutorial will guide you through the process of building a simple end-to-end model using RNNs, training it on patients’ vitals and static data, and making predictions of ”Sudden Cardiac Arrest”.
Introduction to Time-series Data
Time-series data contains a sequence of observations collected for a defined time frame. These sequences can pertain to weather reading, customer’s shopping patterns, word sequence, etc. Manual analysis of such sequences can be challenging as an overwhelming amount of data becomes available, and it becomes difficult to find patterns in the data.
Finding patterns and predicting outcomes today uses various machine learning techniques developed to analyze time-series data. The use of deep learning techniques has also seen an exponential rise in analyzing time-series or sequence data. Recurrent neural networks are popular deep learning techniques available for analyzing and predicting outcomes for time-series data.
What are Recurrent Neural Networks?
Recurrent Neural Networks (RNN), originally a Natural Language Processing technique, are powerful artificial neural networks that maintain the memory of the input. As RNNs maintains the memory of inputs, they can solve problems involving sequential data with long-term dependencies. They demonstrate promising performance when it comes to time-series machine learning problems, ranging from weather prediction to sentiment analysis, machine translation, speech recognition, etc.
In Recurrent Neural Networks, the input features are present in sequential order(i.e., in time-series), and the model tries to find the underlying pattern to predict the desired outcome.
But some specific classification/regression tasks can include a combination of time-series and static features. An example of such a use case is to predict cardiac arrest in patients based on their static data and vitals.
The patient’s static features include age, ethnic origin, gender, patient’s history, and medications. The vitals include time-series features like heart rate, systolic blood pressure, diastolic blood pressure, temperature, etc.
The vitals are frequently measured when patients get admitted to an ICU unit in a hospital. This article describes how one can combine time-series features with static features to construct a custom RNN + SLP (single-layer perceptron) neural network model to predict cardiac arrest in ICU patients.
[dipl_divi_shortcode id=”83536″]
What is the problem statement?
The SCA (sudden cardiac arrest) prediction model constructed as a part of the core Omdena Challenge can be referred to in detail here. SCA (sudden cardiac arrest) is a medical emergency in which the heart suddenly stops beating, killing the patient within minutes. Survival rates for SCA are <25% within hospitals. Identifying and treating the underlying cause can prevent SCA.
The project’s purpose was to expand the cardiac arrest prediction algorithm to pulseless electrical activity and asystole, providing an all-cause cardiac arrest prediction algorithm for more than 90% of patients.
The Data
MIMIC-III and eICU were sources of data for building cardiac arrest prediction models. MIMIC-III (‘Medical Information Mart for Intensive Care’) is an extensive, single-center database comprising information relating to patients admitted to critical care units at a large tertiary care hospital.
Data from many critical care units throughout the continental United States populates the eICU collaborative research database. The data in the collaborative database covers patients admitted to critical care units in 2014 and 2015.
The physionet website provides access to both databases. Once access is received, the data is available for querying in google Bigquery, the big data analytics platform.
The machine learning model needs extraction, cleaning, and processing of the eICU and MIMIC-III data. We will not cover the data processing activity here, but you can read about data processing in the article here.
Post data processing, a total of 10 features were created. A single patient has three static features – age, ethnicity, gender, and seven time-series vital features – systolic blood pressure, diastolic blood pressure, temperature, heart rate, oxygen saturation, respiratory rate, and Glasgow coma scale. We can see this data in figure 1 below for a patient with an ICU stay.
Our Model: The Recurrent Neural Network + Single Layer Perceptron
We need a deep learning model capable of learning from time-series features and static features for this problem. Hence we construct a single layer perceptron (SLP) and a bi-directional LSTM using Keras and TensorFlow.
The intuition behind the approach is that the bi-directional RNN will learn the relationship between the time-series features, and the single-layer perceptron will focus on the static features of the model. The early prediction of circulatory failure is the inspiration for this model and can be accessed here.
The model architecture consists of 1 SLP and 3 LSTM layers, followed by a concatenated layer to combine output from RNNs and SLP layers. This combined output then gets passed on to another dense layer followed by an output layer with sigmoid activation to predict cardiac arrest or not. The model architecture is viewable in the image below.
For the time window under consideration, static_input contains static features ( i.e., their values do not change with time) like age, ethnicity, gender, etc. Recurrent_input includes time-series features like the vitals whose measurements change every hour for the time window under consideration. For this model, we consider an 8-hour window.
# Define timesteps and the number of features n_timesteps = 8 n_features = 7 # RNN + SLP Model # Define input layer recurrent_input = Input(shape=(n_timesteps,n_features),name=&amp;amp;quot;TIMESERIES_INPUT&amp;amp;quot;) static_input = Input(shape=(x_train_over_static.shape[1], ),name=&amp;amp;quot;STATIC_INPUT&amp;amp;quot;) # RNN Layers # layer - 1 rec_layer_one = Bidirectional(LSTM(128, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01),return_sequences=True),name =&amp;amp;quot;BIDIRECTIONAL_LAYER_1&amp;amp;quot;)(recurrent_input) rec_layer_one = Dropout(0.1,name =&amp;amp;quot;DROPOUT_LAYER_1&amp;amp;quot;)(rec_layer_one) # layer - 2 rec_layer_two = Bidirectional(LSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01)),name =&amp;amp;quot;BIDIRECTIONAL_LAYER_2&amp;amp;quot;)(rec_layer_one) rec_layer_two = Dropout(0.1,name =&amp;amp;quot;DROPOUT_LAYER_2&amp;amp;quot;)(rec_layer_two) # SLP Layers static_layer_one = Dense(64, kernel_regularizer=l2(0.001), activation='relu',name=&amp;amp;quot;DENSE_LAYER_1&amp;amp;quot;)(static_input) # Combine layers - RNN + SLP combined = Concatenate(axis= 1,name = &amp;amp;quot;CONCATENATED_TIMESERIES_STATIC&amp;amp;quot;)([rec_layer_two,static_layer_one]) combined_dense_two = Dense(64, activation='relu',name=&amp;amp;quot;DENSE_LAYER_2&amp;amp;quot;)(combined) output = Dense(n_outputs,activation='sigmoid',name=&amp;amp;quot;OUTPUT_LAYER&amp;amp;quot;)(combined_dense_two) # Compile Model model = Model(inputs=[recurrent_input,static_input],outputs=[output]) # binary cross entropy loss model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',f1_m,precision_m, recall_m]) # focal loss def focal_loss_custom(alpha, gamma): def binary_focal_loss(y_true, y_pred): fl = tfa.losses.SigmoidFocalCrossEntropy(alpha=alpha, gamma=gamma) y_true_K = K.ones_like(y_true) focal_loss = fl(y_true, y_pred) return focal_loss return binary_focal_loss model.compile(loss=focal_loss_custom(alpha=0.2, gamma=2.0), optimizer='adam', metrics=['accuracy',f1_m,precision_m, recall_m]) model.summary()
The model summary shows the model constructed as seen in figure 2.
Since the input data is imbalanced, we also define the precision, recall, and f1-score calculation for the evaluation of the model.
# Define metrics for evaluating the model - recall, precision and f1-score def recall_m(y_true, y_pred): true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) recall = true_positives / (possible_positives + K.epsilon()) return recall def precision_m(y_true, y_pred): true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) precision = true_positives / (predicted_positives + K.epsilon()) return precision def f1_m(y_true, y_pred): precision = precision_m(y_true, y_pred) recall = recall_m(y_true, y_pred) return 2*((precision*recall)/(precision+recall+K.epsilon()))
Results
Let’s train our model. We first train the model using binary cross-entropy loss and then using focal loss. Focal loss applies a modulating term to the cross-entropy loss to focus learning on hard negative examples, and you can read in detail about it in the paper.
# fit network history = model.fit([np.asarray(x_train_reshape).astype('float32'), np.asarray(x_train_over_static).astype('float32')], y_train_reshape, epochs=epochs, batch_size=batch_size, verbose=verbose, validation_data=([np.asarray(x_val_reshape).astype('float32'), np.asarray(x_val_static).astype('float32')],y_val_reshape)) # summarize history for accuracy pyplot.plot(history.history['accuracy']) pyplot.plot(history.history['val_accuracy']) pyplot.title('model accuracy') pyplot.ylabel('accuracy') pyplot.xlabel('epoch') pyplot.legend(['train', 'validation'], loc='upper left') pyplot.show() # summarize history for loss pyplot.plot(history.history['loss']) pyplot.plot(history.history['val_loss']) pyplot.title('model loss') pyplot.ylabel('loss') pyplot.xlabel('epoch') pyplot.legend(['train', 'validation'], loc='upper left') pyplot.show() #evaluate model loss, accuracy, f1_score, precision, recall = model.evaluate([np.asarray(x_test_reshape).astype('float32'),np.asarray(x_test_static).astype('float32')], y_test_reshape, batch_size=batch_size, verbose=0) #print output print("Accuracy:{} , F1_Score:{}, Precision:{}, Recall:{}".format(accuracy, f1_score, precision, recall))
Figure 3 the plots for both loss and accuracy for train and validation sets respectively.
The results on the test set using binary cross entropy loss is Accuracy:0.95, F1_Score:0.18, Precision:0.12, Recall:0.41.
The results on the test set using focal loss is Accuracy:0.99, F1_Score:0.35, Precision:0.37, Recall:0.35. Using the focal loss improves the model performance significantly.
Conclusion
RNNs are proven to work exceptionally well with time-series-based data. Often in actual life data, supplementary static features may be available, which cannot get directly incorporated into RNNs because of their non-sequential nature. The method described involves adding static features to RNNs to influence the learning process. A previous approach to the problem was to implement several models for each modality and combine them at the prediction level. Combining these two methods into the same model architecture allows the model to learn simultaneously from the static and temporal features.
We conclude that the addition of the static features improves the performance of the RNN than would otherwise by using the sequential and static features alone.
You might also like