Code
import blog.transformers_logging
January 8, 2022
Recently I made a shuffler using quantum computing that could shuffle a list. At the end I used it to shuffle a word. If I can shuffle words, can I unshuffle them?
To try out the different approaches to deshuffling we must have a way to shuffle some text. The quantum shuffler from the previous posts is a bit slow for this analysis. Furthermore the approach used should work with any shuffled text, not just quantum shuffled text.
# from src/main/python/blog/unshuffler/tokenize.py
from typing import Callable, Iterable, Iterator, List, TypeVar
Value = TypeVar("Value")
def to_words(text: str) -> List[str]:
"""
This tokenizes the text into sequences of letters and sequences of non letters.
"""
letter_groups = split(iter(text), str.isalpha)
return ["".join(letters) for letters in letter_groups]
def split(
iterable: Iterable[Value], condition: Callable[[Value], bool]
) -> Iterator[List[Value]]:
state = False
group = []
for element in iterable:
current_state = condition(element)
if current_state != state:
if group:
yield group
group = []
state = current_state
group.append(element)
if group:
yield group
# from src/main/python/blog/unshuffler/shuffle.py
from random import shuffle
from typing import Iterable
def shuffle_words(words: Iterable[str]) -> str:
def shuffle_word(word: str) -> str:
if not word.isalpha():
return word
letters = list(word)
shuffle(letters)
return "".join(letters)
return "".join(shuffle_word(word) for word in words)
'lalC me maIlhse. oSem yaser ago- renev dmni ohw lgon irseepylc-'
This should be good enough to test the different approaches.
I can work this out by collating all of the words from some source and then calculating their frequency.
To use this we want a way to look up the word, so some kind of index is needed that normalizes the word so it can be easily found. Sorting the letters would be enough as that can be done even to shuffled words to produce the same set of letters. The anagrams of a set of letters can then be found.
# from src/main/python/blog/unshuffler/word_frequency/calculate.py
from typing import Iterable
import pandas as pd
def calculate_word_frequency(words: Iterable[str]) -> pd.DataFrame:
df = pd.DataFrame(
[word.casefold() for word in words if word.isalpha()],
columns=["word"],
)
df["count"] = 1
df = df.groupby("word").agg(len)
df = df.reset_index(drop=False)
df["normalized_word"] = df.word.apply(sorted).apply("".join)
df = df.set_index("word", drop=True)
return df.sort_values(by="count", ascending=False)
count | normalized_word | |
---|---|---|
word | ||
the | 14727 | eht |
of | 6747 | fo |
and | 6515 | adn |
a | 4805 | a |
to | 4709 | ot |
With this index we can look up all words that are an anagram of the sorted list of letters.
count | normalized_word | |
---|---|---|
word | ||
stop | 33 | opst |
spot | 23 | opst |
post | 22 | opst |
pots | 18 | opst |
tops | 4 | opst |
You can see the count of each of these words - if we take the one with the greatest count, we are choosing the most common anagram of the normalized word. We can then unshuffle the words by finding the most common anagram of each token.
# from src/main/python/blog/unshuffler/word_frequency/unshuffle.py
from typing import Iterable
import pandas as pd
def unshuffle_by_anagram_frequency(
words: Iterable[str], word_frequencies: pd.DataFrame
) -> str:
return "".join(
most_common_anagram(word, word_frequencies=word_frequencies) for word in words
)
def most_common_anagram(word: str, word_frequencies: pd.DataFrame) -> str:
if not word.isalpha():
return word
normalized_word = "".join(sorted(word.casefold()))
return (
word_frequencies[word_frequencies.normalized_word == normalized_word]
.sort_values(by="count", ascending=False)
.index[0]
)
original_phrase = "Call me Ishmael. Some years ago- never mind how long precisely-"
shuffled_phrase = shuffle_words(to_words(original_phrase))
unshuffled_phrase = unshuffle_by_anagram_frequency(
to_words(shuffled_phrase), word_frequencies
)
print(original_phrase)
print(shuffled_phrase)
print(unshuffled_phrase)
Call me Ishmael. Some years ago- never mind how long precisely-
lCal me hsImeal. emoS reyas gao- evrne idmn how gonl pleyisecr-
call me ishmael. some years ago- never mind who long precisely-
This is using the frequency of the letters in Moby Dick itself, so it is reasonable to say this is the best unshuffler that could be constructed for this text. It still gets one word wrong (how -> who).
The problem here is that the relationship between the words is missing. In the unshuffled text the word who does not best fit the sentence. We need a way to couple this decoder with understanding of the text itself.
I like Language Models so I really just want to use them for this.
Language Models have been around for a long time. All recent models are based on neural networks and have significantly improved performance compared to what came before.
Given that this is reordering the letters in a word I want a model that works on a character level. Then it should be possible to restrict the output to only the valid characters and find the most likely word.
Luckily there is a byte level model called ByT5 (Xue et al. 2021). I’m hopeful that I can use this model to unscramble the words. Let’s start by seeing how it works.
text = list("I was walking my dog and".encode("utf-8"))
# +3 to handle the special tokens that the model uses
input_ids = torch.tensor([text]) + 3
# these settings come from https://huggingface.co/blog/how-to-generate
output = model.generate(
input_ids,
do_sample=True,
max_length=50,
top_p=0.92,
top_k=50,
)
# this is just the reverse of the encoding step, skipping special tokens
print(
"".join([
chr(letter - 3)
for letter in output[0]
if letter >= 3
])
)
ÿ stitching raved nail and caught better other ve
To be fair this output is pretty weird. I am guessing that this is the continuation of the text that I provided.
Given this I can still get it to work as a descrambler. I just have to fine tune it a bit. The model already appears to have the encoder-decoder structure that I desire and calculating loss is easy given the desirable labels.
original_phrase = "Call me Ishmael. Some years ago- never mind how long precisely-"
shuffled_phrase = shuffle_words(to_words(original_phrase))
input_ids = torch.tensor([list(shuffled_phrase.encode("utf-8"))]) + 3
labels = torch.tensor([list(original_phrase.encode("utf-8"))]) + 3
model(input_ids=input_ids, labels=labels).loss
tensor(3.4138, grad_fn=<NllLossBackward>)
Once again I want to train this on Moby Dick. This time I will cut the text by lines and keep the first section as a test set.
# from src/main/python/blog/unshuffler/language_model/dataset.py
from pathlib import Path
from typing import Any, Callable, Dict, List
import datasets
import pandas as pd
from transformers import AutoTokenizer
class DatasetLoader:
def __init__(self, model_name: str, row_length: int) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.row_length = row_length
def load_dataset(
self, path: Path, shuffler: Callable[[str], str], start: int, end: int
) -> datasets.Dataset:
text = path.read_text()
df = pd.DataFrame(
[text[index : index + self.row_length] for index in range(start, end)],
columns=["text"],
)
df["shuffled_text"] = df.text.apply(shuffler)
ds = datasets.Dataset.from_pandas(df)
ds = ds.map(self.encode)
ds = ds.remove_columns(["text", "shuffled_text"])
return ds
def encode(self, row: Dict[str, Any]) -> Dict[str, List[int]]:
input_ids = self.tokenizer(
row["shuffled_text"],
return_attention_mask=False,
return_token_type_ids=False,
truncation=True,
max_length=self.row_length,
).input_ids
labels = self.tokenizer(
row["text"],
return_attention_mask=False,
return_token_type_ids=False,
truncation=True,
max_length=self.row_length,
).input_ids
return {
"input_ids": input_ids,
"labels": labels,
}
moby_dick = Path("/data/gutenberg/moby-dick.txt")
shuffler = lambda text: shuffle_words(to_words(text))
loader = DatasetLoader(model_name=MODEL_NAME, row_length=ROW_LENGTH)
train_ds = loader.load_dataset(moby_dick, shuffler=shuffler, start=2_000, end=12_000)
test_ds = loader.load_dataset(moby_dick, shuffler=shuffler, start=0, end=1_000)
# from src/main/python/blog/unshuffler/language_model/metrics.py
from typing import Dict
from sklearn.metrics import accuracy_score
from transformers import EvalPrediction
def letter_accuracy(results: EvalPrediction) -> Dict[str, float]:
predictions = results.predictions[0].argmax(axis=2).reshape(-1)
targets = results.label_ids.reshape(-1)
accuracy = accuracy_score(y_true=targets, y_pred=predictions)
return {"accuracy": accuracy}
from pathlib import Path
from transformers import Trainer, TrainingArguments, EvalPrediction
training_args = TrainingArguments(
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
# learning_rate=1e-6,
learning_rate=5e-5,
warmup_ratio=0.06,
report_to=[], # you'd use wandb for weights and biases
evaluation_strategy="steps",
num_train_epochs=5,
logging_steps=100,
eval_steps=100,
save_steps=100,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
# output_dir is compulsory
logging_dir=MODEL_RUN_FOLDER / "output",
output_dir=MODEL_RUN_FOLDER / "output",
overwrite_output_dir=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=tokenizer,
compute_metrics=letter_accuracy,
)
trainer.train()
Step | Training Loss | Validation Loss | Accuracy |
---|---|---|---|
100 | 2.094200 | 0.854648 | 0.755219 |
200 | 0.769700 | 0.457457 | 0.859805 |
300 | 0.432700 | 0.421420 | 0.878805 |
400 | 0.319200 | 0.418045 | 0.883109 |
500 | 0.263500 | 0.444210 | 0.884617 |
600 | 0.220500 | 0.458677 | 0.885742 |
700 | 0.188900 | 0.477641 | 0.886859 |
800 | 0.169000 | 0.490603 | 0.887414 |
900 | 0.152900 | 0.500098 | 0.889703 |
1000 | 0.138100 | 0.507994 | 0.891023 |
1100 | 0.126800 | 0.533735 | 0.890078 |
1200 | 0.118400 | 0.548336 | 0.889992 |
1300 | 0.111900 | 0.554877 | 0.889992 |
1400 | 0.109200 | 0.555312 | 0.889789 |
1500 | 0.106200 | 0.557491 | 0.890367 |
TrainOutput(global_step=1565, training_loss=0.3443936146867161, metrics={'train_runtime': 1734.9896, 'train_samples_per_second': 28.819, 'train_steps_per_second': 0.902, 'total_flos': 1.14843697152e+16, 'train_loss': 0.3443936146867161, 'epoch': 5.0})
tensor(1.3793, device='cuda:0')
original_phrase = "Call me Ishmael. Some years ago- never mind how long precisely-"
shuffled_phrase = shuffle_words(to_words(original_phrase))
input_ids = torch.tensor([list(shuffled_phrase.encode("utf-8"))]) + 3
labels = torch.tensor([list(original_phrase.encode("utf-8"))]) + 3
with torch.no_grad():
logits = model(
input_ids=input_ids.to(model.device),
labels=labels.to(model.device)
).logits
tokens = logits.argmax(dim=-1)
unshuffled_phrase = tokenizer.decode(tokens[0])
print(original_phrase)
print(shuffled_phrase)
print(unshuffled_phrase)
Call me Ishmael. Some years ago- never mind how long precisely-
lalC em mehaIls. omeS ryase gao- vener idmn how glon ierlyepsc-
Aalleae Imleasls Some saars gga- eever dind wow nong prisisiss-
This is terrible! Oh how awful!!
Luckily this is not the final form. We know that the words that are produced have to be reordered versions of the input so we can restrict the output of the model to only the letters that are valid for the current word.
# from src/main/python/blog/unshuffler/language_model/unshuffle/greedy.py
from typing import List
import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
def greedy_unshuffle(
words: List[str], model: AutoModel, tokenizer: AutoTokenizer
) -> str:
text = "".join(words)
result = ""
for shuffled_word in tqdm(words):
word = greedy_unshuffle_word(
text=text,
prefix=result,
word=shuffled_word,
model=model,
tokenizer=tokenizer,
)
result += word
return result
@torch.no_grad()
def greedy_unshuffle_word(
text: str, prefix: str, word: str, model: AutoModel, tokenizer: AutoTokenizer
) -> str:
if not word.isalpha():
return word
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
label_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(model.device)
letter_ids = tokenizer(word, add_special_tokens=False).input_ids
unshuffled_word = ""
for _ in range(len(letter_ids) - 1):
logits = model(input_ids=input_ids, labels=label_ids).logits
letter_index = logits[0, -1, letter_ids].argmax().item()
letter_id = letter_ids.pop(letter_index)
label_ids = torch.cat(
[
label_ids,
torch.tensor([[letter_id]], dtype=torch.long, device=model.device),
],
dim=-1,
)
letter = tokenizer.decode(letter_id)
unshuffled_word += letter
unshuffled_word += tokenizer.decode(letter_ids[0])
return unshuffled_word
Call me Ishmael. Some years ago- never mind how long precisely-
allC em Ihlmsae. eoSm eyrsa goa- revne mdin hwo nlog ilsepcery-
Clal me Ihmaels. Soem saeyr goa- eevnr dinm who goln preiselcy-
This still performs poorly, but better than the attempt to predict everything at once. The greedy approach always takes the best prediction for the current token, which may be worse than traversing through a lower probability token. We can see the token
Here we can see that the first path takes the better first choice but ends up with a lower probability. This is because the second letter has a much stronger subsequent prediction probability.
To work out the best path through all of the letters we have to keep track of multiple paths at once. This is called a beam search, as each path is a beam.
# from src/main/python/blog/unshuffler/language_model/unshuffle/beam.py
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
def beam_unshuffle(
words: List[str], model: AutoModel, tokenizer: AutoTokenizer, num_beams: int
) -> str:
text = "".join(words)
result = ""
for shuffled_word in tqdm(words):
word = Beam.unshuffle(
text=text,
prefix=result,
word=shuffled_word,
model=model,
tokenizer=tokenizer,
num_beams=num_beams,
)
result += word
return result
@dataclass
class Beam:
word: str
log_probability: float
letter_ids: List[int]
input_ids: torch.Tensor
label_ids: torch.Tensor
@staticmethod
def unshuffle(
*,
text: str,
prefix: str,
word: str,
model: AutoModel,
tokenizer: AutoTokenizer,
num_beams: int,
) -> str:
if not word.isalpha():
return word
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
label_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(model.device)
letter_ids = tokenizer(word, add_special_tokens=False).input_ids
beams = [
Beam(
word="",
log_probability=0.0,
letter_ids=letter_ids,
input_ids=input_ids,
label_ids=label_ids,
)
]
for _ in range(len(letter_ids)):
beams = sum(
[beam.step(model=model, tokenizer=tokenizer) for beam in beams],
start=[],
)
beams = sorted(beams, key=Beam.get_probability, reverse=True)[:num_beams]
return max(beams, key=Beam.get_probability).word
def get_probability(self) -> float:
return self.log_probability
@torch.no_grad()
def step(self, model: AutoModel, tokenizer: AutoTokenizer) -> List[Beam]:
unique_letters = sorted(set(self.letter_ids))
if len(unique_letters) == 1:
return [
self.extend(
letter_id=unique_letters[0],
letter_log_probability=0.0,
tokenizer=tokenizer,
)
]
logits = model(input_ids=self.input_ids, labels=self.label_ids).logits
letters = logits[0, -1].softmax(dim=-1)[unique_letters]
letter_probabilities, letter_indices = letters.sort()
return [
self.extend(
letter_id=unique_letters[index],
letter_log_probability=probability.log().item(),
tokenizer=tokenizer,
)
for index, probability in zip(letter_indices, letter_probabilities)
]
def extend(
self,
letter_id: int,
letter_log_probability: float,
tokenizer: AutoTokenizer,
) -> Beam:
word = self.word + tokenizer.decode(letter_id)
letter_ids = self.letter_ids[:]
letter_ids.remove(letter_id)
label_ids = torch.cat(
[
self.label_ids,
torch.tensor(
[[letter_id]], dtype=torch.long, device=self.label_ids.device
),
],
dim=-1,
)
log_probability = self.log_probability + letter_log_probability
return Beam(
word=word,
log_probability=log_probability,
letter_ids=letter_ids,
input_ids=self.input_ids,
label_ids=label_ids,
)
Call me Ishmael. Some years ago- never mind how long precisely-
allC em Ihlmsae. eoSm eyrsa goa- revne mdin hwo nlog ilsepcery-
lCal me Ihmaels. Soem saeyr goa- eevnr idnm who glon peerliysc-