Using a Language Model to Unshuffle Words

Can we unshuffle the shuffled words we made?
Published

January 8, 2022

Code
import blog.transformers_logging

Recently I made a shuffler using quantum computing that could shuffle a list. At the end I used it to shuffle a word. If I can shuffle words, can I unshuffle them?

Shuffler

To try out the different approaches to deshuffling we must have a way to shuffle some text. The quantum shuffler from the previous posts is a bit slow for this analysis. Furthermore the approach used should work with any shuffled text, not just quantum shuffled text.

Code
# from src/main/python/blog/unshuffler/tokenize.py
from typing import Callable, Iterable, Iterator, List, TypeVar

Value = TypeVar("Value")


def to_words(text: str) -> List[str]:
    """
    This tokenizes the text into sequences of letters and sequences of non letters.
    """
    letter_groups = split(iter(text), str.isalpha)
    return ["".join(letters) for letters in letter_groups]


def split(
    iterable: Iterable[Value], condition: Callable[[Value], bool]
) -> Iterator[List[Value]]:
    state = False
    group = []

    for element in iterable:
        current_state = condition(element)
        if current_state != state:
            if group:
                yield group
                group = []
            state = current_state
        group.append(element)
    if group:
        yield group



# from src/main/python/blog/unshuffler/shuffle.py
from random import shuffle
from typing import Iterable


def shuffle_words(words: Iterable[str]) -> str:
    def shuffle_word(word: str) -> str:
        if not word.isalpha():
            return word
        letters = list(word)
        shuffle(letters)
        return "".join(letters)

    return "".join(shuffle_word(word) for word in words)
Code
shuffle_words(to_words("Call me Ishmael. Some years ago- never mind how long precisely-"))
'lalC me maIlhse. oSem yaser ago- renev dmni ohw lgon irseepylc-'

This should be good enough to test the different approaches.

Language Model Unshuffler

I like Language Models so I really just want to use them for this.

Language Models have been around for a long time. All recent models are based on neural networks and have significantly improved performance compared to what came before.

Given that this is reordering the letters in a word I want a model that works on a character level. Then it should be possible to restrict the output to only the valid characters and find the most likely word.

Luckily there is a byte level model called ByT5 (Xue et al. 2021). I’m hopeful that I can use this model to unscramble the words. Let’s start by seeing how it works.

Xue, Linting, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, and Colin Raffel. 2021. “ByT5: Towards a Token-Free Future with Pre-Trained Byte-to-Byte Models.” https://arxiv.org/abs/2105.13626.
Code
from transformers import T5ForConditionalGeneration
import torch

model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
model.eval() ; None
Code
text = list("I was walking my dog and".encode("utf-8"))

# +3 to handle the special tokens that the model uses
input_ids = torch.tensor([text]) + 3 

# these settings come from https://huggingface.co/blog/how-to-generate
output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    top_p=0.92,
    top_k=50,
)

# this is just the reverse of the encoding step, skipping special tokens
print(
    "".join([
        chr(letter - 3)
        for letter in output[0]
        if letter >= 3
    ])
)
ÿ stitching raved nail and caught better other ve

To be fair this output is pretty weird. I am guessing that this is the continuation of the text that I provided.

Given this I can still get it to work as a descrambler. I just have to fine tune it a bit. The model already appears to have the encoder-decoder structure that I desire and calculating loss is easy given the desirable labels.

Code
original_phrase = "Call me Ishmael. Some years ago- never mind how long precisely-"
shuffled_phrase = shuffle_words(to_words(original_phrase))

input_ids = torch.tensor([list(shuffled_phrase.encode("utf-8"))]) + 3
labels = torch.tensor([list(original_phrase.encode("utf-8"))]) + 3

model(input_ids=input_ids, labels=labels).loss
tensor(3.4138, grad_fn=<NllLossBackward>)

Once again I want to train this on Moby Dick. This time I will cut the text by lines and keep the first section as a test set.

Training

Code
from pathlib import Path

