Last active
February 9, 2024 16:24
-
-
Save jefferythewind/4fb31496e2445c22fcdfc36b3b0feb04 to your computer and use it in GitHub Desktop.
Mulit-Class Softmax log loss custom objective for LightGBM Classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import lightgbm as lgb | |
from sklearn.datasets import load_iris | |
from sklearn.metrics import confusion_matrix | |
import time | |
# Load Iris dataset | |
iris = load_iris() | |
# Separate features (X) and target (y) | |
X = iris.data | |
y = iris.target | |
# Print the shape of the data | |
print("Shape of features (X):", X.shape) | |
print("Shape of target (y):", y.shape) | |
# Define softmax function | |
def softmax(x): | |
e_x = np.exp(x.T - np.max(x, axis=1)).T | |
return ( e_x.T / e_x.sum(axis=1) ).T | |
# Custom softmax cross-entropy loss | |
def softmax_cross_entropy(preds, train_data): | |
labels = train_data.get_label() | |
num_class = len(np.unique(labels)) | |
labels = np.eye(num_class)[labels.astype(int)] | |
preds = softmax(preds) | |
grad = preds - labels | |
hess = preds*(1 - preds) | |
return grad, hess | |
lgb_train = lgb.Dataset(X, y) | |
params = { | |
'max_depth': 5, | |
'num_leaves': 16, | |
'n_estimators': 50, | |
'objective': 'multiclass', | |
'num_class': len(np.unique(y)), | |
'verbosity':-1 | |
} | |
tic = time.time() | |
model = lgb.train(params, lgb_train) | |
toc = time.time() | |
print('Default') | |
print( confusion_matrix( y, np.argmax( model.predict(X), axis=1 ) ) ) | |
print("Elapsed Time: ", (toc-tic) ) | |
params = { | |
'max_depth': 5, | |
'num_leaves': 16, | |
'n_estimators': 50, | |
'objective': softmax_cross_entropy, | |
'num_class': len(np.unique(y)), | |
'verbosity':-1 | |
} | |
tic = time.time() | |
model = lgb.train(params, lgb_train) | |
toc = time.time() | |
print('Custom') | |
print( confusion_matrix( y, np.argmax( model.predict(X), axis=1 ) ) ) | |
print("Elapsed Time: ", (toc-tic) ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment