HuggingFace镜像/p3-wine-quality-keras-ensemble-mixed-top5
模型介绍文件和版本分析
下载使用量0

P3 葡萄酒质量预测 - Keras 集成混合 Top5

本仓库包含为葡萄酒质量预测开发的最终选定 Keras 集成模型。

该集成模型通过 11 个物理化学输入特征预测葡萄酒质量。每个成员模型首先生成连续的回归输出。最终的集成预测通过对五个选定成员模型的连续输出取平均值,然后进行基于阈值的整数质量转换来计算。

选定的集成模型

  • 选定集成模型:Ensemble_Mixed_Top5
  • 成员模型数量:5
  • 聚合方法:连续回归输出的平均值
  • 阈值调优:基于验证集的坐标搜索
  • 输出:连续预测值和整数葡萄酒质量标签

成员模型

  • Q_MLP_32_16_Huber
  • Q_MLP_16_8
  • Q_MLP_32
  • Q_MLP_64
  • AEPC_MLP_64

数据集摘要

  • 数据形状:(4898, 12)
  • 输入特征:11
  • 目标变量:quality
  • 理论质量范围:0 至 10
  • 实际观测质量范围:3 至 9

输入特征

  1. fixed acidity
  2. volatile acidity
  3. citric acid
  4. residual sugar
  5. chlorides
  6. free sulfur dioxide
  7. total sulfur dioxide
  8. density
  9. pH
  10. sulphates
  11. alcohol

预处理

集成成员模型采用了两种预处理策略。

分位数 1%~99% 截断

此预处理策略使用从训练集计算的第 1 百分位数和第 99 百分位数对选定特征值进行截断。然后对选定的偏斜特征应用 log1p 变换,并进行 StandardScaler 标准化。

AEPC:自适应保留超额截断

AEPC 是一种项目设计的预处理策略,用于处理上尾极端值。它不是完全截断超过上限的值,而是压缩超额部分,同时保留原始极端值信号的可控部分。

对于超过 AEPC 上限的值,调整后的值计算如下:

adjusted_value = cap + alpha * (original_value - cap)

这里,alpha 控制保留多少过量信号。alpha 的值由上尾特征决定,例如尾比率、目标信号和尾部连续性。这使得 AEPC 不像硬限幅那样激进,同时仍能减少极端值的影响。

对数转换特征

['chlorides', 'volatile acidity', 'free sulfur dioxide', 'citric acid', 'residual sugar']

阈值

所选阈值如下:

[3.5, 4.4, 5.65, 6.55, 7.35, 8.5]

性能

所选集成模型性能

划分RMSE RawMAE RawR2 Raw准确率KappaMAE_int宏F1加权F1
验证集0.6898500.5315970.3915910.5775510.3446540.4653060.2969270.555434
测试集0.6939020.5399580.3866280.5724490.3393060.4693880.2778220.552183

模型对比说明

与最终单一模型相比,所选集成模型提升了整数质量预测性能。最终单一模型的测试集准确率为0.560204,测试集Kappa系数为0.310787,测试集MAE_int为0.492857;而本集成模型的测试集准确率为0.572449,测试集Kappa系数为0.339306,测试集MAE_int为0.469388。

仓库文件

文件描述
models/训练好的Keras成员模型
preprocess/预处理缩放器和转换参数
ensemble_config.json集成模型配置和所选阈值
preprocess_config.json数据集、预处理和训练配置
metrics.json评估指标
member_model_results.csv五个成员模型的原始回归结果
confusion_matrix_selected.csv所选集成模型的混淆矩阵
classification_report_selected.txt所选集成模型的分类报告
inference.py推理示例脚本
requirements.txt所需的Python包

示例用法

from inference import predict_quality

sample = {
    "fixed acidity": 7.0,
    "volatile acidity": 0.27,
    "citric acid": 0.36,
    "residual sugar": 20.7,
    "chlorides": 0.045,
    "free sulfur dioxide": 45.0,
    "total sulfur dioxide": 170.0,
    "density": 1.001,
    "pH": 3.00,
    "sulphates": 0.45,
    "alcohol": 8.8
}

result = predict_quality(sample)
print(result)

局限性

数据集主要集中在中等质量分数附近,尤其是5分和6分。因此,对于罕见的极端质量分数,预测结果的可靠性可能较低。

集成模型通过对多个成员模型的预测结果取平均,可以提高预测的稳定性,但对于罕见的极端质量标签,仍可能难以准确预测。