Optimizing XGBoost Hyperparameters with Optuna

Python

This code snippet shows how to use Optuna to optimise hyperparameters for an XGBoost model. The objective function defines the hyperparameters to tune, trains an XGBoost model, evaluates the model on the validation set and returns the score. The Optuna study object then optimizes the objective function over 100 trials before the best hyperparameters and score are printed out.

 1|  import optuna
 2|  import xgboost as xgb
 3|  from sklearn.metrics import roc_auc_score
 4|  
 5|  # Step 1: Define the objective function for Optuna to minimize
 6|  def objective(trial):
 7|      # Define the hyperparameters to tune
 8|      params = {
 9|          'objective': 'binary:logistic',
10|          'eval_metric': 'auc',
11|          'tree_method': 'hist',
12|          'max_depth': trial.suggest_int('max_depth', 2, 10),
13|          'learning_rate': trial.suggest_loguniform('learning_rate', 1e-4, 1e-1),
14|          'subsample': trial.suggest_uniform('subsample', 0.1, 1.0),
15|          'colsample_bytree': trial.suggest_uniform('colsample_bytree', 0.1, 1.0),
16|          'min_child_weight': trial.suggest_loguniform('min_child_weight', 1e-5, 1),
17|      }
18|      
19|      # Train an XGBoost model with the current hyperparameters
20|      dtrain = xgb.DMatrix(train_x, label=train_y)
21|      dvalid = xgb.DMatrix(valid_x, label=valid_y)
22|      model = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dvalid, 'validation')], early_stopping_rounds=20, verbose_eval=False)
23|      
24|      # Evaluate the model on the validation set and return the score
25|      y_pred = model.predict(dvalid)
26|      return roc_auc_score(valid_y, y_pred)
27|  
28|  # Step 2: Create an Optuna study object and optimize the objective function
29|  study = optuna.create_study(direction='maximize')
30|  study.optimize(objective, n_trials=100)
31|  
32|  # Step 3: Print the best hyperparameters and score found by Optuna
33|  print(f'Best score: {study.best_value:.5f}')
34|  print('Best parameters:')
35|  for key, value in study.best_params.items():
36|      print(f'    {key}: {value}')
Did you find this snippet useful?

Sign up for free to to add this to your code library