Distributed Deep Learning on Apache Spark with Keras

Start Free Trial
November 15, 2017 by Updated April 12th, 2024

Deep Learning has been shown to produce highly effective machine learning models in a diverse group of fields. Some of the most interesting are: pharmaceutical drug discovery [1], detection of illegal fishing cargo [2], mapping dark matter [3], tracking deforestation in the Amazon [4], taxi destination prediction [5], predicting lift and grasp movements from EEG recordings [6], and medical diagnosis for cancer [7, 8]. Distributed deep learning allows for internet scale dataset sizes, as exemplified by companies like Facebook, Google, Microsoft, and other huge enterprises. This blog post demonstrates how any organization of any size can leverage distributed deep learning on Spark thanks to the Qubole Data Service (QDS).

This demonstration utilizes the Keras [9] framework for describing the structure of a deep neural network, and subsequently leverages the Dist-Keras [10] framework to achieve data parallel model training on Apache Spark. Keras was chosen in large part due to it being the dominant library for deep learning at the time of this writing [12, 13, 14].

The Goal: Distributed Deep Learning Integrated With Spark ML Pipelines

Upon completion of this blog post, an enterprising data scientist should be able to extend this demonstration to their application specific modeling needs. Within a Qubole notebook [19], you should be able to cross validate your deep neural networks using the Spark ML Pipeline interface [18], with an application specific parameter grid similar to the following

df_train = process_csv("/fully/qualified/path/to/training/data")
df_test  = process_csv("/fully/qualified/path/to/test/data")

param_grid = tuning.ParamGridBuilder() \
                   .baseOn(['regularizer', regularizers.l1_l2]) \
                   .addGrid('activations', [['tanh', 'relu']]) \
                   .addGrid('initializers', [['glorot_normal',
                                              'glorot_uniform']]) \
                   .addGrid('layer_dims', [[input_dim, 2000, 300, 1]]) \
                   .addGrid('metrics', [['mae']]) \
                   .baseOn(['learning_rate', 1e-2]) \
                   .baseOn(['reg_strength', 1e-2]) \
                   .baseOn(['reg_decay', 0.25]) \
                   .baseOn(['lr_decay', 0.90]) \
                   .addGrid('dropout_rate', [0.20, 0.35, 0.50, 0.65, 0.80]) \
                   .addGrid('loss', ['mse', 'msle']) \

