Evaluation of TabPFN

Looking at a transformer model for tabular data
tabular classification
probabilistic graph model
Published

January 31, 2023

TabPFN (Hollmann et al. 2022) is a transformer model that can then be used to estimate missing tabular data. I’m interested in investigating this as probabilistic graph models can be useful for working with tabular data and this model appears to overlahp quite heavily with my area of interest.

Hollmann, Noah, Samuel Müller, Katharina Eggensperger, and Frank Hutter. 2022. “TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second.” arXiv. https://doi.org/10.48550/ARXIV.2207.01848.

The training of the model is interesting too. It’s described as a prior fitted network (PFN) and approximates Bayesian inference over the data. The model has been trained and you then fit it to your dataset by providing a small amount of data. After this fitting it is able to estimate missing values.

Dataset

The original TabPFN was evaluated on OpenML-CC18. Datasets were selected that had a relatively small number of rows:

OpenML-CC18 that contain up to 2,000 samples, 100 features, 10 classes and no missing values

There is a huggingface dataset called inria-soda/tabular-benchmark which is OpenML-CC18 datasets that are slightly larger. I am going to use this and select one of the smaller datasets. It’s not ideal to use a dataset from the same overall set that was used to train the model, but given the constraints on the original model we can be sure that it wasn’t used directly in training.

Code
from datasets import load_dataset
import pandas as pd

