XGBoost Dask API on yarn with hyperparameter optimization
A practitioner’s guide to implementing a XGBoost model with hyperparameter optimization using the Dask API for distributed computing. This article will limit its scope to a working script running on a yarn cluster. A official guide on the Dask API for XGBoost is here. We will limit this to simply building the model and finding the best hyperparameters for a XGBoost classification model. We can extend this to a XGBoost regression model as well.
The Basics
We will use the following environment to make this work
The articles assumes familiarity with hyperparameter optimization and eXtreme Gradient Boosting.
We will start by initializing a dask client to distribute the workload. We will have to bundle the conda environment in a compressed format to distribute to the yarn workers. I have found conda-pack quite useful in compressing the environment.
# -*- coding: utf-8 -*-
from dask_yarn import YarnCluster
from dask.distributed import Client
import dask
# Create a cluster
cluster = \
environment='<path to archived python environment>',
worker_vcores=10, worker_memory='20GiB',
# Scale out to ten such workers
# Connect to the cluster
client = Client(cluster)
Next we will create a dask dataframe and split the dataset into train and test. Stratifying the dataset is unavailable out of the box in dask_ml. I have modified the dask module to stratify the split datasets.
# -*- coding: utf-8 -*-
import dask.dataframe as dd
from dask_ml.model_selection import train_test_split
df = dd.read_csv('dataset_44_spambase.csv')
(X, y) = (df.drop(['class'], axis=1), df['class'])
(X_train, X_test, y_train, y_test) = train_test_split(
classes=[0, 1],
We will now utilize optuna to perform hyperparameter optimization and identify the best params for a xgboost model. Here we utilize the xgboost dask implementation to distribute the optuna trials.
# -*- coding: utf-8 -*-
from pprint import pprint
import optuna
import joblib
import numpy as np
import sklearn.metrics
from dask.distributed import Client
import dask_optuna
import xgboost as xgb
dtrain = xgb.dask.DaskDMatrix(client, data=X_train, label=y_train)
dtest = xgb.dask.DaskDMatrix(client, data=X_test, label=y_test)
def objective(trial):
param = {
'silent': 1,
'objective': 'binary:logistic',
'tree_method': 'hist',
'booster': trial.suggest_categorical('booster', ['gbtree',
'gblinear', 'dart']),
'lambda': trial.suggest_float('lambda', 1e-8, 1.0, log=True),
'alpha': trial.suggest_float('alpha', 1e-8, 1.0, log=True),
if param['booster'] == 'gbtree' or param['booster'] == 'dart':
param['max_depth'] = trial.suggest_int('max_depth', 1, 9)
param['eta'] = trial.suggest_float('eta', 1e-8, 1.0, log=True)
param['gamma'] = trial.suggest_float('gamma', 1e-8, 1.0,
param['grow_policy'] = trial.suggest_categorical('grow_policy',
['depthwise', 'lossguide'])
if param['booster'] == 'dart':
param['sample_type'] = trial.suggest_categorical('sample_type',
['uniform', 'weighted'])
param['normalize_type'] = \
trial.suggest_categorical('normalize_type', ['tree',
param['rate_drop'] = trial.suggest_float('rate_drop', 1e-8,
1.0, log=True)
param['skip_drop'] = trial.suggest_float('skip_drop', 1e-8,
1.0, log=True)
bst = xgb.dask.train(client, param, dtrain)
preds = xgb.dask.predict(client, bst['booster'], dtest)
pred_labels = np.rint(preds)
accuracy = sklearn.metrics.accuracy_score(y_test, pred_labels)
return accuracy
# process
storage = dask_optuna.DaskStorage()
study = optuna.create_study(storage=storage, direction='maximize')
with joblib.parallel_backend('dask'):
study.optimize(objective, n_trials=100)
print 'Best params:'
Note: This article highlights a single configuration to distribute the computations in dask. We have tested this on a large dataset [3 million rows x 200 column] with all numerical features
Further Improvements
Optuna is evolving and will have support for Dask in the future. We are limited to Dask-Optuna until then. XGBoost can be further improved with pruning and additional parameter optimizations. they are intentionally ignored given the scope of this article.
Known Issues
- The example here reads the parquet file with each iteration and then distibutes the data for computations. We can further optimize this
- dask-optuna does not operate well with optuna>2.3.0 issue