Prompt Training - Centroid Distance

Try to calculate the moving centroid for each class
Published

May 10, 2021

I’ve been trying to train prompts to make classifers (see here and here). The general principle is that a language model can take some input and a trained prompt to become a task specific classifier. For example the prompt “This person feels” might get a language model to act as an emotion classifier (as it might say good or bad or happy or sad).

Coming up with a suitable prompt is hard so I want to use classical training techniques to produce one. I’ve managed to train a prompt like this.

The next problem is the appropriate selection of words to look for in the output. If I choose a token then I might choose poorly. There could be a better token to choose for a given category. Once again I want to be able to choose a token automatically.

I’m now reasonably confident that the visualization techniques that I have used are working correctly. So I can use these to determine if the method of training is working.


Problem Statement

The principle problem that I have had is that I want the outputs for different classes to be distinguishable without choosing them in advance. When I select two tokens to compare (like good and bad) they are distinguishable but I have chosen them. Instead the output of the model should vary by class and this variation should be consistent.

I first tried to train the prompt by measuring the difference between two inputs. If the inputs have the same label then the distance between them should be small. For inputs with different labels the distance should be large. Doing this let me train the model but when I visualized the results the model was not clearly distinguishing between the input labels.

The problem with the distance measurement is that the points start out randomly distributed. This means that the prompt is guided in arbitrary directions rather than a consistent direction.

The second problem is that the pure distance measurement has two loss functions - one for two inputs with the same class, and one for two inputs with different classes - and the curve of these loss functions is not the same. The loss for the same class is the distance between the points - so it grows linearly. The loss for different classes is the reciprocal of the distance - so it shrinks exponentially. It is very unlikely that these two loss factors are in balance and so one relationship will tend to dominate the learning.


Proposed Solution

To solve these problems I am going to move to modelling each class as a centroid. This means that there is one ideal point for each class in the output space. The loss for a given output is thus the distance from that output to the centroid. This makes the loss function consistent across classes.

The next part is that the centroid should be discovered. To do this I will be adjusting the centroid based on the observed outputs from the model. So while the current centroid alters the prompt, the current outputs then adjust the centroid. In this way the centroid is being trained parallel to the prompt.

Finally the centroids must be different. After all, the model prompt could tell the model to ignore the input. That would mean that the output of the model becomes extremely consistent. So when updating the centroids a small impulse must be provided to move the updated centroids away from other centroids.


Centroid Movement as an Analogy of Optimizers

Optimizers move the weights of the model. Moving the centroid is like this.

A pure movement of the centroid based on the current outputs would be equivalent to SGD. So perhaps we can incorporate some of the advances that have been made in optimizer design? Momentum, dampening, etc.

The next thing is that the movement of the centroids apart is like regularization.

Can the centroids be treated as parameters of the model and trained with the same optimizer? I believe so. The loss could incorporate the distance between the centroids as a factor as well and that could use the optimizer for regularization.


Loss Function

This all means that the core problem is the definition of the loss function. We have already defined the loss for a given point as:

\[loss = \sum_{n \in batch} distance(point_n, centroid_{C_n})\] Where \(C_n\) is the classification class for point n.

If given the opportunity to optimize both \(point_n\) and \(centroid_{C_n}\) then the optimizer will act to bring them together. The rate at which it moves the centroid relative to the points can be adjusted by placing them in separate groups and having different learning rates for each group.

Then this just needs a factor related to the distance between each centroid.

Code
import torch

def centroid_loss_fn(
    output: torch.Tensor,
    labels: torch.Tensor,
    centroids: torch.nn.Parameter,
    distance_factor: float = 1,
    repulsion_factor: float = 1,
) -> torch.Tensor:
    # only supports two classes for the moment

    targets = centroids[labels]
    distance = torch.pairwise_distance(output, targets).mean()
    repulsion = 1 / torch.pairwise_distance(centroids[0][None, :], centroids[1][None, :]).sum()
    
    return (distance_factor * distance) + (repulsion_factor * repulsion)

Data Loader, Training Loop et al

Now we need a chunk of code to load the data appropriately. I’ve taken to collapsing this as these are quite extensive.

Code
#collapse

from typing import Dict, Iterator, Optional, Tuple, Union

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Past = Tuple[Tuple[torch.Tensor, ...], ...]
TextBatch = Dict[str, torch.Tensor]
PastBatch = Dict[str, Union[torch.Tensor, Past]]