dataset = load_dataset("inria-soda/tabular-benchmark", data_files="reg_num/wine_quality.csv")
df = pd.DataFrame(dataset["train"])
df.head()
Using custom data configuration inria-soda--tabular-benchmark-ee9df09d63c8b659
Found cached dataset csv (/home/matthew/.cache/huggingface/datasets/inria-soda___csv/inria-soda--tabular-benchmark-ee9df09d63c8b659/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
fixed.acidity volatile.acidity citric.acid residual.sugar chlorides free.sulfur.dioxide total.sulfur.dioxide density pH sulphates alcohol quality
0 7.4 0.70 0.00 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5
1 7.8 0.88 0.00 2.6 0.098 25.0 67.0 0.9968 3.20 0.68 9.8 5
2 7.8 0.76 0.04 2.3 0.092 15.0 54.0 0.9970 3.26 0.65 9.8 5
3 11.2 0.28 0.56 1.9 0.075 17.0 60.0 0.9980 3.16 0.58 9.8 6
4 7.4 0.70 0.00 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5

Simple Usage

Let’s just apply the model to the dataset. This is the starter code adjusted to use this dataset. The aim is to predict the wine quality.

Code
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier

X_train, X_test, y_train, y_test = train_test_split(
    df[[
        "fixed.acidity",
        "volatile.acidity",
        "citric.acid",
        "residual.sugar",
        "chlorides",
        "free.sulfur.dioxide",
        "total.sulfur.dioxide",
        "density",
        "pH",
        "sulphates",
        "alcohol",
    ]],
    df.quality,
    train_size=1_000,
    random_state=42,
)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

print('Accuracy', accuracy_score(y_test, y_eval))
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
Accuracy 0.5701291613607422

This took about a minute to run and it doesn’t appear to be very accurate. What’s the problem here?

The problem is that the accuracy score is an indicator of a perfect prediction while the quality is a linear value. This is a regression task instead of a categorization task.

There are a few metrics for regression tasks - mean squared error and mean absolute error are both appropriate.

Code
mean_square_error = ((y_test - y_eval)**2).mean()
mean_absolute_error = (y_test - y_eval).abs().mean()

print(f"mean square error: {mean_square_error:0.3f}")
print(f"mean absolute error: {mean_absolute_error:0.3f}")
mean square error: 0.565
mean absolute error: 0.473
Code
(
    (y_test - y_eval)
        .sort_values() 
        .reset_index()
        .rename(columns={"quality": "error", "Unnamed: 0": "count"})
        .groupby("error")
        .agg(len)
        .div(len(y_test))
        .plot(title="TabPFN error")
) ; None

This shows the 57% accuracy that is the zero error while the overall shape of this is similar to a normal distribution. Such a distribution is to be expected as the model was trained to approximate Bayesian inference.

Floating Point Target

I can also change the type of the Y column to see if that gives the model greater freedom. Every value in the column is an integer, if they were floats it might indicate that the output is continuous.

Code
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier

X_train, X_test, y_train, y_test = train_test_split(
    df[[
        "fixed.acidity",
        "volatile.acidity",
        "citric.acid",
        "residual.sugar",
        "chlorides",
        "free.sulfur.dioxide",
        "total.sulfur.dioxide",
        "density",
        "pH",
        "sulphates",
        "alcohol",
    ]],
    df.quality.astype(float), # CHANGED
    train_size=1_000,
    random_state=42,
)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

mean_square_error = ((y_test - y_eval)**2).mean()
mean_absolute_error = (y_test - y_eval).abs().mean()

print(f"mean square error: {mean_square_error:0.3f}")
print(f"mean absolute error: {mean_absolute_error:0.3f}")
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
mean square error: 0.565
mean absolute error: 0.473

This is identical to the previous run, so the way that the model determined that the output was separate was based on value instead of type.

Permuted Target

The values were collected back into a discrete set based on value rather than type. If the value is randomly altered by a tiny amount then it should no longer form clearly separate values and TabPFN should treat it as a regression task.

Code
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import numpy as np

from tabpfn import TabPFNClassifier

X_train, X_test, y_train, y_test = train_test_split(
    df[[
        "fixed.acidity",
        "volatile.acidity",
        "citric.acid",
        "residual.sugar",
        "chlorides",
        "free.sulfur.dioxide",
        "total.sulfur.dioxide",
        "density",
        "pH",
        "sulphates",
        "alcohol",
    ]],
    df.quality.astype(float), # CHANGED
    train_size=1_000,
    random_state=42,
)

# permute the training target values slightly to make the model think that these are not integers
y_train += (np.random.random(y_train.shape) - 0.5) * 2 / 1_000

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

mean_square_error = ((y_test - y_eval)**2).mean()
mean_absolute_error = (y_test - y_eval).abs().mean()

print(f"mean square error: {mean_square_error:0.3f}")
print(f"mean absolute error: {mean_absolute_error:0.3f}")
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
ValueError: Unknown label type: 'continuous'

This is interesting. It seems that the model is unable to predict continuous values?

This is specifically mentioned in the paper, section 4.5:

we need to transform our scalar labels \(\hat{y}\) to discrete class labels y. We do so by splitting the values of \(\hat{y}\) into intervals that map to class labels

The docstring on the check_classification_targets is:

Ensure that target y is of a non-regression type

So this is an intentional choice by the model.

Model Structure

How does the fitting work? Why is it restricted in this way?

If we take the model apart then we can see the different stages involved. To refresh, the model is designed as a scikit-learn estimator and that involves fitting the model to the training data and then predicting over unseen data.

Model Fitting

This is surprisingly simple:

Signature: TabPFNClassifier.fit(self, X, y, overwrite_warning=False)
Docstring: <no docstring>
Source:   
    def fit(self, X, y, overwrite_warning=False):
        # Check that X and y have correct shape
        X, y = check_X_y(X, y, force_all_finite=False)
        # Store the classes seen during fit
        y = self._validate_targets(y)
        self.X_ = X
        self.y_ = y
        if X.shape[1] > self.max_num_features:
            raise ValueError("The number of features for this classifier is restricted to ", self.max_num_features)
        if len(np.unique(y)) > self.max_num_classes:
            raise ValueError("The number of classes for this classifier is restricted to ", self.max_num_classes)
        if X.shape[0] > 1024 and not overwrite_warning:
            raise ValueError("⚠️ WARNING: TabPFN is not made for datasets with a trainingsize > 1024. Prediction might take a while, be less reliable. We advise not to run datasets > 10k samples, which might lead to your machine crashing (due to quadratic memory scaling of TabPFN). Please confirm you want to run by passing overwrite_warning=True to the fit function.")
            
        # Return the classifier
        return self
File:      ~/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/tabpfn/scripts/transformer_prediction_interface.py
Type:      function

The relevant lines are:

self.X_ = X
self.y_ = y

so fitting the model is literally just holding the training data on the model. It must be used when performing predictions in some way, and that then suggests that inference is actually a row (or rows) + all training data.

Model Prediction

So is the predictor running the underlying deep learning model with the rows that were saved from fitting?

Signature:
TabPFNClassifier.predict(
    self,
    X,
    return_winning_probability=False,
    normalize_with_test=False,
)
Docstring: <no docstring>
Source:   
    def predict(self, X, return_winning_probability=False, normalize_with_test=False):
        p = self.predict_proba(X, normalize_with_test=normalize_with_test)
        y = np.argmax(p, axis=-1)
        y = self.classes_.take(np.asarray(y, dtype=np.intp))
        if return_winning_probability:
            return y, p.max(axis=-1)
        return y
File:      ~/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/tabpfn/scripts/transformer_prediction_interface.py
Type:      function

The predict function initially calls predict_proba which combines the provided values with the training set.

Signature: TabPFNClassifier.predict_proba(self, X, normalize_with_test=False)
Docstring: <no docstring>
Source:   
    def predict_proba(self, X, normalize_with_test=False):
        # Check is fit had been called
        check_is_fitted(self)
        # Input validation
        X = check_array(X, force_all_finite=False)
        X_full = np.concatenate([self.X_, X], axis=0)
        X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1)
        y_full = np.concatenate([self.y_, np.zeros_like(X[:, 0])], axis=0)
        y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1)
        eval_pos = self.X_.shape[0]
        prediction = transformer_predict(self.model[2], X_full, y_full, eval_pos,
                                         device=self.device,
                                         style=self.style,
                                         inference_mode=True,
                                         preprocess_transform='none' if self.no_preprocess_mode else 'mix',
                                         normalize_with_test=normalize_with_test,
                                         N_ensemble_configurations=self.N_ensemble_configurations,
                                         softmax_temperature=self.temperature,
                                         combine_preprocessing=self.combine_preprocessing,
                                         multiclass_decoder=self.multiclass_decoder,
                                         feature_shift_decoder=self.feature_shift_decoder,
                                         differentiable_hps_as_style=self.differentiable_hps_as_style,
                                         seed=self.seed,
                                         **get_params_from_config(self.c))
        prediction_, y_ = prediction.squeeze(0), y_full.squeeze(1).long()[eval_pos:]
        return prediction_.detach().cpu().numpy()
