XGBoost4j 调用python训练好的模型
XGBoost工程化中的一个问题
概述
使用python xgboost 训练lambda rank 排序模型,线上工程化时使用XGBoost4j 调用python 已经训练好的模型进行线上预测。
python xgb 训练以及保存模型
def train_xgb_model(x_train,
y_train,
group_train,
x_valid,
y_valid,
group_valid):
params = { 'objective' : 'rank:ndcg',
'learning_rate' : 0.1,
'gamma' : 1.0,
'nthread': 60,
'min_child_weight' : 0.1,
'max_depth' : 8,
'n_estimators' : 5,
'reg_alpha' : 0.01,
'reg_lambda' : 0.01
}
model = xgb.sklearn.XGBRanker(**params)
model.fit(x=x_train,
y=y_train,
group=group_train,
eval_set=[(x_valid, y_valid)],
eval_group=[group_valid],
verbose=True)
#保存模型
model.save_model('xgb_model.bin')
return model
python 加载模型及预测
# 加载模型
model = xgb.Booster(model_file='xgb_model.bin')
# 进行预测
test_data = np.asarray(feature_arr)
test_data = xgb.DMatrix(test_data)
rank_score = model.predict(test_data)
XGBoost4j 调用模型及预测
//加载模型
Booster booster = XGBoost.loadModel("xgb_model.bin");
//使用特征数据构造Dmatrix,多行样本flatten 为一维数组,row, col 代表二维数组长宽。第四个参数代表missing value 的初始值,python 默认值为Nan,java默认值是0。这个参数不传的话可能会导致python 与java 预测结果不一致。
DMatrix dMatrix = new DMatrix(test_data, row, col, Float.NaN);
//进行预测
float[][] predicts = booster.predict(dMatrix);