Update README.md
Browse files
README.md
CHANGED
@@ -8,4 +8,106 @@ tags:
|
|
8 |
- foundation-model
|
9 |
- deep-learning
|
10 |
- in-context
|
11 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
- foundation-model
|
9 |
- deep-learning
|
10 |
- in-context
|
11 |
+
---
|
12 |
+
|
13 |
+
# ConTextTab: A Semantics-Aware Tabular In-Context Learner
|
14 |
+
[](https://github.com/SAP-samples/contexttab/)
|
15 |
+
|
16 |
+
## Description
|
17 |
+
|
18 |
+
Implementation of the deep learning model with the inference pipeline described in the paper "ConTextTab: A Semantics-Aware Tabular In-Context Learner".
|
19 |
+
|
20 |
+

|
21 |
+
## Abstract
|
22 |
+
|
23 |
+
Tabular in-context learning (ICL) has recently achieved state-of-the-art (SOTA) performance on several tabular prediction tasks. Previously restricted to classification problems on small tables, recent advances such as TabPFN and TabICL have extended its use to larger datasets. While being architecturally efficient and well-adapted to tabular data structures, current table-native ICL architectures, being trained exclusively on synthetic data, do not fully leverage the rich semantics and world knowledge contained in real-world tabular data. On another end of this spectrum, tabular ICL models based on pretrained large language models such as TabuLa-8B integrate deep semantic understanding and world knowledge but are only able to make use of a small amount of context due to inherent architectural limitations. With the aim to combine the best of both these worlds, we introduce **ConTextTab**, integrating semantic understanding and alignment into a table-native ICL framework. By employing specialized embeddings for different data modalities and by training on large-scale real-world tabular data, our model is competitive with SOTA across a broad set of benchmarks while setting a new standard on the semantically rich CARTE benchmark.
|
24 |
+
|
25 |
+
## Requirements
|
26 |
+
|
27 |
+
This project uses Git LFS to manage model checkpoints. If you haven't installed Git LFS yet, please run:
|
28 |
+
```git lfs install```
|
29 |
+
and then clone the repository:
|
30 |
+
```git clone https://github.com/SAP-samples/contexttab.git```
|
31 |
+
Model checkpoints will be automatically downloaded.
|
32 |
+
|
33 |
+
The requirements are detailed in the `requirements.txt` file for Python 3.11 version.
|
34 |
+
|
35 |
+
Local development installation:
|
36 |
+
```pip install -e .```
|
37 |
+
|
38 |
+
Installation from source:
|
39 |
+
```pip install git+https://github.com/SAP-samples/contexttab```
|
40 |
+
|
41 |
+
## Basic Usage
|
42 |
+
|
43 |
+
The model supports both classification and regression tasks. It accepts input data in the form of a pandas DataFrame or a NumPy array. No preprocessing is required, column names and cell values are automatically embedded using an LLM that is running in the background, and any missing values are handled correctly.
|
44 |
+
|
45 |
+
For best performance, use a GPU with at least 80 GB of memory and set the context size to 8192. For large tables, it is recommended to use a bagging factor of 8.
|
46 |
+
|
47 |
+
### Classification
|
48 |
+
|
49 |
+
```python
|
50 |
+
from sklearn.datasets import load_breast_cancer
|
51 |
+
from sklearn.metrics import accuracy_score
|
52 |
+
from sklearn.model_selection import train_test_split
|
53 |
+
|
54 |
+
from contexttab import ConTextTabClassifier
|
55 |
+
|
56 |
+
# Load sample data
|
57 |
+
X, y = load_breast_cancer(return_X_y=True)
|
58 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
|
59 |
+
|
60 |
+
# Initialize a classifier
|
61 |
+
clf = ConTextTabClassifier(bagging=1, max_context_size=2048)
|
62 |
+
|
63 |
+
clf.fit(X_train, y_train)
|
64 |
+
|
65 |
+
# Predict probabilities
|
66 |
+
prediction_probabilities = clf.predict_proba(X_test)
|
67 |
+
# Predict labels
|
68 |
+
predictions = clf.predict(X_test)
|
69 |
+
print("Accuracy", accuracy_score(y_test, predictions))
|
70 |
+
```
|
71 |
+
|
72 |
+
### Regression
|
73 |
+
```python
|
74 |
+
from sklearn.datasets import fetch_openml
|
75 |
+
from sklearn.metrics import r2_score
|
76 |
+
from sklearn.model_selection import train_test_split
|
77 |
+
|
78 |
+
from contexttab import ConTextTabRegressor
|
79 |
+
|
80 |
+
|
81 |
+
# Load sample data
|
82 |
+
df = fetch_openml(data_id=531, as_frame=True)
|
83 |
+
X = df.data
|
84 |
+
y = df.target.astype(float)
|
85 |
+
|
86 |
+
# Train-test split
|
87 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
|
88 |
+
|
89 |
+
# Initialize the regressor
|
90 |
+
regressor = ConTextTabRegressor(bagging=1, max_context_size=2048)
|
91 |
+
|
92 |
+
regressor.fit(X_train, y_train)
|
93 |
+
|
94 |
+
# Predict on the test set
|
95 |
+
predictions = regressor.predict(X_test)
|
96 |
+
|
97 |
+
r2 = r2_score(y_test, predictions)
|
98 |
+
print("R² Score:", r2)
|
99 |
+
```
|
100 |
+
|
101 |
+
## Known Issues
|
102 |
+
No known issues
|
103 |
+
|
104 |
+
## How to obtain support
|
105 |
+
[Create an issue](https://github.com/SAP-samples/contexttab/issues) in this repository if you find a bug or have questions about the content.
|
106 |
+
|
107 |
+
## Contributing
|
108 |
+
If you wish to contribute code, offer fixes or improvements, please send a pull request. Due to legal reasons, contributors will be asked to accept a DCO when they create the first pull request to this project. This happens in an automated fashion during the submission process. SAP uses [the standard DCO text of the Linux Foundation](https://developercertificate.org/).
|
109 |
+
|
110 |
+
## License
|
111 |
+
Copyright (c) 2024 SAP SE or an SAP affiliate company. All rights reserved. This project is licensed under the Apache Software License, version 2.0 except as noted otherwise in the [LICENSE](LICENSE) file.
|
112 |
+
|
113 |
+
The model checkpoints have been trained on [the T4 dataset](https://huggingface.co/datasets/mlfoundations/t4-full), which, in turn, is a subset of [the TabLib dataset](https://huggingface.co/datasets/approximatelabs/tablib-v1-full). As such, they inherit the same restrictions described therein and in particular they are only intended for research purposes.
|