File:      ~/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/tabpfn/scripts/transformer_prediction_interface.py
Type:      function

Finally this calls transformer_predict which is an extensive function that has inner methods. This suggests to me that this might be better structured as a class, and it happens that we have one already - the model that is used for making the prediction.

To better understand the invocation we can break down the argument values according to the current classifier:

transformer_predict(
    self.model[2],
        # self.model is a tuple with inf, inf, TransformerModel
        # the TransformerModel is defined within the tabpfn repo
    X_full,
        # this is the combination of X_train and X
    y_full,
        # this is the y_train with zeros for each row in X
    eval_pos,
        # the number of rows in the training set, shows where predictions should start?
    device=self.device,
        # defaults to cpu
    style=self.style,
        # defaults to None, should dig into this
    inference_mode=True,
        # used to turn on torch.inference_mode,
        # seems partially implemented as the model is unconditionally put into evaluation mode
        # this method is clearly not for training - this parameter should be removed imo
    preprocess_tranform='mix',
        # no_preprocess_mode is false
        # the mix preprocess_transform is actually a combination of none and power_all
        # valid values appear to be:
        # * power / power_all (synonyms?)
        # * quantile / quantile_all
        # * robust / robust_all
        # it's not clear what these values represent and they are passed around as strings a lot.
        # it would be better to define an enumeration for them and map strings to those types internally.
    normalize_with_test=False,
        # used in a few places with a semaphore value
    N_ensemble_configurations=32,
        # explicitly set to 32 in the example, defaults to 3
        # used to select the first N values from a product of ensemble and preprocess configurations
        # this is where the randomness kicks in as the product is over two randomly sorted lists of values
        # note - is randperm used appropriately - would the product cause a limited variance over one value?
    softmax_temperature=None,
        # much like distillation this can apply temperature to the logits before softmax
        # this value is ignored and defaults to log(0.8) unless style is not None
        # the coupling of these two values is non obvious
    combine_preprocessing=False,
        # used once to apply power_all, quantile_all and normalization to the X data
        # otherwise the preprocess_transform is applied
        # the precedence of this over preprocess_transform is non obvious
    multiclass_decoder='permutation',
        # used in the generation of the ensembles,
        # the check is 'is this value permutation', could easily be a boolean
    feature_shift_decoder=True,
        # used in the generation of the ensembles much like multiclass_decoder
        # at least one of these must be set or there will only be a single model in the ensemble
        # these coupled values could be better handled
    differentiable_hps_as_style=False,
        # overrides style, setting it to None
        # why does this exist?
    seed=0,
        # defaults to zero in constructor of TabPFNClassifier
    # **get_params_from_config(self.c)
    max_features=100,
        # the X data is extended to this many features, and normalized based on that
    rescale_features=True,
        # not used?
    normalize_to_ranking=False,
        # used to choose between removing outliers and normalizing a ranked version of the data
        # in preprocessing
    normalize_with_sqrt=False,
        # change the normalization to square root the divisor
        # given that the divisor is (0-1] turning this on would reduce the values
        # the normalization is based on the ratio of used features to max features so is a constant
)

That’s quite a lot of parameters. This makes me feel like this would be better if it were split up more. A configuration object already exists and is partially used. If that configuration was revisited then it would be a chance to reduce the parameter count and more clearly communicate coupled variables.

It’s not a big deal as the code in the method is reasonably straightforward. The method is quite long so I have folded it below, and will lift specific parts out to discuss them further.