class TextDataloader:
    """Provides a dataloader over a text dataframe"""

    def __init__(
        self,
        df: pd.DataFrame,
        *,
        tokenizer: AutoTokenizer,
        batch_size: int,
        max_length: int,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        self.tokenizer = tokenizer
        self.df = df
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[TextBatch]:
        """Returns an iterator that returns batches.
        The final batch can be a partial batch."""
        if self.shuffle:
            df = self.df.sample(frac=1).reset_index(drop=True)
        else:
            df = self.df
        batch_size = self.batch_size

        for i in range(len(self)):
            start = i * batch_size
            end = start + batch_size
            yield self.to_batch(df[start:end])

    def __len__(self) -> int:
        """Returns the total number of batches that can be returned."""
        full_batches = len(self.df) // self.batch_size
        if len(self.df) % self.batch_size:
            return full_batches + 1
        return full_batches

    def to_batch(self, rows: pd.DataFrame) -> TextBatch:
        """Converts the rows into a batch"""
        tokens = self.tokenizer(
            rows.text.tolist(),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
        ).to(self.device)
        labels = torch.tensor(rows.label.tolist(), dtype=torch.long, device=self.device)
        return {
            "input_ids": tokens["input_ids"],
            "attention_mask": tokens["attention_mask"],
            "labels": labels,
        }


class PastDataloader(TextDataloader):  # pylint: disable=too-few-public-methods
    """Provides a dataloader which converts the text into past tensors"""

    def __init__(
        self,
        df: pd.DataFrame,
        *,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        batch_size: int,
        max_length: int,
        label_map: Optional[Dict[str, int]] = None,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        if label_map:
            df = df.copy()
            df["label"] = df.label.map(label_map)
        super().__init__(
            df=df,
            tokenizer=tokenizer,
            batch_size=batch_size,
            max_length=max_length,
            device=device,
            shuffle=shuffle,
        )
        model.to(device)
        self.model = model

    @torch.no_grad()
    def to_batch(self, rows: pd.DataFrame) -> PastBatch:
        batch = super().to_batch(rows)
        past_key_values = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch.get("attention_mask", None),
        ).past_key_values
        return {
            "past_key_values": past_key_values,
            "attention_mask": batch["attention_mask"],
            "labels": batch["labels"],
        }

Then the training loop. We are operating over the raw model output instead of the tokenized version.

This has some subtlety as it creates the centroids and returns them after training. I’m also making quite a few adjustments to this during this post so I need to add flexibility for that.

Code
#collapse

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Tuple, Union

import torch
import numpy as np
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

LossFunction = Callable[[torch.Tensor, torch.Tensor, torch.nn.Parameter], torch.Tensor]
OptimizerFactory = Callable[[torch.nn.Parameter, torch.nn.Parameter], torch.optim.Optimizer]

class CentroidManager(ABC):
    @abstractmethod
    def make(
        self,
        *,
        dl: PastDataloader,
        model: AutoModelForCausalLM,
        prompt: torch.nn.Parameter,
        prompt_attention: torch.Tensor,
        device: torch.device,
    ) -> torch.nn.Parameter:
        """Make the initial centroids"""
        pass

    def update(
        self,
        *,
        centroids: torch.nn.Parameter,
        inputs: torch.Tensor,
        labels: torch.Tensor,
    ) -> None:
        """In place update to the centroids"""
        pass

class CentroidRandomStart(CentroidManager):
    """Just randomly initializes the centroids"""
    @torch.no_grad()
    def make(
        self,
        *,
        dl: PastDataloader,
        model: AutoModelForCausalLM,
        prompt: torch.nn.Parameter,
        prompt_attention: torch.Tensor,
        device: torch.device,
    ) -> torch.nn.Parameter:
        return torch.nn.Parameter(
            torch.rand(2, model.config.n_embd, device=device)
        )

class CentroidInitialOutput(CentroidManager):
    """Take a single item for each class and use the output as the centroid position"""
    @torch.no_grad()
    def make(
        self,
        *,
        dl: PastDataloader,
        model: AutoModelForCausalLM,
        prompt: torch.nn.Parameter,
        prompt_attention: torch.Tensor,
        device: torch.device,
    ) -> torch.nn.Parameter:
        for batch in dl:
            labels = batch["labels"]
            if 0 not in labels or 1 not in labels:
                continue

            zero_index = (labels == 0).long().argmax().item()
            one_index = (labels == 1).long().argmax().item()
            indexes = [zero_index, one_index]

            logits = _get_output_with_past(
                model=model,
                prompt=prompt,
                attention_mask=prompt_attention,
                past=batch["past_key_values"],
                past_attention_mask=batch["attention_mask"],
            )
            return torch.nn.Parameter(
                logits[[zero_index, one_index]].clone()
            )
        raise AssertionError("No batch with both labels found")


@dataclass
class TrainedPrompt:
    prompt: torch.Tensor
    centroids: torch.Tensor

    @staticmethod
    def load(folder: Path) -> TrainedPrompt:
        assert folder.exists()
        return TrainedPrompt(
            prompt=torch.load(folder / "prompt.pt"),
            centroids=torch.load(folder / "centroids.pt"),
        )

    def save(self, folder: Path) -> None:
        folder.mkdir(parents=True, exist_ok=True)
        torch.save(self.prompt, folder / "prompt.pt")
        torch.save(self.centroids, folder / "centroids.pt")

@dataclass
class TrainingStatistics:
    batches: np.ndarray
    labels: np.ndarray
    centroids: np.ndarray


def _make_optimizer(prompt: torch.nn.Parameter, centroids: torch.nn.Parameter) -> torch.optim.Optimizer:
    return torch.optim.Adam([prompt, centroids], lr=1e-3)

def train(
    *,
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_tokens: int,
    epochs: int,
    loss_fn: LossFunction,
    centroid_manager: CentroidManager,
    optimizer_factory: OptimizerFactory = _make_optimizer,
) -> Tuple[TrainedPrompt, TrainingStatistics]:
    """Train the prompt"""
    prompt, prompt_attention = _make_prompt(
        model=model,
        tokenizer=tokenizer,
        prompt_tokens=prompt_tokens,
        device=dl.device,
    )
    centroids = centroid_manager.make(
        dl=dl,
        model=model,
        prompt=prompt,
        prompt_attention=prompt_attention,
        device=dl.device
    )
    optimizer = optimizer_factory(prompt, centroids)

    total_loss = 0.0
    current_loss = 0.0
    stats_batches = []
    stats_labels = []
    stats_centroids = []
    bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] - {postfix[0]:>8.4f}"

    with tqdm(
        range(epochs), leave=False, bar_format=bar_format, postfix=[total_loss]
    ) as bar:
        for _epoch in bar:
            with tqdm(
                dl, leave=False, bar_format=bar_format, postfix=[current_loss]
            ) as epoch_bar:
                for batch in epoch_bar:
                    current_loss, current_batch, current_labels, current_centroids = _process(
                        batch=batch,
                        model=model,
                        optimizer=optimizer,
                        prompt=prompt,
                        prompt_attention=prompt_attention,
                        centroids=centroids,
                        loss_fn=loss_fn,
                        centroid_manager=centroid_manager,
                    )
                    total_loss += current_loss
                    if current_batch.shape[0] == dl.batch_size:
                        stats_batches.append(current_batch)
                        stats_labels.append(current_labels)
                        stats_centroids.append(current_centroids)
                    epoch_bar.postfix[0] = current_loss

            average_loss = total_loss / len(dl)
            bar.postfix[0] = average_loss
            print(f"Average loss: {average_loss:0.4f}")
            total_loss = 0.0

    return (
        TrainedPrompt(prompt=prompt.data, centroids=centroids.data),
        TrainingStatistics(
            batches=np.concatenate([
                batch[None, :]
                for batch in stats_batches
            ]),
            labels=np.concatenate([
                labels[None, :]
                for labels in stats_labels
            ]),
            centroids=np.concatenate([
                centroid[None, :]
                for centroid in stats_centroids
            ])
        )
    )