estimator = DistKeras(trainers.ADAG,
                      {'batch_size': 256,
                       'communication_window': 3,
                       'num_epoch': 10,
                       'num_workers': 50},

evaluator = evaluation.RegressionEvaluator(metricName='r2')

cv_estimator = tuning.CrossValidator(estimator=estimator,
cv_model = cv_estimator.fit(df_train)

df_pred_train = cv_model.transform(df_train)
df_pred_test  = cv_model.transform(df_test)

In order to make your cross validation work as illustrated above, you need to first configure your Qubole cluster [20], and secondly set up your Qubole notebook [19]. The instructions for doing so are contained in the remainder of this blog.

Configuring Your QDS Cluster

Add the following lines to your node_boostrap script

# automatically installs latest version of Keras as dependency
pip install dist-keras
# for GPU clusters, swap out default dependency tensorflow
# with tensorflow for GPU nodes
pip uninstall tensorflow
pip install tensorflow-gpu

and restart your cluster. Note the default back-end for Keras is Tensorflow. It supports any of the following back-ends as well: CNTK, MXNET, Theano [15, 16]. To use any of the other back-ends, you must pip install them in the node_bootstrap script and subsequently tell Keras to which back-end to switch [17].

Setting Up Your QDS Notebook

First import the necessary libraries

from keras import layers, models, optimizers, regularizers, utils
from pyspark.ml import evaluation, feature, tuning
from distkeras import predictors, trainers
from pyspark.sql import functions, types
from pyspark import ml
import numpy as np 
import matplotlib 
import StringIO

after which you should define these wrappers to tightly integrate with Spark ML pipelines [18]. The wrappers are taken directly from an open source gist [11].

class DistKeras(ml.Estimator):

    def __init__(self, *args, **kwargs):
        self.__trainer_klass = args[0]
        self.__trainer_params = args[1]
        super(DistKeras, self).__init__()

    def __build_keras_model(klass, *args, **kwargs):
        loss = kwargs['loss']
        metrics = kwargs['metrics']
        layer_dims = kwargs['layer_dims']
        hidden_activation, output_activation = kwargs['activations']
        hidden_init, output_init = kwargs['initializers']
        dropout_rate = kwargs['dropout_rate']
        alpha = kwargs['reg_strength']
        reg_decay = kwargs['reg_decay']
        reg = kwargs['regularizer']
        keras_model = models.Sequential()
        for idx in range(1, len(layer_dims)-1, 1):
            alpha *= reg_decay
        return keras_model

    def __build_trainer(self, *args, **kwargs):
        loss = kwargs['loss']
        learning_rate = kwargs['learning_rate']
        lr_decay = kwargs['lr_decay']
        keras_optimizer = optimizers.SGD(learning_rate, decay=lr_decay)
        keras_model = DistKeras.__build_keras_model(**kwargs)
        self._trainer = self.__trainer_klass(keras_model, keras_optimizer,
                                             loss, **self.__trainer_params)

    def _fit(self, *args, **kwargs):
        data_frame = args[0]
        if len(args) > 1:
            params = args[1]
        keras_model = self._trainer.train(data_frame)
        return DistKerasModel(keras_model)

class DistKerasModel(ml.Model):

    def __init__(self, *args, **kwargs):
        self._keras_model = args[0]
        self._predictor = predictors.ModelPredictor(self._keras_model)
        super(DistKerasModel, self).__init__()

    def _transform(self, *args, **kwargs):
        data_frame = args[0]
        pred_col = self._predictor.output_column
        preds = self._predictor.predict(data_frame)
        return preds.withColumn(pred_col,

cast_to_double = functions.udf(lambda row: float(row[0]), types.DoubleType())

Last but not least, you should define some important helper functions, starting with the show() function which displays arbitrary and generic matplotlib figures. This function is adapted from the Qubole blog on integrating the alternate library Plotly into our notebooks [22].

# must do before importing pyplot or pylab
from matplotlib import pyplot as plt

def show(fig):
    image = StringIO.StringIO()
    fig.savefig(image, format='svg')
    print("%html <div style='width:1200px'>"+ image.buf +"</div>")

Another important helper function is process_csv() which automates the highly redundant task of creating a data frame with renamed columns (such as ‘label’ for the label column) and with excluded columns (such as unused ID columns) from a CSV file in cloud storage [21].

def process_csv(fully_qualified_path, columns_renamed=tuple(),
                excluded_columns=tuple(), num_workers=None):
    if num_workers is None:
        raise NotImplementedError

    excluded_columns = frozenset(excluded_columns)
    data_frame = sqlContext.read.format('com.databricks.spark.csv') \
                           .options(header='true', inferSchema='true') \
    for (old_name, new_name) in columns_renamed:
        data_frame = data_frame.withColumnRenamed(old_name, new_name)
    data_frame = data_frame.repartition(num_workers)

    feature_columns = tuple(frozenset(data_frame.columns) \
    transformer = feature.VectorAssembler(inputCols=feature_columns,
    data_frame = transformer.transform(data_frame) \

    return data_frame

Now you are ready to configure, train, and evaluate any distributed deep learning model described in Keras!

About The Author

Horia Margarit is a career data scientist with industry experience in machine learning for digital media, consumer search, cloud infrastructure, life sciences, and consumer finance industry verticals [23]. His expertise is in modeling and optimization on internet scale datasets, specifically leveraging distributed deep learning among other techniques. He earned dual Bachelors degrees in Cognitive and Computer Science from UC Berkeley, and a Master’s degree in Statistics from Stanford University.


  1. Deep Learning How I Did It: Merck 1st place interview
  2. The Nature Conservancy Fisheries Monitoring Competition, 1st Place Winner’s Interview: Team ‘Towards Robust-Optimal Learning of Learning’
  3. DeepZot on Dark Matter: How we won the Mapping Dark Matter challenge
  4. Planet: Understanding the Amazon from Space, 1st Place Winner’s Interview
  5. Taxi Trajectory Winners’ Interview: 1st place, Team ?
  6. Grasp-and-Lift EEG Detection Winners’ Interview: 3rd place, Team HEDJ
  7. Intel & MobileODT Cervical Cancer Screening Competition, 1st Place Winner’s Interview: Team ‘Towards Empirically Stable Training’
  8. 2017 Data Science Bowl, Predicting Lung Cancer: 2nd Place Solution Write-up, Daniel Hammack and Julian de Wit
  9. Keras Documentation
  10. Dist-Keras Documentation
  11. Dist-Keras Spark ML Pipelines
  12. Compare Keras and Pytorch’s popularity and activity
  13. Compare Keras and Caffe’s popularity and activity
  14. Compare Keras and Theano’s popularity and activity
  15. Keras shoot-out: TensorFlow vs MXNet
  16. The Search for the Fastest Keras Deep Learning Backend
  17. Keras: Switching from one backend to another
  18. Spark ML Pipelines
  19. Introduction To Qubole Notebooks
  20. Configuring A Spark Cluster
  21. The Cloud Advantage: Decoupling Storage And Compute
  22. Creating Customized Plots In Qubole Notebooks
  23. Author’s LinkedIn Profile

Start Free Trial
Read Qubole Announces Apache Spark on AWS Lambda