Signature:
transformer_predict(
    model,
    eval_xs,
    eval_ys,
    eval_position,
    device='cpu',
    max_features=100,
    style=None,
    inference_mode=False,
    num_classes=2,
    extend_features=True,
    normalize_with_test=False,
    normalize_to_ranking=False,
    softmax_temperature=0.0,
    multiclass_decoder='permutation',
    preprocess_transform='mix',
    categorical_feats=[],
    feature_shift_decoder=False,
    N_ensemble_configurations=10,
    combine_preprocessing=False,
    batch_size_inference=16,
    differentiable_hps_as_style=False,
    average_logits=True,
    fp16_inference=False,
    normalize_with_sqrt=False,
    seed=0,
    **kwargs,
)
Source:   
def transformer_predict(model, eval_xs, eval_ys, eval_position,
                        device='cpu',
                        max_features=100,
                        style=None,
                        inference_mode=False,
                        num_classes=2,
                        extend_features=True,
                        normalize_with_test=False,
                        normalize_to_ranking=False,
                        softmax_temperature=0.0,
                        multiclass_decoder='permutation',
                        preprocess_transform='mix',
                        categorical_feats=[],
                        feature_shift_decoder=False,
                        N_ensemble_configurations=10,
                        combine_preprocessing=False,
                        batch_size_inference=16,
                        differentiable_hps_as_style=False,
                        average_logits=True,
                        fp16_inference=False,
                        normalize_with_sqrt=False,
                        seed=0,
                        **kwargs):
    """
    :param model:
    :param eval_xs:
    :param eval_ys:
    :param eval_position:
    :param rescale_features:
    :param device:
    :param max_features:
    :param style:
    :param inference_mode:
    :param num_classes:
    :param extend_features:
    :param normalize_to_ranking:
    :param softmax_temperature:
    :param multiclass_decoder:
    :param preprocess_transform:
    :param categorical_feats:
    :param feature_shift_decoder:
    :param N_ensemble_configurations:
    :param average_logits:
    :param normalize_with_sqrt:
    :param metric_used:
    :return:
    """
    num_classes = len(torch.unique(eval_ys))
    def predict(eval_xs, eval_ys, used_style, softmax_temperature, return_logits):
        # Initialize results array size S, B, Classes
        inference_mode_call = torch.inference_mode() if inference_mode else NOP()
        with inference_mode_call:
            start = time.time()
            output = model(
                    (used_style.repeat(eval_xs.shape[1], 1) if used_style is not None else None, eval_xs, eval_ys.float()),
                    single_eval_pos=eval_position)[:, :, 0:num_classes]
            output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature)
            if not return_logits:
                output = torch.nn.functional.softmax(output, dim=-1)
            #else:
            #    output[:, :, 1] = model((style.repeat(eval_xs.shape[1], 1) if style is not None else None, eval_xs, eval_ys.float()),
            #               single_eval_pos=eval_position)
            #    output[:, :, 1] = torch.sigmoid(output[:, :, 1]).squeeze(-1)
            #    output[:, :, 0] = 1 - output[:, :, 1]
        #print('RESULTS', eval_ys.shape, torch.unique(eval_ys, return_counts=True), output.mean(axis=0))
        return output
    def preprocess_input(eval_xs, preprocess_transform):
        import warnings
        if eval_xs.shape[1] > 1:
            raise Exception("Transforms only allow one batch dim - TODO")
        if preprocess_transform != 'none':
            if preprocess_transform == 'power' or preprocess_transform == 'power_all':
                pt = PowerTransformer(standardize=True)
            elif preprocess_transform == 'quantile' or preprocess_transform == 'quantile_all':
                pt = QuantileTransformer(output_distribution='normal')
            elif preprocess_transform == 'robust' or preprocess_transform == 'robust_all':
                pt = RobustScaler(unit_variance=True)
        # eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys)
        eval_xs = normalize_data(eval_xs, normalize_positions=-1 if normalize_with_test else eval_position)
        # Removing empty features
        eval_xs = eval_xs[:, 0, :]
        sel = [len(torch.unique(eval_xs[0:eval_ys.shape[0], col])) > 1 for col in range(eval_xs.shape[1])]
        eval_xs = eval_xs[:, sel]
        warnings.simplefilter('error')
        if preprocess_transform != 'none':
            eval_xs = eval_xs.cpu().numpy()
            feats = set(range(eval_xs.shape[1])) if 'all' in preprocess_transform else set(
                range(eval_xs.shape[1])) - set(categorical_feats)
            for col in feats:
                try:
                    pt.fit(eval_xs[0:eval_position, col:col + 1])
                    trans = pt.transform(eval_xs[:, col:col + 1])
                    # print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])]))
                    eval_xs[:, col:col + 1] = trans
                except:
                    pass
            eval_xs = torch.tensor(eval_xs).float()
        warnings.simplefilter('default')
        eval_xs = eval_xs.unsqueeze(1)
        # TODO: Cautian there is information leakage when to_ranking is used, we should not use it
        eval_xs = remove_outliers(eval_xs, normalize_positions=-1 if normalize_with_test else eval_position) if not normalize_to_ranking else normalize_data(to_ranking_low_mem(eval_xs))
        # Rescale X
        eval_xs = normalize_by_used_features_f(eval_xs, eval_xs.shape[-1], max_features,
                                               normalize_with_sqrt=normalize_with_sqrt)
        return eval_xs.detach().to(device)
    eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device)
    eval_ys = eval_ys[:eval_position]
    model.to(device)
    model.eval()
    import itertools
    if not differentiable_hps_as_style:
        style = None
    if style is not None:
        style = style.to(device)
        style = style.unsqueeze(0) if len(style.shape) == 1 else style
        num_styles = style.shape[0]
        softmax_temperature = softmax_temperature if softmax_temperature.shape else softmax_temperature.unsqueeze(
            0).repeat(num_styles)
    else:
        num_styles = 1
        style = None
        softmax_temperature = torch.log(torch.tensor([0.8]))
    styles_configurations = range(0, num_styles)
    def get_preprocess(i):
        if i == 0:
            return 'power_all'