def _make_prompt(
    *,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_tokens: int,
    device: torch.device,
) -> Tuple[torch.nn.Parameter, torch.Tensor]:
    """Generate the prompt by randomly choosing tokens and then converting to embeddings"""
    prompt_indexes = torch.randint(
        size=(prompt_tokens,), low=0, high=tokenizer.vocab_size, device=device
    )
    prompt_attention = torch.ones(
        size=(1, prompt_tokens), dtype=torch.long, device=device
    )
    prompt = torch.nn.Parameter(
        model.transformer.wte(prompt_indexes).clone()[None, :, :]
    )
    return prompt, prompt_attention

def _process(
    *,
    batch: Dict[str, Union[torch.Tensor, Past]],
    model: AutoModelForCausalLM,
    optimizer: torch.optim.Optimizer,
    prompt: torch.nn.Parameter,
    prompt_attention: torch.Tensor,
    centroids: torch.nn.Parameter, 
    loss_fn: LossFunction,
    centroid_manager: CentroidManager,
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
    optimizer.zero_grad()
    
    current_centroids = (
        centroids.clone()
            .detach()
            .cpu()
            .numpy()
    )

    logits = _get_output_with_past(
        model=model,
        prompt=prompt,
        attention_mask=prompt_attention,
        past=batch["past_key_values"],
        past_attention_mask=batch["attention_mask"],
    )
    labels = batch["labels"]
    loss = loss_fn(logits, labels, centroids)

    loss.backward()
    optimizer.step()

    centroid_manager.update(
        centroids=centroids,
        inputs=logits,
        labels=labels,
    )
    
    current_logits = (
        logits.detach()
            .cpu()
            .numpy()
    )
    current_labels = (
        labels.detach()
            .cpu()
            .numpy()
    )

    return loss.item(), current_logits, current_labels, current_centroids


def _get_output_with_past(
    *,
    model: AutoModelForCausalLM,
    prompt: torch.nn.Parameter,
    attention_mask: torch.Tensor,
    past: Past,
    past_attention_mask: torch.Tensor,
) -> torch.Tensor:
    """Get the predictions for the next token after the prompt"""
    # concatenate the past attention with the prompt attention
    batch_size = past_attention_mask.shape[0]
    attention_mask = attention_mask.repeat_interleave(batch_size, dim=0)
    attention_mask = torch.cat([past_attention_mask, attention_mask], dim=-1)

    # expand the prompt to match the batch size
    input_ids = prompt.repeat_interleave(batch_size, dim=0)

    state = model.transformer(
        inputs_embeds=input_ids,
        attention_mask=attention_mask,
        past_key_values=past,
    ).last_hidden_state
    return state[:, -1]

Now a way to measure the accuracy is required. Here we can measure the distance to the two centroids and take the smaller value.

Code
#collapse

from dataclasses import dataclass
from sklearn.metrics import classification_report
from tqdm.auto import tqdm
import numpy as np

@dataclass
class LabelledOutputs:
    outputs: np.ndarray
    labels: np.ndarray
    predictions: np.ndarray

def generate_outputs(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: TrainedPrompt,
) -> LabelledOutputs:
    raw_outputs = []
    raw_predictions = []
    for current_outputs, current_predictions in iterate_outputs(
        dl=dl, model=model, prompt=prompt.prompt, centroids=prompt.centroids
    ):
        raw_outputs.append(
            current_outputs.cpu().numpy(),
        )
        raw_predictions.append(
            current_predictions.cpu().numpy(),
        )

    outputs = np.concatenate(raw_outputs)
    labels = dl.df.label.to_numpy()
    predictions = np.concatenate(raw_predictions)

    return LabelledOutputs(
        outputs=outputs,
        labels=labels,
        predictions=predictions,
    )

@torch.no_grad()
def iterate_outputs(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: torch.Tensor,
    centroids: torch.Tensor,
) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
    prompt_attention = torch.ones((1, prompt.shape[1]), device=dl.device)

    for batch in tqdm(dl):
        output = _get_output_with_past(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
        )
        predicted_labels = predictions(output, centroids)
        yield output, predicted_labels

@torch.no_grad()
def predictions(
    output: torch.Tensor,
    centroids: torch.nn.Parameter,
) -> torch.Tensor:
    distances = torch.cat([
        torch.pairwise_distance(
            output[idx],
            centroids
        )[None, :]
        for idx in range(output.shape[0])
    ])
    return distances.argmin(dim=1)

@torch.no_grad()
def accuracy(outputs: LabelledOutputs) -> None:
    print(classification_report(
        y_true=outputs.labels,
        y_pred=outputs.predictions,
        target_names=["bad", "good"],
        zero_division=0
    ))

Another thing I am going to introduce in this notebook is animating the training progress. To do this I am leaning heavily on this stack overflow post. The basic process is to map the centroids and batch points to the same 2 dimensions using PCA:

Code
#collapse

from __future__ import annotations
from dataclasses import dataclass, replace

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

CENTROID_COLORS = np.array([
    [1., 0, 0],
    [0, 0, 1.]
])
POINT_COLORS = np.array([
    [1., 0.5, 0.5],
    [0.5, 0.5, 1.]
])

@dataclass
class Points:
    points: np.ndarray
    colors: np.ndarray
    sizes: np.ndarray

    @staticmethod
    def make(
        points: np.ndarray,
        labels: np.ndarray,
        colors: np.ndarray = POINT_COLORS,
        size: float = 10.
    ) -> None:
        sizes = np.ones((points.shape[0])) * size
        return Points(
            points=points,
            colors=colors[labels],
            sizes=sizes,
        )

    def decay(self) -> None:
        self.colors += (1 - self.colors) * 0.1
        self.sizes *= 0.9

    @staticmethod
    def combine(*points: Points) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        all_points = np.concatenate([point.points for point in points])
        all_colors = np.concatenate([point.colors for point in points])
        all_sizes = np.concatenate([point.sizes for point in points])
        return all_points, all_colors, all_sizes


class AnimatedScatter(object):
    """An animated scatter plot using matplotlib.animations.FuncAnimation."""

    def __init__(self, centroids: np.ndarray, batches: np.ndarray, labels: np.ndarray) -> None:
        self.centroids = [
            Points.make(
                points=centroid,
                labels=np.array([0, 1]),
                colors=CENTROID_COLORS,
                size=100.
            )
            for centroid in centroids
        ]
        self.batches = [
            Points.make(points=batch, labels=label)
            for batch, label in zip(batches, labels)
        ]
        
        all_points, *_ = Points.combine(*self.centroids, *self.batches)
        self.axis = [
            all_points[:, 0].min() * 1.1, all_points[:, 0].max() * 1.1,
            all_points[:, 1].min() * 1.1, all_points[:, 1].max() * 1.1,
        ]

        # Setup the figure and axes...
        self.fig, self.ax = plt.subplots()
        # Then setup FuncAnimation.
        self.ani = animation.FuncAnimation(
            self.fig,
            self.update,
            interval=50,
            init_func=self.setup_plot,
            blit=False
        )

    def setup_plot(self):
        """Initial drawing of the scatter plot."""
        x, y = self.batches[0].points.T
        self.scat = self.ax.scatter(
            x,
            y,
        )
        self.ax.axis(self.axis)
        # For FuncAnimation's sake, we need to return the artist we'll be using
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat,

    def update(self, i):
        """Update the scatter plot."""
        all_points, all_colors, all_sizes = Points.combine(
            *self.batches[:i],
            self.centroids[i]
        )

        self.scat.set_offsets(all_points)
        self.scat.set_color(all_colors)
        self.scat.set_sizes(all_sizes)

        for batch in self.batches[:i]:
            batch.decay()

        # We need to return the updated artist for FuncAnimation to draw..
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat,
Code
#collapse

from sklearn.decomposition import PCA
import numpy as np
import pandas as pd

def fit(outputs: LabelledOutputs) -> PCA:
    pca = PCA(n_components=2)
    pca.fit(outputs.outputs)
    return pca

def animate(pca: PCA, training_statistics: TrainingStatistics, outputs: LabelledOutputs) -> str:
    transformed_centroids = (
        pca.transform(training_statistics.centroids.reshape(-1, 768))
            .reshape(-1, 2, 2)
    )
    transformed_batches = (
        pca.transform(training_statistics.batches.reshape(-1, 768))
            .reshape(-1, train_dataloader.batch_size, 2)
    )
    
    return AnimatedScatter(
        centroids=transformed_centroids,
        batches=transformed_batches,
        labels=training_statistics.labels
    ).ani.to_jshtml()

def visualize(pca: PCA, prompt: TrainedPrompt, outputs: LabelledOutputs) -> None:
    pca_output = pca.transform(outputs.outputs)
    output_df = pd.DataFrame(pca_output)
    output_df["label"] = outputs.labels
    output_df["color"] = output_df.label.map({0: "pink", 1: "lightblue"})
    output_df["size"] = 0.1

    pca_centroid = pca.transform(prompt.centroids.clone().detach().cpu().numpy())
    centroid_df = pd.DataFrame(pca_centroid)
    centroid_df["color"] = pd.Series(["red", "blue"])
    centroid_df["size"] = 10.

    pca_df = pd.concat([
        output_df,
        centroid_df,
    ]).reset_index(drop=True)

    pca_df.plot.scatter(x=0, y=1, c="color", s="size")

Now we can load the deep learning model and datasets

Code
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.to("cuda")
model.eval()

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # needed to enable padding
Code
import pandas as pd

train_df = pd.read_parquet("/data/sentiment/imdb-movie-reviews/train.gz.parquet")
validation_df = pd.read_parquet("/data/sentiment/imdb-movie-reviews/validation.gz.parquet")
Code
BATCH_SIZE = 32
MAX_LENGTH = 1_000

train_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=train_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=True,
    label_map={"bad": 0, "good": 1},
)
validation_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=False,
    label_map={"bad": 0, "good": 1},
)

