HuggingFace镜像/structured-data-classification-grn-vsn
模型介绍文件和版本分析
下载使用量0

模型描述

本模型采用了 Bryan Lim 等人在论文《Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting》(https://arxiv.org/abs/1912.09363)中提出的两个重要架构组件——GRN 和 VSN,它们在结构化数据学习任务中非常实用。

  1. 门控残差网络(Gated Residual Networks, GRN):包含跳跃连接和门控层,可高效促进信息流。GRN 能够灵活地仅在需要处应用非线性处理。 GRN 利用门控线性单元(Gated Linear Units,简称 GLU)来抑制与特定任务无关的输入。

    GRN 的工作原理如下:

    • 首先对输入进行非线性 ELU 变换
    • 接着进行线性变换,随后应用 dropout
    • 然后应用 GLU,并将原始输入与 GLU 的输出相加,实现跳跃(残差)连接
    • 最后进行层归一化并生成输出
  2. 变量选择网络(Variable Selection Networks, VSN):有助于从输入中精心筛选出最重要的特征,并去除可能损害模型性能的不必要噪声输入。 VSN 的工作原理如下:

    • 首先,对每个特征单独应用一个门控残差网络(GRN)。
    • 然后将所有特征拼接起来,并对拼接后的特征应用一个 GRN,随后通过 softmax 生成特征权重。
    • 最后计算各个独立 GRN 输出的加权和作为结果。

注意:本模型并非基于上述论文中描述的完整 TFT 模型,而仅使用了其中的 GRN 和 VSN 组件,这表明 GRN 和 VSN 组件本身在结构化数据学习任务中也能发挥重要作用。

预期用途

本模型可用于二分类任务,以判断一个人年收入是否超过 500K 美元。

训练与评估数据

本模型使用 UCI 机器学习仓库提供的美国人口普查收入数据集进行训练。 该数据集包含加权的人口普查数据,其中的人口统计和就业相关变量提取自美国人口普查局 1994 年和 1995 年的当前人口调查。 数据集包含约 29.9K 个样本,具有 41 个输入变量和 1 个目标变量 income_level。 变量 instance_weight 未用作模型输入,因此模型最终使用 40 个输入特征,包括 7 个数值特征和 33 个类别特征:

数值特征类别特征
ageclass of worker
wage per hourindustry code
capital gainsoccupation code
capital lossesadjusted gross income
dividends from stockseducation
num persons worked for employerveterans benefits
weeks worked in yearenrolled in edu inst last wk
marital status
major industry code
major occupation code
mace
hispanic Origin
sex
member of a labor union
reason for unemployment
full or part time employment stat
federal income tax liability
tax filer status
region of previous residence
state of previous residence
detailed household and family stat
detailed household summary in household
migration code-change in msa
migration code-change in reg
migration code-move within reg
live in this house 1 year ago
migration prev res in sunbelt
family members under 18
total person earnings
country of birth father
country of birth mother
country of birth self
citizenship
total person income
own business or self employed
taxable income amount
fill inc questionnaire for veteran's admin

该数据集已预先划分为训练集和测试集两部分。 训练集包含 199523 个样本,测试集包含 99762 个样本。

训练流程

  1. 数据准备:加载训练集和测试集,并将目标列 income_level 从字符串类型转换为整数类型。训练集进一步划分为训练子集和验证子集。最后,将训练子集和验证子集转换为用于模型训练和评估的 tf.data.Dataset。

  2. 定义输入特征编码逻辑:我们对分类特征和数值特征按如下方式进行编码:

    • 分类特征:使用 Keras 提供的 Embedding(嵌入)层进行编码。嵌入层的输出维度等于 encoding_size。

    • 数值特征:通过 Keras 提供的 Dense(全连接)层执行线性变换,将其投影到 encoding_size 维向量。

    因此,所有编码后的特征将具有相同的维度,该维度等于 encoding_size 的值。

  3. 创建模型:

    • 模型将具有与给定数据集的数值特征和分类特征相对应的输入层。
    • 输入层接收的特征随后使用步骤 2 中定义的编码逻辑进行编码,其中 encoding_size 设为 16,指示编码后特征的输出维度。
    • 编码后的特征通过变量选择网络(VSN)。如“模型描述”部分所述,VSN 内部也利用了 GRN。
    • VSN 生成的特征通过一个带有 sigmoid 激活函数的最终 Dense 层,以产生模型的最终输出,该输出表示一个人收入是否大于 500K 的概率。
  4. 编译、训练和评估模型:

    • 由于该模型用于二分类任务,选择的损失函数为二元交叉熵。
    • 用于评估模型性能的指标为 accuracy(准确率)。
    • 选择的优化器为 Adam,学习率为 0.001。
    • GRN 中 Dropout 层的 dropout_rate( dropout 率)为 0.15。
    • 选择的 batch_size(批大小)为 265,模型训练 20 个 epochs(轮次)。
    • 训练过程中使用 Keras 的 EarlyStopping(早停)回调函数,这意味着一旦验证指标停止改善,训练就会中断。
    • 最后,在 test_dataset(测试集)上评估模型性能,准确率达到约 95%。

训练超参数

训练过程中使用了以下超参数:

超参数值
nameAdam
learning_rate0.0010000000474974513
decay0.0
beta_10.8999999761581421
beta_20.9990000128746033
epsilon1e-07
amsgradFalse
training_precisionfloat32

模型图

查看模型图

Model Image

致谢:

  • HF 贡献者:Shivalika Singh
  • 完整归功于 Khalid Salama 的原始 Keras 示例
  • 在此查看演示空间 here