#            if i == 1:
#                return 'robust_all'
        if i == 1:
            return 'none'
    preprocess_transform_configurations = ['none', 'power_all'] if preprocess_transform == 'mix' else [preprocess_transform]
    if seed is not None:
        torch.manual_seed(seed)
    feature_shift_configurations = torch.randperm(eval_xs.shape[2]) if feature_shift_decoder else [0]
    class_shift_configurations = torch.randperm(len(torch.unique(eval_ys))) if multiclass_decoder == 'permutation' else [0]
    ensemble_configurations = list(itertools.product(class_shift_configurations, feature_shift_configurations))
    #default_ensemble_config = ensemble_configurations[0]
    rng = random.Random(seed)
    rng.shuffle(ensemble_configurations)
    ensemble_configurations = list(itertools.product(ensemble_configurations, preprocess_transform_configurations, styles_configurations))
    ensemble_configurations = ensemble_configurations[0:N_ensemble_configurations]
    #if N_ensemble_configurations == 1:
    #    ensemble_configurations = [default_ensemble_config]
    output = None
    eval_xs_transformed = {}
    inputs, labels = [], []
    start = time.time()
    for ensemble_configuration in ensemble_configurations:
        (class_shift_configuration, feature_shift_configuration), preprocess_transform_configuration, styles_configuration = ensemble_configuration
        style_ = style[styles_configuration:styles_configuration+1, :] if style is not None else style
        softmax_temperature_ = softmax_temperature[styles_configuration]
        eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone()
        if preprocess_transform_configuration in eval_xs_transformed:
            eval_xs_ = eval_xs_transformed[preprocess_transform_configuration].clone()
        else:
            if eval_xs_.shape[-1] * 3 < max_features and combine_preprocessing:
                eval_xs_ = torch.cat([preprocess_input(eval_xs_, preprocess_transform='power_all'),
                            preprocess_input(eval_xs_, preprocess_transform='quantile_all')], -1)
                eval_xs_ = normalize_data(eval_xs_, normalize_positions=-1 if normalize_with_test else eval_position)
                #eval_xs_ = torch.stack([preprocess_input(eval_xs_, preprocess_transform='power_all'),
                #                        preprocess_input(eval_xs_, preprocess_transform='robust_all'),
                #                        preprocess_input(eval_xs_, preprocess_transform='none')], -1)
                #eval_xs_ = torch.flatten(torch.swapaxes(eval_xs_, -2, -1), -2)
            else:
                eval_xs_ = preprocess_input(eval_xs_, preprocess_transform=preprocess_transform_configuration)
            eval_xs_transformed[preprocess_transform_configuration] = eval_xs_
        eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float()
        eval_xs_ = torch.cat([eval_xs_[..., feature_shift_configuration:],eval_xs_[..., :feature_shift_configuration]],dim=-1)
        # Extend X
        if extend_features:
            eval_xs_ = torch.cat(
                [eval_xs_,
                 torch.zeros((eval_xs_.shape[0], eval_xs_.shape[1], max_features - eval_xs_.shape[2])).to(device)], -1)
        inputs += [eval_xs_]
        labels += [eval_ys_]
    inputs = torch.cat(inputs, 1)
    inputs = torch.split(inputs, batch_size_inference, dim=1)
    labels = torch.cat(labels, 1)
    labels = torch.split(labels, batch_size_inference, dim=1)
    #print('PREPROCESSING TIME', str(time.time() - start))
    outputs = []
    start = time.time()
    for batch_input, batch_label in zip(inputs, labels):
        #preprocess_transform_ = preprocess_transform if styles_configuration % 2 == 0 else 'none'
        import warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore",
                                    message="None of the inputs have requires_grad=True. Gradients will be None")
            warnings.filterwarnings("ignore",
                                    message="torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available.  Disabling.")
            if device == 'cpu':
                output_batch = checkpoint(predict, batch_input, batch_label, style_, softmax_temperature_, True)
            else:
                with torch.cuda.amp.autocast(enabled=fp16_inference):
                    output_batch = checkpoint(predict, batch_input, batch_label, style_, softmax_temperature_, True)
        outputs += [output_batch]
    #print('MODEL INFERENCE TIME ('+str(batch_input.device)+' vs '+device+', '+str(fp16_inference)+')', str(time.time()-start))
    outputs = torch.cat(outputs, 1)
    for i, ensemble_configuration in enumerate(ensemble_configurations):
        (class_shift_configuration, feature_shift_configuration), preprocess_transform_configuration, styles_configuration = ensemble_configuration
        output_ = outputs[:, i:i+1, :]
        output_ = torch.cat([output_[..., class_shift_configuration:],output_[..., :class_shift_configuration]],dim=-1)
        #output_ = predict(eval_xs, eval_ys, style_, preprocess_transform_)
        if not average_logits:
            output_ = torch.nn.functional.softmax(output_, dim=-1)
        output = output_ if output is None else output + output_
    output = output / len(ensemble_configurations)
    if average_logits:
        output = torch.nn.functional.softmax(output, dim=-1)
    output = torch.transpose(output, 0, 1)
    return output
File:      ~/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/tabpfn/scripts/transformer_prediction_interface.py
Type:      function

The function is split into four stages:

  • the first being cleaning the arguments and the formation of the list of ensembles
  • then for each ensemble the data is preprocessed and added to an array
  • then the preprocessed data is run through inference
  • finally the model output for each ensemble is summed before or after softmax (softmax after sum if average_logits is true)

The softmax_temperature is applied to the result of running the model. Using softmax in this way forces the produced values to be within the range 0-1.

This is a long function that is quite complex. That makes it difficult to follow what is going on. I’m interested in seeing what the underlying model is receiving and what about this function is on the critical path. To do that I am going to start patching it to show what the values are as each part is called.