Now we can try training.


Training

Let’s give it a go. I’m returning essentially all of the outputs and the centroids at each batch to see if I can animate the progress of the prompt as it trains. I think it will be fun, provided it’s not too memory hungry.

Random Centroid Start

In this I am just going to randomly initialize the centroid positions and see how we get on. While the centroids might start badly, hopefully they can be adjusted into a better position as training proceeds.

Code
from functools import partial

trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=partial(centroid_loss_fn, distance_factor=100),
    centroid_manager=CentroidRandomStart(),
)
Average loss: 2848.5156
Average loss: 434.2683
Average loss: 261.2136
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/random-start"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.53      0.44      0.48     12500
        good       0.52      0.61      0.56     12500

    accuracy                           0.52     25000
   macro avg       0.52      0.52      0.52     25000
weighted avg       0.52      0.52      0.52     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

We can now look at them in a few different ways - both statistically and visually.

It is clear from the classification report that this is performing very badly. The classifications that it produces are barely better than random chance. We can use the visualizations to explore why this might be.

The first visualization is the same PCA clustering that was explored before. We can see that the classes are intermingled. The larger blue dot is the position of the “good” centroid, and the red is the “bad” centroid. While the points are clustered around them they are not separated, and the centroids themselves are very close to each other.

The animation shows the centroid positions and the training batches over time. You can see that the training batches appear more clustered than the validation ones, as they consistently lie on the diagonal. The training batches do slightly close in on the centroids, however the effect is not large. Finally the centroids themselves do not visibly move during training.

