LightGBM MultiClass Classification


This code snippet consists of three main steps. First, we initialise and fit the LightGBM model with training data. In this step we specify the parameters of the model such as the number of estimators, maximum depth, learning rate, and regularization parameters. These parameters help the model to learn patterns in the data and reduce overfitting. Once the model is trained, we save it for later use.

Secondly we plot the feature importances. In this step, we use a bar chart to plot the importance values of the features. The plot gives us an intuitive sense of the feature importances making it easy to identify the most important features.

Finally we make predictions on the test data and evaluate the model's accuracy. We use the predict method to predict the class labels for the test data, and then calculate the accuracy of the model using the accuracy_score function. The accuracy score tells us how well the model is performing on unseen data.

 1|  import lightgbm as lgb
 2|  import matplotlib.pyplot as plt
 3|  from sklearn.metrics import accuracy_score
 5|  # Step 1: Initialise and fit LightGBM multiclass model
 6|  model = lgb.LGBMClassifier(objective='multiclass', 
 7|                             n_estimators=1000, 
 8|                             max_depth=4, 
 9|                             learning_rate=0.1, 
10|                             reg_lambda=1, 
11|                             random_state=101)
12|, y_train, verbose=False)
14|  # Save the model
15|  model.booster_.save_model('kaggle/working/lgb_classification.model')
17|  # Step 2: Plot feature importances
18|  features = X_train.columns
19|  importance_values = model.feature_importances_
21|  plt.barh(y=range(len(features)),
22|           width=importance_values,
23|           tick_label=features)
26|  # Step 3: Make predictions for test data & evaluate performance
27|  y_pred = model.predict(X_test)
28|  print('Accuracy', accuracy_score(y_test, y_pred))
Did you find this snippet useful?

Sign up for free to to add this to your code library