Code
classifier
TabPFNClassifier(N_ensemble_configurations=32)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Code
from tabpfn.scripts.transformer_prediction_interface import check_array, check_is_fitted, transformer_predict, get_params_from_config
import torch

X_train, X_test, y_train, y_test = train_test_split(
    df[[
        "fixed.acidity",
        "volatile.acidity",
        "citric.acid",
        "residual.sugar",
        "chlorides",
        "free.sulfur.dioxide",
        "total.sulfur.dioxide",
        "density",
        "pH",
        "sulphates",
        "alcohol",
    ]],
    df.quality,
    train_size=1_000,
    random_state=42,
)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)

def predict(self, X, return_winning_probability=False, normalize_with_test=False):
    p = predict_proba(self, X, normalize_with_test=normalize_with_test)
    y = np.argmax(p, axis=-1)
    y = self.classes_.take(np.asarray(y, dtype=np.intp))
    if return_winning_probability:
        return y, p.max(axis=-1)
    return y

def predict_proba(self, X, normalize_with_test=False):
    # Check is fit had been called
    check_is_fitted(self)

    # Input validation
    X = check_array(X, force_all_finite=False)
    X_full = np.concatenate([self.X_, X], axis=0)
    X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1)
    y_full = np.concatenate([self.y_, np.zeros_like(X[:, 0])], axis=0)
    y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1)

    eval_pos = self.X_.shape[0]

    def print_args(**kwargs):
        print(kwargs)
    
    print_args(
        eval_xs=X_full.shape,
        eval_ys=y_full.shape,
        eval_position=eval_pos,
        device=self.device,
        style=self.style,
        inference_mode=True,
        preprocess_transform='none' if self.no_preprocess_mode else 'mix',
        normalize_with_test=normalize_with_test,
        N_ensemble_configurations=self.N_ensemble_configurations,
        softmax_temperature=self.temperature,
        combine_preprocessing=self.combine_preprocessing,
        multiclass_decoder=self.multiclass_decoder,
        feature_shift_decoder=self.feature_shift_decoder,
        differentiable_hps_as_style=self.differentiable_hps_as_style,
        seed=self.seed,
        **get_params_from_config(self.c)
    )
    prediction = transformer_predict(
        self.model[2],
        eval_xs=X_full,
        eval_ys=y_full,
        eval_position=eval_pos,
        device=self.device,
        style=self.style,
        inference_mode=True,
        preprocess_transform='none' if self.no_preprocess_mode else 'mix',
        normalize_with_test=normalize_with_test,
        N_ensemble_configurations=self.N_ensemble_configurations,
        softmax_temperature=self.temperature,
        combine_preprocessing=self.combine_preprocessing,
        multiclass_decoder=self.multiclass_decoder,
        feature_shift_decoder=self.feature_shift_decoder,
        differentiable_hps_as_style=self.differentiable_hps_as_style,
        seed=self.seed,
        **get_params_from_config(self.c)
    )
    prediction_, y_ = prediction.squeeze(0), y_full.squeeze(1).long()[eval_pos:]

    return prediction_.detach().cpu().numpy()

y_eval, p_eval = predict(classifier, X_test, return_winning_probability=True)

mean_square_error = ((y_test - y_eval)**2).mean()
mean_absolute_error = (y_test - y_eval).abs().mean()