Part of the problem statement is that the centroid position is not known in advance, so the fact that the centroids do not move is a problem. A bigger problem is that the centroids are not representative of the data - the position of the centroids is no where near the actual data. Let’s start by moving the centroids closer to the points.


Representative Centroid Start

This time I am going to take a value for each class, pass it through the model, and use it’s output as the centroid starting position. Doing this ensures that the centroids start in a position that is representative of at least one entry in the training set.

Would this approach be improved if the average of all entries in the traning set were used? I think that might result in both centroids being in a very similar position, as the prompt is totally untrained.

Would it be better to take a moving average of the center of each class, and remove the centroids from optimization? It might be - that is something to explore later.

Code
from functools import partial

trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=partial(centroid_loss_fn, distance_factor=100),
    centroid_manager=CentroidInitialOutput(),
)
Average loss: 1496.6118
Average loss: 963.0601
Average loss: 745.7831
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/representative-start"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.69      0.93      0.79     12500
        good       0.89      0.59      0.71     12500

    accuracy                           0.76     25000
   macro avg       0.79      0.76      0.75     25000
weighted avg       0.79      0.76      0.75     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

So this simple change has vastly improved the performance, and clearly distinguished the two centroids. This is seen in the validation data visualization which shows a strong distinction between red and blue. If anything the red zone (bad) is larger and that may contribute to the misclassifications.

There does seem to be a tiny amount of centroid movement. I wonder if the centroids would move to a better place for the classification if they were more able to?

Only Train Prompt

I think that the centroid approach has value. There are either bugs with it or the implementation needs quite a bit of tuning. To that end I’m going to try to make it simpler.

One way to make it simpler is to just train the prompt, or just train the centroids. Let’s start by just training the prompt.

Code
def make_optimizer_only_prompt(prompt: torch.nn.Parameter, centroids: torch.nn.Parameter) -> torch.optim.Optimizer:
    return torch.optim.Adam([prompt], lr=1e-3)
Code
from functools import partial

trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    # no point in including centroid repulsion in loss
    loss_fn=partial(centroid_loss_fn, distance_factor=100, repulsion_factor=0),
    centroid_manager=CentroidInitialOutput(),
    optimizer_factory=make_optimizer_only_prompt,
)
Average loss: 1214.2718
Average loss: 998.2953
Average loss: 893.0654
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/only-prompt"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.75      0.87      0.80     12500
        good       0.84      0.70      0.77     12500

    accuracy                           0.79     25000
   macro avg       0.80      0.79      0.79     25000
