LightGBM Custom Loss Function

Python

LightGBM gives you the option to create your own custom loss functions. The loss function you create needs to take two parameters: the prediction made by your lightGBM model and the training data.

Inside the loss function we can extract the true value of our target by using the get_label() method from the training dataset we pass to the model. Using this and the prediction made by the model we can calculate the error between the actual target and the model's prediction which will be used to calculate the loss.

A custom loss function requires two values to be returned known as grad and hess. Grad is the first derivative of whatever loss function defined and hess is the second derivative.

In the example below we've used the custom loss function to create a simple mean squared error function to demonstrate this. The loss function is simply the square of the error, therefore the first derivative (Grad) is simply the error multiplied by 2 and the second derivative (Hess) is just 2.

Note: Hess and grad both need to be 1-dimensional arrays of the same length as the number of examples in our training data set. That's why in the example below, instead of just declaring hess = 2, we multiply the error array by 0 and then add 2 to ensure we have a 1-dimensional array that has the same length as our training data set with every element in the array set to 2.

 1|  import LightGBM as lgb
 2|  
 3|  def custom_loss(y_pred, data):
 4|      y_true = data.get_label()
 5|      error = y_pred-y_true
 6|      
 7|      #1st derivative of loss function
 8|      grad = 2 * error
 9|  
10|      #2nd derivative of loss function
11|      hess = 0 * error + 2
12|      
13|      return grad, hess
14|  
15|  params = {"learning_rate" : 0.1}
16|  
17|  training_data = lgb.Dataset(X_train , label = y_train)
18|  
19|  model = lgb.train(train_set=training_data,
20|                    params=params, 
21|                    fobj=custom_loss)
Did you find this snippet useful?

Sign up for free to to add this to your code library