MODEL_NAME = "google/byt5-small"
BATCH_SIZE = 32
ROW_LENGTH = 128

MODEL_RUN_FOLDER = Path("/data/blog/2022-01-08-language-model-deshuffler/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)
Code
# from src/main/python/blog/unshuffler/language_model/dataset.py
from pathlib import Path
from typing import Any, Callable, Dict, List

import datasets
import pandas as pd
from transformers import AutoTokenizer


class DatasetLoader:
    def __init__(self, model_name: str, row_length: int) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.row_length = row_length

    def load_dataset(
        self, path: Path, shuffler: Callable[[str], str], start: int, end: int
    ) -> datasets.Dataset:
        text = path.read_text()
        df = pd.DataFrame(
            [text[index : index + self.row_length] for index in range(start, end)],
            columns=["text"],
        )
        df["shuffled_text"] = df.text.apply(shuffler)

        ds = datasets.Dataset.from_pandas(df)
        ds = ds.map(self.encode)
        ds = ds.remove_columns(["text", "shuffled_text"])
        return ds

    def encode(self, row: Dict[str, Any]) -> Dict[str, List[int]]:
        input_ids = self.tokenizer(
            row["shuffled_text"],
            return_attention_mask=False,
            return_token_type_ids=False,
            truncation=True,
            max_length=self.row_length,
        ).input_ids
        labels = self.tokenizer(
            row["text"],
            return_attention_mask=False,
            return_token_type_ids=False,
            truncation=True,
            max_length=self.row_length,
        ).input_ids
        return {
            "input_ids": input_ids,
            "labels": labels,
        }
Code
moby_dick = Path("/data/gutenberg/moby-dick.txt")

shuffler = lambda text: shuffle_words(to_words(text))

loader = DatasetLoader(model_name=MODEL_NAME, row_length=ROW_LENGTH)
train_ds = loader.load_dataset(moby_dick, shuffler=shuffler, start=2_000, end=12_000)
test_ds = loader.load_dataset(moby_dick, shuffler=shuffler, start=0, end=1_000)
Code
# from src/main/python/blog/unshuffler/language_model/metrics.py
from typing import Dict

from sklearn.metrics import accuracy_score
from transformers import EvalPrediction


def letter_accuracy(results: EvalPrediction) -> Dict[str, float]:
    predictions = results.predictions[0].argmax(axis=2).reshape(-1)
    targets = results.label_ids.reshape(-1)

    accuracy = accuracy_score(y_true=targets, y_pred=predictions)

    return {"accuracy": accuracy}
Code
#hide_output
from transformers import T5ForConditionalGeneration, AutoTokenizer

model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments, EvalPrediction

training_args = TrainingArguments(
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    # learning_rate=1e-6,
    learning_rate=5e-5,
    warmup_ratio=0.06,

    report_to=[], # you'd use wandb for weights and biases

    evaluation_strategy="steps",
    num_train_epochs=5,
    logging_steps=100,
    eval_steps=100,
    save_steps=100,

    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,

    # output_dir is compulsory
    logging_dir=MODEL_RUN_FOLDER / "output",
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    compute_metrics=letter_accuracy,
)

trainer.train()
[1565/1565 28:54, Epoch 5/5]
Step Training Loss Validation Loss Accuracy
100 2.094200 0.854648 0.755219
200 0.769700 0.457457 0.859805
300 0.432700 0.421420 0.878805
400 0.319200 0.418045 0.883109
500 0.263500 0.444210 0.884617
600 0.220500 0.458677 0.885742
700 0.188900 0.477641 0.886859
800 0.169000 0.490603 0.887414
900 0.152900 0.500098 0.889703
1000 0.138100 0.507994 0.891023
1100 0.126800 0.533735 0.890078
1200 0.118400 0.548336 0.889992
1300 0.111900 0.554877 0.889992
1400 0.109200 0.555312 0.889789
1500 0.106200 0.557491 0.890367

TrainOutput(global_step=1565, training_loss=0.3443936146867161, metrics={'train_runtime': 1734.9896, 'train_samples_per_second': 28.819, 'train_steps_per_second': 0.902, 'total_flos': 1.14843697152e+16, 'train_loss': 0.3443936146867161, 'epoch': 5.0})
Code
input_ids = torch.tensor([list(shuffled_phrase.encode("utf-8"))]) + 3
labels = torch.tensor([list(original_phrase.encode("utf-8"))]) + 3

with torch.no_grad():
    print(model(input_ids=input_ids.cuda(), labels=labels.cuda()).loss)
tensor(1.3793, device='cuda:0')
Code
original_phrase = "Call me Ishmael. Some years ago- never mind how long precisely-"
shuffled_phrase = shuffle_words(to_words(original_phrase))

input_ids = torch.tensor([list(shuffled_phrase.encode("utf-8"))]) + 3
labels = torch.tensor([list(original_phrase.encode("utf-8"))]) + 3

with torch.no_grad():
    logits = model(
        input_ids=input_ids.to(model.device),
        labels=labels.to(model.device)
    ).logits
    tokens = logits.argmax(dim=-1)
    unshuffled_phrase = tokenizer.decode(tokens[0])

print(original_phrase)
print(shuffled_phrase)
print(unshuffled_phrase)
Call me Ishmael. Some years ago- never mind how long precisely-
lalC em mehaIls. omeS ryase gao- vener idmn how glon ierlyepsc-
Aalleae Imleasls Some saars gga- eever dind wow nong prisisiss-

This is terrible! Oh how awful!!

Luckily this is not the final form. We know that the words that are produced have to be reordered versions of the input so we can restrict the output of the model to only the letters that are valid for the current word.

Code
model.eval() ; model.cpu() ; None
Code
# from src/main/python/blog/unshuffler/language_model/unshuffle/greedy.py
from typing import List

import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer


def greedy_unshuffle(
    words: List[str], model: AutoModel, tokenizer: AutoTokenizer
) -> str:
    text = "".join(words)
    result = ""
    for shuffled_word in tqdm(words):
        word = greedy_unshuffle_word(
            text=text,
            prefix=result,
            word=shuffled_word,
            model=model,
            tokenizer=tokenizer,
        )
        result += word
    return result


@torch.no_grad()
def greedy_unshuffle_word(
    text: str, prefix: str, word: str, model: AutoModel, tokenizer: AutoTokenizer
) -> str:
    if not word.isalpha():
        return word

    input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
    label_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(model.device)
    letter_ids = tokenizer(word, add_special_tokens=False).input_ids

    unshuffled_word = ""

    for _ in range(len(letter_ids) - 1):
        logits = model(input_ids=input_ids, labels=label_ids).logits
        letter_index = logits[0, -1, letter_ids].argmax().item()

        letter_id = letter_ids.pop(letter_index)
        label_ids = torch.cat(
            [
                label_ids,
                torch.tensor([[letter_id]], dtype=torch.long, device=model.device),
            ],
            dim=-1,
        )

        letter = tokenizer.decode(letter_id)
        unshuffled_word += letter

    unshuffled_word += tokenizer.decode(letter_ids[0])
    return unshuffled_word
Code
original_phrase = "Call me Ishmael. Some years ago- never mind how long precisely-"
shuffled_phrase = shuffle_words(to_words(original_phrase))
Code
unshuffled_phrase = greedy_unshuffle(
    words=to_words(shuffled_phrase),
    model=model,
    tokenizer=tokenizer
)

print(original_phrase)
print(shuffled_phrase)
print(unshuffled_phrase)

Call me Ishmael. Some years ago- never mind how long precisely-
allC em Ihlmsae. eoSm eyrsa goa- revne mdin hwo nlog ilsepcery-
Clal me Ihmaels. Soem saeyr goa- eevnr dinm who goln preiselcy-

This still performs poorly, but better than the attempt to predict everything at once. The greedy approach always takes the best prediction for the current token, which may be worse than traversing through a lower probability token. We can see the token

Here we can see that the first path takes the better first choice but ends up with a lower probability. This is because the second letter has a much stronger subsequent prediction probability.

To work out the best path through all of the letters we have to keep track of multiple paths at once. This is called a beam search, as each path is a beam.

Code
# from src/main/python/blog/unshuffler/language_model/unshuffle/beam.py
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer


def beam_unshuffle(
    words: List[str], model: AutoModel, tokenizer: AutoTokenizer, num_beams: int
) -> str:
    text = "".join(words)
    result = ""
    for shuffled_word in tqdm(words):
        word = Beam.unshuffle(
            text=text,
            prefix=result,
            word=shuffled_word,
            model=model,
            tokenizer=tokenizer,
            num_beams=num_beams,
        )
        result += word
    return result


@dataclass
class Beam:
    word: str
    log_probability: float
    letter_ids: List[int]
    input_ids: torch.Tensor
    label_ids: torch.Tensor

    @staticmethod
    def unshuffle(
        *,
        text: str,
        prefix: str,
        word: str,
        model: AutoModel,
        tokenizer: AutoTokenizer,
        num_beams: int,
    ) -> str:
        if not word.isalpha():
            return word

        input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
        label_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(model.device)
        letter_ids = tokenizer(word, add_special_tokens=False).input_ids

        beams = [
            Beam(
                word="",
                log_probability=0.0,
                letter_ids=letter_ids,
                input_ids=input_ids,
                label_ids=label_ids,
            )
        ]
        for _ in range(len(letter_ids)):
            beams = sum(
                [beam.step(model=model, tokenizer=tokenizer) for beam in beams],
                start=[],
            )
            beams = sorted(beams, key=Beam.get_probability, reverse=True)[:num_beams]
        return max(beams, key=Beam.get_probability).word

    def get_probability(self) -> float:
        return self.log_probability

    @torch.no_grad()
    def step(self, model: AutoModel, tokenizer: AutoTokenizer) -> List[Beam]:
        unique_letters = sorted(set(self.letter_ids))
        if len(unique_letters) == 1:
            return [
                self.extend(
                    letter_id=unique_letters[0],
                    letter_log_probability=0.0,
                    tokenizer=tokenizer,
                )
            ]

        logits = model(input_ids=self.input_ids, labels=self.label_ids).logits
        letters = logits[0, -1].softmax(dim=-1)[unique_letters]
        letter_probabilities, letter_indices = letters.sort()
        return [
            self.extend(
                letter_id=unique_letters[index],
                letter_log_probability=probability.log().item(),
                tokenizer=tokenizer,
            )
            for index, probability in zip(letter_indices, letter_probabilities)
        ]

    def extend(
        self,
        letter_id: int,
        letter_log_probability: float,
        tokenizer: AutoTokenizer,
    ) -> Beam:
        word = self.word + tokenizer.decode(letter_id)
        letter_ids = self.letter_ids[:]
        letter_ids.remove(letter_id)
        label_ids = torch.cat(
            [
                self.label_ids,
                torch.tensor(
                    [[letter_id]], dtype=torch.long, device=self.label_ids.device
                ),
            ],
            dim=-1,
        )
        log_probability = self.log_probability + letter_log_probability
        return Beam(
            word=word,
            log_probability=log_probability,
            letter_ids=letter_ids,
            input_ids=self.input_ids,
            label_ids=label_ids,
        )
Code
unshuffled_phrase = beam_unshuffle(
    words=to_words(shuffled_phrase),
    model=model,
    tokenizer=tokenizer,
    num_beams=10,
)

print(original_phrase)
print(shuffled_phrase)
print(unshuffled_phrase)

Call me Ishmael. Some years ago- never mind how long precisely-
allC em Ihlmsae. eoSm eyrsa goa- revne mdin hwo nlog ilsepcery-
lCal me Ihmaels. Soem saeyr goa- eevnr idnm who glon peerliysc-