weighted avg       0.80      0.79      0.79     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

This actually performs better than the previous run. I think that this shows that the position of the centroid matters a great deal. The fact that the centroid cannot move to a better position is a problem.


Boids

I was thinking about this on my walk, and my concern about the asymmetrical nature of the two distance losses (which I will call \(loss_{attraction}\) and \(loss_{repulsion}\)) may be misplaced. The repulsive force becomes overwhelming if points move too close together and so there is an enforced minimum distance between points. I’m reminded of boids which are the simulated version of mumurations

murmuration

The simulation wants each bird to move to the center of the group while maintaining a minimum distance from the other members of the group. It uses a very similar approach to the two proposed loss functions, and the asymmetry does not cause a problem. I wonder how hard it would be to produce a boid simulation using these techniques?

Code
def boid_loss(boids: torch.Tensor) -> torch.Tensor:
    # This compares every boid to every other boid
    # This must not compare a boid to itself
    boid_count = boids.shape[0]

    # This creates the cartesian join of every index
    indexes = torch.tensor(range(boid_count))
    left_indexes = indexes.repeat(boid_count)
    right_indexes = indexes.repeat_interleave(boid_count)
    # This mask filters out the points where an index joins to itself
    not_same_mask = left_indexes != right_indexes

    # These are then the expanded aligned comparisons
    left = boids[left_indexes[not_same_mask]]
    right = boids[right_indexes[not_same_mask]]

    distances = torch.pairwise_distance(left, right)
    want_to_be_near_loss = distances.mean()
    want_to_be_far_loss = (1 / distances).mean()

    return want_to_be_near_loss + want_to_be_far_loss

def boid_train() -> np.ndarray:
    boids = torch.nn.Parameter(torch.rand(10, 2) * 10)
    
    stats = [boids.clone().detach().cpu().numpy()[None, :]]
    optimizer = torch.optim.Adam([boids], lr=1e-1)
    for _ in range(100):
        optimizer.zero_grad()
        loss = boid_loss(boids)
        loss.backward()
        optimizer.step()
        stats.append(boids.clone().detach().cpu().numpy()[None, :])

    return np.concatenate(stats)
Code
boid_positions = boid_train()
Code
#collapse

from __future__ import annotations
from dataclasses import dataclass, replace

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

@dataclass
class SimplePoint:
    points: np.ndarray
    colors: np.ndarray
    sizes: np.ndarray

    @staticmethod
    def make(
        points: np.ndarray,
        colors: np.ndarray = np.array([[.2, .2, .8]]),
        size: float = 10.
    ) -> None:
        sizes = np.ones((points.shape[0])) * size
        return Points(
            points=points,
            colors=colors[np.zeros(points.shape[0], dtype=int)],
            sizes=sizes,
        )

    def decay(self) -> None:
        self.colors += (1 - self.colors) * 0.1
        self.sizes *= 0.9

    @staticmethod
    def combine(*points: Points) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        all_points = np.concatenate([point.points for point in points])
        all_colors = np.concatenate([point.colors for point in points])
        all_sizes = np.concatenate([point.sizes for point in points])
        return all_points, all_colors, all_sizes


class SimpleAnimation(object):
    """An animated scatter plot using matplotlib.animations.FuncAnimation."""

    def __init__(self, batches: np.ndarray) -> None:
        self.batches = [
            SimplePoint.make(points=batch)
            for batch in batches
        ]
        self.axis = [
            batches[:, :, 0].min(), batches[:, :, 0].max(),
            batches[:, :, 1].min(), batches[:, :, 1].max(),
        ]

        # Setup the figure and axes...
        self.fig, self.ax = plt.subplots()
        # Then setup FuncAnimation.
        self.ani = animation.FuncAnimation(
            self.fig,
            self.update,
            interval=50,
            init_func=self.setup_plot,
            blit=False
        )

    def setup_plot(self):
        """Initial drawing of the scatter plot."""
        x, y = self.batches[0].points[:, 0], self.batches[0].points[:, 1]
        self.scat = self.ax.scatter(
            x,
            y,
        )
        self.ax.axis(self.axis)
        # For FuncAnimation's sake, we need to return the artist we'll be using
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat,

    def update(self, i):
        """Update the scatter plot."""
        all_points, all_colors, all_sizes = SimplePoint.combine(
            *self.batches[:i+1]
        )

        self.scat.set_offsets(all_points)
        self.scat.set_color(all_colors)
        self.scat.set_sizes(all_sizes)

        for batch in self.batches[:i+1]:
            batch.decay()

        # We need to return the updated artist for FuncAnimation to draw..
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat,
Code
#hide_output
from IPython.display import HTML

html = SimpleAnimation(batches=boid_positions).ani.to_jshtml()
# HTML(html)

So this works well. The learning rate for the position for the boids needs to be boosted quite extensively to get them moving like this. That’s something to tune when training the model.

I’m going to try moving to this approach.


Attract Me, Repel You

This time I’m going to create the centroid equivalent of the boids training, above. Each point will be attracted to the centroid for it’s class and repelled from the centroid of the other class.

Code
import torch

