XGBoost - 分类

XGBoost 最常见的用途之一是分类。它根据输入特征预测离散类标签。分类是使用 XGBClassifier 模块进行的,该模块是专门为处理分类任务而创建的。

XGBClassifier 语法

为了提高性能,我们可以调整 XGBoost 中 XGBClassifier 类的超参数。构建 XGBoost 分类器的基本语法如下所示 −

model = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=num_classes,      
    max_depth=max_depth,       
    learning_rate=learning_rate,
    subsample=subsample,        
    colsample_bytree=colsample, 
    n_estimators=num_estimators
)

以下是 XGBClassifier 语法中使用的超参数的描述 −

  • objective='multi:softprob - 它是多类分类的可选客观参数,并返回每个类的概率分数。对于二元分类,默认值为"binary:logistic"。

  • num_class=num_classes - 它是多类分类任务所必需的,并显示数据集中存在的类数。

  • max_depth=max_depth - 它是一个可选参数,显示每个决策树的最大深度。

  • learning_rate=learning_rate - 它是一个可选参数,其中步长收缩可避免过度拟合。

  • subsample=subsample - 它是一个可选参数,显示每棵树使用的样本分数。

  • colsample_bytree=colsample - 它也是一个可选参数,显示每个决策树使用的特征分数树。

  • n_estimators=num_estimators - 这是一个必需参数,用于查找提升迭代次数并处理模型的整体复杂性。

XGBoost 分类示例

鸢尾花数据集是机器学习中非常流行的数据集。它包含 150 个鸢尾花示例,每个示例有四个测量值,需要对三种鸢尾花进行分类。

让我们使用鸢尾花数据集来展示使用 XGBoost 库的分类:

    import xgboost as xgb
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score, classification_report
    
    # 加载 Iris 数据集
    data = load_iris()
    X, y = data.data, data.target
    
    # 将数据拆分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # 创建 XGBoost 分类器
    model = xgb.XGBClassifier()
    
    # 在训练数据上训练模型
    model.fit(X_train, y_train)
    
    # 在测试集上进行预测
    predictions = model.predict(X_test)
    
    # 计算准确率
    accuracy = accuracy_score(y_test, predictions)
    
    print("Model's Accuracy is:", accuracy)
    print("Model's Classification Report is:")
    print(classification_report(y_test, predictions, target_names=data.target_names))

输出

这将导致以下结果 −

Model's Accuracy is: 1.0

Model's Classification Report is:
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00         9
   virginica       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

总结

XGBoost 是机器学习的强大工具,尤其适用于分类任务。它在许多情况下都表现良好,因为它速度快,并且具有有助于防止过度拟合的功能。例如 - 我们使用 XGBoost 将鸢尾花分类为不同类型,实现了 1.0 的完美准确率。它的灵活性和效率使 XGBoost 成为许多现实生活中分类问题的绝佳选择。