print(f"mean square error: {mean_square_error:0.3f}")
print(f"mean absolute error: {mean_absolute_error:0.3f}")
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
{'eval_xs': torch.Size([6497, 1, 11]), 'eval_ys': torch.Size([6497, 1]), 'eval_position': 1000, 'device': 'cpu', 'style': None, 'inference_mode': True, 'preprocess_transform': 'mix', 'normalize_with_test': False, 'N_ensemble_configurations': 32, 'softmax_temperature': None, 'combine_preprocessing': False, 'multiclass_decoder': 'permutation', 'feature_shift_decoder': True, 'differentiable_hps_as_style': False, 'seed': 0, 'max_features': 100, 'rescale_features': True, 'normalize_to_ranking': False, 'normalize_with_sqrt': False}
mean square error: 0.565
mean absolute error: 0.473
Code
classifier.classes_
array([3, 4, 5, 6, 7, 8])
Code
transformer_predict??
Signature:
transformer_predict(
    model,
    eval_xs,
    eval_ys,
    eval_position,
    device='cpu',
    max_features=100,
    style=None,
    inference_mode=False,
    num_classes=2,
    extend_features=True,
    normalize_with_test=False,
    normalize_to_ranking=False,
    softmax_temperature=0.0,
    multiclass_decoder='permutation',
    preprocess_transform='mix',
    categorical_feats=[],
    feature_shift_decoder=False,
    N_ensemble_configurations=10,
    combine_preprocessing=False,
    batch_size_inference=16,
    differentiable_hps_as_style=False,
    average_logits=True,
    fp16_inference=False,
    normalize_with_sqrt=False,
    seed=0,
    **kwargs,
)
Source:   
def transformer_predict(model, eval_xs, eval_ys, eval_position,
                        device='cpu',
                        max_features=100,
                        style=None,
                        inference_mode=False,
                        num_classes=2,
                        extend_features=True,
                        normalize_with_test=False,
                        normalize_to_ranking=False,
                        softmax_temperature=0.0,
                        multiclass_decoder='permutation',
                        preprocess_transform='mix',
                        categorical_feats=[],
                        feature_shift_decoder=False,
                        N_ensemble_configurations=10,
                        combine_preprocessing=False,
                        batch_size_inference=16,
                        differentiable_hps_as_style=False,
                        average_logits=True,
                        fp16_inference=False,
                        normalize_with_sqrt=False,
                        seed=0,
                        **kwargs):
    """
    :param model:
    :param eval_xs:
    :param eval_ys:
    :param eval_position:
    :param rescale_features:
    :param device:
    :param max_features:
    :param style:
    :param inference_mode:
    :param num_classes:
    :param extend_features:
    :param normalize_to_ranking:
    :param softmax_temperature:
    :param multiclass_decoder:
    :param preprocess_transform:
    :param categorical_feats:
    :param feature_shift_decoder:
    :param N_ensemble_configurations:
    :param average_logits:
    :param normalize_with_sqrt:
    :param metric_used:
    :return:
    """
    num_classes = len(torch.unique(eval_ys))
    def predict(eval_xs, eval_ys, used_style, softmax_temperature, return_logits):
        # Initialize results array size S, B, Classes
        inference_mode_call = torch.inference_mode() if inference_mode else NOP()
        with inference_mode_call:
            start = time.time()
            output = model(
                    (used_style.repeat(eval_xs.shape[1], 1) if used_style is not None else None, eval_xs, eval_ys.float()),
                    single_eval_pos=eval_position)[:, :, 0:num_classes]
            output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature)
            if not return_logits:
                output = torch.nn.functional.softmax(output, dim=-1)
            #else:
            #    output[:, :, 1] = model((style.repeat(eval_xs.shape[1], 1) if style is not None else None, eval_xs, eval_ys.float()),
            #               single_eval_pos=eval_position)
            #    output[:, :, 1] = torch.sigmoid(output[:, :, 1]).squeeze(-1)
            #    output[:, :, 0] = 1 - output[:, :, 1]
        #print('RESULTS', eval_ys.shape, torch.unique(eval_ys, return_counts=True), output.mean(axis=0))
        return output
    def preprocess_input(eval_xs, preprocess_transform):
        import warnings
        if eval_xs.shape[1] > 1:
            raise Exception("Transforms only allow one batch dim - TODO")
        if preprocess_transform != 'none':
            if preprocess_transform == 'power' or preprocess_transform == 'power_all':
                pt = PowerTransformer(standardize=True)
            elif preprocess_transform == 'quantile' or preprocess_transform == 'quantile_all':
                pt = QuantileTransformer(output_distribution='normal')
            elif preprocess_transform == 'robust' or preprocess_transform == 'robust_all':
                pt = RobustScaler(unit_variance=True)
        # eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys)
        eval_xs = normalize_data(eval_xs, normalize_positions=-1 if normalize_with_test else eval_position)
        # Removing empty features
        eval_xs = eval_xs[:, 0, :]
        sel = [len(torch.unique(eval_xs[0:eval_ys.shape[0], col])) > 1 for col in range(eval_xs.shape[1])]
        eval_xs = eval_xs[:, sel]
        warnings.simplefilter('error')
        if preprocess_transform != 'none':
            eval_xs = eval_xs.cpu().numpy()
            feats = set(range(eval_xs.shape[1])) if 'all' in preprocess_transform else set(
                range(eval_xs.shape[1])) - set(categorical_feats)
            for col in feats:
                try:
                    pt.fit(eval_xs[0:eval_position, col:col + 1])
                    trans = pt.transform(eval_xs[:, col:col + 1])
                    # print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])]))
                    eval_xs[:, col:col + 1] = trans
                except:
                    pass
            eval_xs = torch.tensor(eval_xs).float()
        warnings.simplefilter('default')
        eval_xs = eval_xs.unsqueeze(1)
        # TODO: Cautian there is information leakage when to_ranking is used, we should not use it
        eval_xs = remove_outliers(eval_xs, normalize_positions=-1 if normalize_with_test else eval_position) if not normalize_to_ranking else normalize_data(to_ranking_low_mem(eval_xs))
        # Rescale X
        eval_xs = normalize_by_used_features_f(eval_xs, eval_xs.shape[-1], max_features,
                                               normalize_with_sqrt=normalize_with_sqrt)
        return eval_xs.detach().to(device)
    eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device)
    eval_ys = eval_ys[:eval_position]
    model.to(device)
    model.eval()
    import itertools
    if not differentiable_hps_as_style:
        style = None
    if style is not None:
        style = style.to(device)
        style = style.unsqueeze(0) if len(style.shape) == 1 else style
        num_styles = style.shape[0]
        softmax_temperature = softmax_temperature if softmax_temperature.shape else softmax_temperature.unsqueeze(
            0).repeat(num_styles)
    else:
        num_styles = 1
        style = None
        softmax_temperature = torch.log(torch.tensor([0.8]))
    styles_configurations = range(0, num_styles)
    def get_preprocess(i):
        if i == 0:
            return 'power_all'