def attract_repel_loss_fn(
    output: torch.Tensor,
    labels: torch.Tensor,
    centroids: torch.nn.Parameter,
    attraction_factor: float = 1.,
    repulsion_factor: float = 1.,
) -> torch.Tensor:
    # only supports two classes for the moment
    
    attraction_targets = centroids[labels]
    repulsion_targets = centroids[(labels == 0).long()] # assumption: label is 0 or 1, this flips it
    
    attraction_distance = torch.pairwise_distance(output, attraction_targets)
    repulsion_distance = torch.pairwise_distance(output, repulsion_targets)

    attraction_loss = attraction_factor * attraction_distance.mean()
    repulsion_loss = (repulsion_factor / repulsion_distance).mean()
    
    return attraction_loss + repulsion_loss
Code
trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=attract_repel_loss_fn,
    centroid_manager=CentroidInitialOutput(),
)
Average loss: 8.6588
Average loss: 5.5559
Average loss: 4.3190
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/attract-repel"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.60      0.59      0.60     12500
        good       0.60      0.61      0.60     12500

    accuracy                           0.60     25000
   macro avg       0.60      0.60      0.60     25000
weighted avg       0.60      0.60      0.60     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)


Variable Learning Rate

The centroids really do not want to move. It’s very unlikely that the starting position for them is the best possible position for them. I’m going to try to dislodge them by increasing the learning rate for the centroid parameters.

Code
def make_optimizer_variable_lr(prompt: torch.nn.Parameter, centroids: torch.nn.Parameter) -> torch.optim.Optimizer:
    return torch.optim.Adam([
        {"params": prompt, "lr": 1e-3},
        {"params": centroids, "lr": 1e-1}
    ])
Code
trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=attract_repel_loss_fn,
    optimizer_factory=make_optimizer_variable_lr,
    centroid_manager=CentroidInitialOutput(),
)
Average loss: 6.9556
Average loss: 3.0133
Average loss: 2.4371
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/attract-repel-varying-lr"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.55      0.53      0.54     12500
        good       0.55      0.56      0.55     12500

    accuracy                           0.55     25000
   macro avg       0.55      0.55      0.55     25000
weighted avg       0.55      0.55      0.55     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

Well that’s pretty bad. They do move but it’s more of an oscillation and the classes don’t neatly separate. If anything the centroids get closer together, making it harder to separate the classes.


Cross Entropy Centroid Loss

The distance from an output to the two centroids is the main metric that we have. If I make the distance negative then the closest centroid will have the largest value. Then cross entropy loss can be used, as that takes a set of values and a target index, where the index of the highest value is the prediction.

I’m pretty pleased with this relevation. It handles the distance to both in a consistent way and it uses a well established loss metric. Tuning the centroid position may still be required.

Code
import torch

def cross_entropy_centroid_loss(
    output: torch.Tensor,
    labels: torch.Tensor,
    centroids: torch.nn.Parameter,
) -> torch.Tensor:
    # pairwise distance works on 2d tensors so have to iterate
    # see if cdist works?
    distances = torch.cat([
        torch.pairwise_distance(
            output[idx],
            centroids
        )[None, :]
        for idx in range(output.shape[0])
    ])
    return torch.nn.functional.cross_entropy(
        -distances,
        labels
    )
Code
trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=cross_entropy_centroid_loss,
    centroid_manager=CentroidInitialOutput(),
)
Average loss: 0.6868
Average loss: 0.4685
Average loss: 0.3410
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/cross-entropy"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.84      0.91      0.88     12500
        good       0.91      0.83      0.87     12500

    accuracy                           0.87     25000
   macro avg       0.87      0.87      0.87     25000
weighted avg       0.87      0.87      0.87     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

This is a really solid result. The metrics are the best yet and it still leaves an obvious avenue for improvement. When I look at the visualization of the points and the centroids I can see that the centroids are not even near the points.

I feel that moving to a calculated centroid would improve things substantially. I’m going to remove it from training and start updating it manually.


Calculated Centroids

I’m going to be using a few different inputs to the centroid calculation. The first is that the centroid will be the average position of all of the points in the class. Then there is some kind of momentum for the previous centroid position. Finally adding repulsion in to the other centroids would be good.

Code
import torch

def make_optimizer_only_prompt(prompt: torch.nn.Parameter, centroids: torch.nn.Parameter) -> torch.optim.Optimizer:
    return torch.optim.Adam([prompt], lr=1e-3)

@dataclass
class CalculatedCentroid(CentroidInitialOutput):
    repulsion: float = 0.1
    momentum: float = 0.9

    @torch.no_grad()
    def update(
        self,
        *,
        centroids: torch.nn.Parameter,
        inputs: torch.Tensor,
        labels: torch.Tensor,
    ) -> None:
        """In place update to the centroids"""
        updated_centroids = {}
        for label in labels.unique():
            center = self.centroid_center(
                outputs=inputs, labels=labels, label=label
            )
            repulsion = self.centroid_repulsion(
                centroids=centroids, index=label, factor=self.repulsion
            )
            updated_centroids[label] = center + repulsion

        for label, centroid in updated_centroids.items():
            centroids.data[label] = self.centroid_momentum(
                old=centroids.data[label],
                new=centroid,
                factor=self.momentum
            )

    def centroid_center(self, outputs: torch.Tensor, labels: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
        return outputs[labels == label].mean(dim=0)

    def centroid_repulsion(self, centroids: torch.Tensor, index: int, factor: float) -> torch.Tensor:
        indexes = [idx for idx in range(centroids.shape[0]) if idx != index]
        direction = centroids[indexes].mean(dim=0)
        return factor * direction / torch.norm(direction)

    def centroid_momentum(self, old: torch.Tensor, new: torch.Tensor, factor: float) -> torch.Tensor:
        return (old * factor) + (new * (1 - factor))
Code
trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=cross_entropy_centroid_loss,
    centroid_manager=CalculatedCentroid(repulsion=0.0),
    optimizer_factory=make_optimizer_only_prompt,
)
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/calculated-centroid-no-repulsion"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.53      0.86      0.66     12500
        good       0.64      0.25      0.36     12500

    accuracy                           0.56     25000
   macro avg       0.59      0.56      0.51     25000
