本模型采用了 Bryan Lim 等人在论文《Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting》(https://arxiv.org/abs/1912.09363)中提出的两个重要架构组件——GRN 和 VSN,它们在结构化数据学习任务中非常实用。
门控残差网络(Gated Residual Networks, GRN):包含跳跃连接和门控层,可高效促进信息流。GRN 能够灵活地仅在需要处应用非线性处理。 GRN 利用门控线性单元(Gated Linear Units,简称 GLU)来抑制与特定任务无关的输入。
GRN 的工作原理如下:
变量选择网络(Variable Selection Networks, VSN):有助于从输入中精心筛选出最重要的特征,并去除可能损害模型性能的不必要噪声输入。 VSN 的工作原理如下:
注意:本模型并非基于上述论文中描述的完整 TFT 模型,而仅使用了其中的 GRN 和 VSN 组件,这表明 GRN 和 VSN 组件本身在结构化数据学习任务中也能发挥重要作用。
本模型可用于二分类任务,以判断一个人年收入是否超过 500K 美元。
本模型使用 UCI 机器学习仓库提供的美国人口普查收入数据集进行训练。 该数据集包含加权的人口普查数据,其中的人口统计和就业相关变量提取自美国人口普查局 1994 年和 1995 年的当前人口调查。 数据集包含约 29.9K 个样本,具有 41 个输入变量和 1 个目标变量 income_level。 变量 instance_weight 未用作模型输入,因此模型最终使用 40 个输入特征,包括 7 个数值特征和 33 个类别特征:
| 数值特征 | 类别特征 |
|---|---|
| age | class of worker |
| wage per hour | industry code |
| capital gains | occupation code |
| capital losses | adjusted gross income |
| dividends from stocks | education |
| num persons worked for employer | veterans benefits |
| weeks worked in year | enrolled in edu inst last wk |
| marital status | |
| major industry code | |
| major occupation code | |
| mace | |
| hispanic Origin | |
| sex | |
| member of a labor union | |
| reason for unemployment | |
| full or part time employment stat | |
| federal income tax liability | |
| tax filer status | |
| region of previous residence | |
| state of previous residence | |
| detailed household and family stat | |
| detailed household summary in household | |
| migration code-change in msa | |
| migration code-change in reg | |
| migration code-move within reg | |
| live in this house 1 year ago | |
| migration prev res in sunbelt | |
| family members under 18 | |
| total person earnings | |
| country of birth father | |
| country of birth mother | |
| country of birth self | |
| citizenship | |
| total person income | |
| own business or self employed | |
| taxable income amount | |
| fill inc questionnaire for veteran's admin |
该数据集已预先划分为训练集和测试集两部分。 训练集包含 199523 个样本,测试集包含 99762 个样本。
数据准备:加载训练集和测试集,并将目标列 income_level 从字符串类型转换为整数类型。训练集进一步划分为训练子集和验证子集。最后,将训练子集和验证子集转换为用于模型训练和评估的 tf.data.Dataset。
定义输入特征编码逻辑:我们对分类特征和数值特征按如下方式进行编码:
分类特征:使用 Keras 提供的 Embedding(嵌入)层进行编码。嵌入层的输出维度等于 encoding_size。
数值特征:通过 Keras 提供的 Dense(全连接)层执行线性变换,将其投影到 encoding_size 维向量。
因此,所有编码后的特征将具有相同的维度,该维度等于 encoding_size 的值。
创建模型:
编译、训练和评估模型:
训练过程中使用了以下超参数:
| 超参数 | 值 |
|---|---|
| name | Adam |
| learning_rate | 0.0010000000474974513 |
| decay | 0.0 |
| beta_1 | 0.8999999761581421 |
| beta_2 | 0.9990000128746033 |
| epsilon | 1e-07 |
| amsgrad | False |
| training_precision | float32 |
