[ML]TabPFN: 一种基于因果推理的先验数据拟合分类算法

一 本文简介

本文介绍了一种名为TabPFN的方法,适用于小型表格分类任务。该方法使用Transformer来特征化数据集,并且可以通过网络向前传递一次学习,而无需反向传播。此外,该模型是“先验数据拟合网络”的一个实例,使用丰富的模拟数据进行类似元学习一样的模型预训练。预训练数据由贝叶斯神经网络或结构因果模型生成,并根据需要进行调整,以生成类似于表格数据的数据,每个数据集最多有2000个样本,100个特征和10个不平衡类。作者在30个小数据集上对该方法进行了基准测试,包括使用ligthGBM和XGBoost建立的基线进行比较,结果表明,该方法不需要调参的情况下可以获得更好的性能。


TabPFN通过生成数据预训练一个有先验知识的预训练模型,在使用时快速地计算具体任务的后验概率分布,以达到比树模型更好的分类预测效果。

二 背景知识

表格数据任务是机器学习(ML)应用中最普遍的任务类型。目前,Gradient-Boosted Decision Trees (GBDT)仍然是解决表格数据分类问题的主流算法,主要原因是其训练时间短且鲁棒性强。审计网络在表格数据任务中的表现通常不是最优的,在之前的文章中我们也探讨了这个问题。

三 本文工作

本文提出的新方法TabPFN,用于表格分类。与传统的监督学习方法不同,TabPFN是一个单一的Transformer模型,并在大量的生成数据上进行预先训练。这意味着它已经学会了如何处理各种不同类型的表格数据。通过将所有可能的数据生成机制进行加权,根据它们在给定数据和先验概率下的可能性,TabPFN可以近似计算后验预测分布(PPD),即未知数据的概率分布。学习近似这个复杂的PPD可以获得一个通用的模型,可以适用于各种小型表格分类任务,而不需要重新进行训练或者选择不同的模型。这个方法的突破性在于能够快速且准确地解决小型表格分类问题。
图片
本文用到了基于因果推理的预测方法,这种方法寻找系统组件之间的因果关系来预测未见数据上的观察结果。现有的大多数工作都主要集中在确定单个因果图上用于下游预测,但由于SCMs的非识别性和DAG空间的组合性质,这会存在问题。相比之下,本文介绍了一种基于TabPFN的先验,它考虑了一类广泛的SCMs,跳过了任何显式的图形表示,并直接近似PPD。我们不进行因果推理,而是直接解决下游预测任务。
图片
图片

四 实验结果

在OpenML-CC18基准中的18个数据集上的实验表明,TabPFN明显优于其它的分类方法,其性能与复杂的最先进的AutoML系统不相上下,加速高达230倍。当使用GPU时,这会增加到5700倍的加速。
图片
图片

五 总结展望

本文提供了一种用于表格数据分类的神经网络方法,它在合成数据上进行预训练,然后可以快速应用于未见数据集,并且不需要超参数调整。这个想法是很新颖的,它与现有的标准数据分类方法非常不同。该方法的一个主要缺点是其仅能用于小型数据集。

发布者:股市刺客,转载请注明出处:https://www.95sca.cn/archives/111091
站内所有文章皆来自网络转载或读者投稿,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。敬请谅解!

(0)
股市刺客的头像股市刺客
上一篇 1天前
下一篇 1天前

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注