weighted avg       0.59      0.56      0.51     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

Without repulsion the two points have moved together as the center of the outputs overlaps. I think this is because the outputs have not been separated by the prompt and then the centroid movements make it harder to separate them. Let’s try with a repulsive force between the centroids.

Code
trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=cross_entropy_centroid_loss,
    centroid_manager=CalculatedCentroid(momentum=0.99, repulsion=1.),
    optimizer_factory=make_optimizer_only_prompt,
)
Average loss: 0.9006
Average loss: 0.6347
Average loss: 0.5345
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/calculated-centroid-with-repulsion"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.90      0.58      0.71     12500
        good       0.69      0.94      0.79     12500

    accuracy                           0.76     25000
   macro avg       0.80      0.76      0.75     25000
weighted avg       0.80      0.76      0.75     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

So repulsion helps however the performance isn’t really on a par with the pure cross entropy loss training from before. I’ve realised that the current implementation doesn’t really do momentum, the different parameters are really more of a lethargy (reluctance to move). Maybe actually adding in momentum will help?

Centroid Momentum

Code
from dataclasses import field
import torch

ZERO = torch.zeros(768, device="cuda")

def make_optimizer_only_prompt(prompt: torch.nn.Parameter, centroids: torch.nn.Parameter) -> torch.optim.Optimizer:
    return torch.optim.Adam([prompt], lr=1e-3)

@dataclass
class CalculatedCentroid(CentroidInitialOutput):
    repulsion: float = 0.1
    momentum: float = 0.9

    impulse: Dict[int, torch.Tensor] = field(default_factory=dict)

    @torch.no_grad()
    def update(
        self,
        *,
        centroids: torch.nn.Parameter,
        inputs: torch.Tensor,
        labels: torch.Tensor,
    ) -> None:
        """In place update to the centroids"""
        centroid_impulses = {}
        for label in labels.unique().tolist():
            center = self.centroid_center(
                centroid=centroids[label], outputs=inputs, labels=labels, label=label
            )
            repulsion = self.centroid_repulsion(
                centroids=centroids, index=label, factor=self.repulsion
            )
            impulse = center + repulsion
            self.impulse[label] = self.calculate_impulse(label, impulse)

        for label in labels.unique().tolist():
            centroids.data[label] += self.impulse[label]

    def centroid_center(
        self,
        centroid: torch.Tensor,
        outputs: torch.Tensor,
        labels: torch.Tensor,
        label: torch.Tensor
    ) -> torch.Tensor:
        return outputs[labels == label].mean(dim=0) - centroid

    def centroid_repulsion(self, centroids: torch.Tensor, index: int, factor: float) -> torch.Tensor:
        indexes = [idx for idx in range(centroids.shape[0]) if idx != index]
        direction = centroids[indexes].mean(dim=0) - centroids[index]
        return factor * direction
        # return factor * direction / torch.norm(direction)

    def calculate_impulse(self, label: int, impulse: torch.Tensor) -> torch.Tensor:
        momentum = self.impulse.get(label, ZERO) * self.momentum
        impulse *= 1 - self.momentum
        return momentum + impulse
Code
trained_data, training_statistics = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=cross_entropy_centroid_loss,
    centroid_manager=CalculatedCentroid(momentum=0.9, repulsion=1.),
    optimizer_factory=make_optimizer_only_prompt,
)
Average loss: 1.3368
Average loss: 0.7696
Average loss: 0.6304
Code
trained_data.save(Path("/data/blog/2021-05-10-prompt-training-centroid/calculated-centroid-with-momentum"))
Code
outputs = generate_outputs(
    dl=validation_dataloader,
    model=model,
    prompt=trained_data
)
Code
accuracy(outputs)
              precision    recall  f1-score   support

         bad       0.85      0.76      0.81     12500
        good       0.79      0.87      0.83     12500

    accuracy                           0.82     25000
   macro avg       0.82      0.82      0.82     25000
weighted avg       0.82      0.82      0.82     25000
Code
pca_output = fit(outputs)
Code
visualize(pca=pca_output, prompt=trained_data, outputs=outputs)

Code
#hide_output
from IPython.display import HTML

html = animate(pca=pca_output, training_statistics=training_statistics, outputs=outputs)
# HTML(html)

I’m going to stop this for now. It’s pretty clear that the centroids are oscillating now and that’s probably the cause of the underlying inaccuracy. If the centroids moved with more purpose then the points could tune up on them.

Maybe I should only update the centroids every few batches?

Anyway I need to use my GPU for something else now so I’m shutting this down.