#            if i == 1:
#                return 'robust_all'
        if i == 1:
            return 'none'
    preprocess_transform_configurations = ['none', 'power_all'] if preprocess_transform == 'mix' else [preprocess_transform]
    if seed is not None:
        torch.manual_seed(seed)
    feature_shift_configurations = torch.randperm(eval_xs.shape[2]) if feature_shift_decoder else [0]
    class_shift_configurations = torch.randperm(len(torch.unique(eval_ys))) if multiclass_decoder == 'permutation' else [0]
    ensemble_configurations = list(itertools.product(class_shift_configurations, feature_shift_configurations))
    #default_ensemble_config = ensemble_configurations[0]
    rng = random.Random(seed)
    rng.shuffle(ensemble_configurations)
    ensemble_configurations = list(itertools.product(ensemble_configurations, preprocess_transform_configurations, styles_configurations))
    ensemble_configurations = ensemble_configurations[0:N_ensemble_configurations]
    #if N_ensemble_configurations == 1:
    #    ensemble_configurations = [default_ensemble_config]
    output = None
    eval_xs_transformed = {}
    inputs, labels = [], []
    start = time.time()
    for ensemble_configuration in ensemble_configurations:
        (class_shift_configuration, feature_shift_configuration), preprocess_transform_configuration, styles_configuration = ensemble_configuration
        style_ = style[styles_configuration:styles_configuration+1, :] if style is not None else style
        softmax_temperature_ = softmax_temperature[styles_configuration]
        eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone()
        if preprocess_transform_configuration in eval_xs_transformed:
            eval_xs_ = eval_xs_transformed[preprocess_transform_configuration].clone()
        else:
            if eval_xs_.shape[-1] * 3 < max_features and combine_preprocessing:
                eval_xs_ = torch.cat([preprocess_input(eval_xs_, preprocess_transform='power_all'),
                            preprocess_input(eval_xs_, preprocess_transform='quantile_all')], -1)
                eval_xs_ = normalize_data(eval_xs_, normalize_positions=-1 if normalize_with_test else eval_position)
                #eval_xs_ = torch.stack([preprocess_input(eval_xs_, preprocess_transform='power_all'),
                #                        preprocess_input(eval_xs_, preprocess_transform='robust_all'),
                #                        preprocess_input(eval_xs_, preprocess_transform='none')], -1)
                #eval_xs_ = torch.flatten(torch.swapaxes(eval_xs_, -2, -1), -2)
            else:
                eval_xs_ = preprocess_input(eval_xs_, preprocess_transform=preprocess_transform_configuration)
            eval_xs_transformed[preprocess_transform_configuration] = eval_xs_
        eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float()
        eval_xs_ = torch.cat([eval_xs_[..., feature_shift_configuration:],eval_xs_[..., :feature_shift_configuration]],dim=-1)
        # Extend X
        if extend_features:
            eval_xs_ = torch.cat(
                [eval_xs_,
                 torch.zeros((eval_xs_.shape[0], eval_xs_.shape[1], max_features - eval_xs_.shape[2])).to(device)], -1)
        inputs += [eval_xs_]
        labels += [eval_ys_]
    inputs = torch.cat(inputs, 1)
    inputs = torch.split(inputs, batch_size_inference, dim=1)
    labels = torch.cat(labels, 1)
    labels = torch.split(labels, batch_size_inference, dim=1)
    #print('PREPROCESSING TIME', str(time.time() - start))
    outputs = []
    start = time.time()
    for batch_input, batch_label in zip(inputs, labels):
        #preprocess_transform_ = preprocess_transform if styles_configuration % 2 == 0 else 'none'
        import warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore",
                                    message="None of the inputs have requires_grad=True. Gradients will be None")
            warnings.filterwarnings("ignore",
                                    message="torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available.  Disabling.")
            if device == 'cpu':
                output_batch = checkpoint(predict, batch_input, batch_label, style_, softmax_temperature_, True)
            else:
                with torch.cuda.amp.autocast(enabled=fp16_inference):
                    output_batch = checkpoint(predict, batch_input, batch_label, style_, softmax_temperature_, True)
        outputs += [output_batch]
    #print('MODEL INFERENCE TIME ('+str(batch_input.device)+' vs '+device+', '+str(fp16_inference)+')', str(time.time()-start))
    outputs = torch.cat(outputs, 1)
    for i, ensemble_configuration in enumerate(ensemble_configurations):
        (class_shift_configuration, feature_shift_configuration), preprocess_transform_configuration, styles_configuration = ensemble_configuration
        output_ = outputs[:, i:i+1, :]
        output_ = torch.cat([output_[..., class_shift_configuration:],output_[..., :class_shift_configuration]],dim=-1)
        #output_ = predict(eval_xs, eval_ys, style_, preprocess_transform_)
        if not average_logits:
            output_ = torch.nn.functional.softmax(output_, dim=-1)
        output = output_ if output is None else output + output_
    output = output / len(ensemble_configurations)
    if average_logits:
        output = torch.nn.functional.softmax(output, dim=-1)
    output = torch.transpose(output, 0, 1)
    return output
File:      ~/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/tabpfn/scripts/transformer_prediction_interface.py
Type:      function
Code
from tabpfn.transformer import TransformerModel

TransformerModel??
Code
TabPFNClassifier??