Code
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
model.eval()
model.cuda() ; NoneNovember 15, 2021
When we talk we talk about relationships between things. “The cat sat on the mat” is the relationship of sitting between the cat and the mat. While this is a very simple view of language it does show that understanding language requires knowing what things are being discussed and what relationships they have to other things.
If we were to try to identify the things in a passage of text we might ask questions about it. This reminds me of the comprehension stests from school, which would feature a paragraph of text and then a few questions afterwards.
Those comprehension tests ask questions which are specific to the text. If we wanted to automate this process we would need general questions that we could ask about anything. This question is called a prompt.
If we can ask about the different things in the text then we can automate this process. The prompt can be used to ask about each of the things in the text. We can use a language model to do this.
import torch
@torch.no_grad()
def ask_question(text: str, question: str) -> str:
    text = " ".join(text.split())
    document = text + " " + question
    tokens = tokenizer(document, return_tensors="pt").input_ids
    tokens = tokens.to(model.device)
    torch.random.manual_seed(42)
    output = model.generate(
        tokens,
        min_length=tokens.shape[-1] + 20,
        max_length=tokens.shape[-1] + 50,
        do_sample=True,
        top_p=0.92,
        temperature=0.7,
        no_repeat_ngram_size=2,
    )
    response = tokenizer.batch_decode(output)[0]
    return response[len(document):].strip()It's the same thing as the term gin. It is a drink made from a sugar cane. The sugar is extracted from the sugar, and distilled into gin, which is then mixed with water. You can tell from your taste of the drink if(You can find an introduction to the generate method on the huggingface blog.)
This prompt got the model to describe the gin thing as a drink. I don’t think it’s made of sugar cane though. The problem here is the prompt. A better prompt would make it easier to understand what gin is.
Given that we can consult the model to ask about the term, how can we match the description to an unambiguous thing? To start with this we need a ontology (or taxonomy) of things.
We need a way to classify each thing to a specific class. These classes form an ontology of “conversational things”. The creation of this would be a major undertaking in itself, so finding one that already exists is essential. Wikipedia can be thought of as such an ontology as it has pages on most things. This means that the unit of classification is a wikipedia page.
For the example of Gin we would use the page on Gin.
Given this how can we match the mention of Gin to the wikipedia page? We could use the text on the wikipedia page and see what we get when we ask the same message. Here is the first paragraph from the Gin page with the same prompt:
print(
    ask_question(
        text="""
Gin originated as a medicinal liquor made by monks and alchemists across Europe,
particularly in southern France, Flanders and the Netherlands, to provide aqua
vita from distillates of grapes and grains. It then became an object of commerce
in the spirits industry. Gin emerged in England after the introduction of jenever,
a Dutch and Belgian liquor that was originally a medicine. Although this development
had been taking place since the early 17th century, gin became widespread after the
William of Orange-led 1688 Glorious Revolution and subsequent import restrictions
on French brandy. 
        """,
        question="What is gin?"
    )
)It is a simple syrup made of sugar and a salt. When sugar is added to gin, it is distilled to a syrup with the addition of water. As the sugar evaporates, the syrup becomes a glass of wine. The syrup is then usedThis is still talking about sugar so maybe there is something here? When considering if this would work we need to see how closely it matches the original output, and how likely a false match is. Given this we can probably find some other text that would produce the same output.
I would think that maple syrup could be described as a syrup with sugar. Let’s see how that does.
print(
    ask_question(
        text="""
Maple syrup is a syrup usually made from the xylem sap of sugar maple, red maple,
or black maple trees, although it can also be made from other maple species. In
cold climates, these trees store starch in their trunks and roots before winter;
the starch is then converted to sugar that rises in the sap in late winter and
early spring. Maple trees are tapped by drilling holes into their trunks and
collecting the sap, which is processed by heating to evaporate much of the water,
leaving the concentrated syrup. Most trees can produce 20 to 60 litres
(5 to 15 US gallons) of sap per season.
        """,
        question="What is maple syrup?"
    )
)It is composed of a variety of sugars and a number of aromatic and chemical compounds.
The maple tree contains a small amount of fructose. The fructose is what is used in sugar syrup, and the sugars are made up of different types of starchThat does contain sugar and syrup in the continuation, but it’s still quite different and seems to me to be a better description of maple syrup.
How does it handle another clear alcoholic spirit?
print(
    ask_question(
        text="""
Vodka is traditionally drunk "neat" (not mixed with water, ice, or other mixers),
and it is often served freezer chilled in the vodka belt of Belarus, Estonia,
Finland, Iceland, Latvia, Lithuania, Norway, Poland, Russia, Sweden, and Ukraine.
It is also used in cocktails and mixed drinks, such as the vodka martini,
Cosmopolitan, vodka tonic, screwdriver, greyhound, Black or White Russian, Moscow
mule, Bloody Mary, and Caesar. 
        """,
        question="What is vodka?"
    )
)It was introduced to the world in Russia in 1719 by Russian nobleman Alexander the Great, who claimed to have invented the first vodka, the "Bourbon", which is a mixture of sugar and water. The term was invented to describe theThese do have distinct outputs even when they use the same words. Based on this tiny sample we can see that this approach does produce distinct output for similar things so in theory we could classify these.
It’s quite an unweildy approach though and we can cover some of the downsides now.
This is using the model to generate text that follows the prompt. There are two downsides to this which we can review.
Generation like this involves repeatedly running the model to generate output. Every time we generate a token of output we rerun the model with that new token as our input.
We can see how slow this is by running the last example again:
%%timeit
ask_question(
    text="""
Gin originated as a medicinal liquor made by monks and alchemists across Europe,
particularly in southern France, Flanders and the Netherlands, to provide aqua
vita from distillates of grapes and grains. It then became an object of commerce
in the spirits industry. Gin emerged in England after the introduction of jenever,
a Dutch and Belgian liquor that was originally a medicine. Although this development
had been taking place since the early 17th century, gin became widespread after the
William of Orange-led 1688 Glorious Revolution and subsequent import restrictions
on French brandy. 
    """,
    question="What is gin?"
)420 ms ± 412 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)It’s difficult to see how slow this is without context, so here is the same model running the same text a single time. In the code I’m collecting the top 10 tokens to simulate running beam search, although the actual generation process is slightly different.
#collapse
import torch
@torch.no_grad()
def single_inference(text: str, question: str) -> str:
    text = " ".join(text.split())
    document = text + " " + question
    tokens = tokenizer(document, return_tensors="pt").input_ids
    tokens = tokens.to(model.device)
    torch.random.manual_seed(42)
    logits = model(tokens).logits
    # get top 10 tokens as if we were doing 10x beam search
    predictions = logits[0, -1].argsort(dim=0, descending=True)[:10]
    return tokenizer.batch_decode(predictions[:, None])%%timeit
single_inference(
    text="""
Gin originated as a medicinal liquor made by monks and alchemists across Europe,
particularly in southern France, Flanders and the Netherlands, to provide aqua
vita from distillates of grapes and grains. It then became an object of commerce
in the spirits industry. Gin emerged in England after the introduction of jenever,
a Dutch and Belgian liquor that was originally a medicine. Although this development
had been taking place since the early 17th century, gin became widespread after the
William of Orange-led 1688 Glorious Revolution and subsequent import restrictions
on French brandy. 
    """,
    question="What is gin?"
)8.28 ms ± 6.26 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)This is running on a GPU and it takes 50 times longer to generate the text than running the model once. Tying this back to a specific wikipedia page has not been explored yet, even if that part was fast and accurate this would still be a slow process.
The more that the model generates the more likely it is to produce junk. This is because the further it goes the more of the input consists of tokens that the model generated.
#collapse
import torch
@torch.no_grad()
def long_inference(text: str, question: str) -> str:
    text = " ".join(text.split())
    document = text + " " + question
    tokens = tokenizer(document, return_tensors="pt").input_ids
    tokens = tokens.to(model.device)
    torch.random.manual_seed(42)
    output = model.generate(
        tokens,
        min_length=1024,
        max_length=1024,
        do_sample=True,
        top_p=0.92,
        temperature=0.7,
        no_repeat_ngram_size=2,
    )
    return tokenizer.decode(output[0, -50:]) to sit down and enjoy is its color. Though, its flavor is very unique to this particular wine because it is also different because its wine color is different to its body color, not the color of your skin. With this, your body has toIn this example the model has moved to talking about the color of wine and your skin. This is less of a problem with the current approach as an excessively long generated output is less likely to be useful. Different generation approaches, such as translation, can suffer from this more.
If this is so slow and potentially inaccurate then is it actually practical to do?
The facebook GENRE model uses text generation to classify entities (Cao et al. 2021). It restricts the available output quite heavily but it does involve running the model at least once for each token in the input.
You can see this process in the following gif from the documentation:

As GENRE is directly attempting to resolve things we can see that generation is considered a viable means of language description. The restricted output of the model is directly to address the copy accuracy, so this is one possible fix for that problem.
If generating long form responses is too slow then how can it be made faster? Furthermore, can we make it easier to work out what wikipedia page is being linked?
To answer that let’s look back at the GENRE gif. The most important single step in this gif is this prediction:
At this point the model is describing the entity mention named Metropolis. The prediction is creating a signature of this entity according to the ranking of the predicted token. This means that establishing the wikipedia page that matches the signature is a case of matching one vector to another in some way.
This is the core behaviour that we want. Is it possible to get this wikipedia page signature without having to run all of the previous steps of the model?
The prediction can be generated with a single model invocation if we can produce the input that the model would’ve seen as it generated the output. In the GENRE output the brackets are used to delimit the entity mentions which is something that is not apparent in the source text. When we were first discussing how to predict the wikipedia page for something mentioned in a passage we asked a question about it. We could do the same here and then review what token the model wants to produce next.
from typing import List
import torch
@torch.no_grad()
def next_token(text: str, question: str) -> List[str]:
    text = " ".join(text.split())
    document = text + " " + question
    tokens = tokenizer(document, return_tensors="pt").input_ids
    tokens = tokens.to(model.device)
    output = model(tokens).logits
    top_tokens = output[0, -1].argsort(descending=True)[:15]
    return tokenizer.batch_decode(top_tokens[:, None])[' It',
 '\n',
 ' Well',
 ' A',
 ' What',
 ' The',
 ' Why',
 ' I',
 ' And',
 ' Gin',
 ' Is',
 ' You',
 ' How',
 ' We',
 ' That']This has produced some next tokens that are good at starting an explanation however they do not jump to the description of the wikipedia page. If we change the question then we could find a way to extract better tokens from the model. Since we want the wikipedia page, maybe we should ask for it?
[' "',
 ' gin',
 ' the',
 " '",
 ' The',
 ',',
 ' p',
 ' a',
 ':',
 ' P',
 ' Gin',
 ' it',
 ' g',
 ' G',
 '\n'][' "',
 ' p',
 ' the',
 " '",
 ' P',
 ' gin',
 ' The',
 ',',
 ' a',
 ':',
 ' it',
 ' Water',
 ' St',
 ' water',
 ' Guinness']These are an improvement over the previous question as the most salient tokens appear earlier on. Gin is described by the Gin page, while porter is described by the Porter (beer) page. As Guinness is a type of porter (previously being called Extra Superior Porter) the inclusion of it in the list is very encouraging.
Or we could try filling out the start of the answer to our question?
[' drink',
 ' mixture',
 ' combination',
 ' blend',
 ' beverage',
 ' gin',
 ' kind',
 ' name',
 ' cocktail',
 ' term',
 ' very',
 ' type',
 ' compound',
 ' liquid',
 ' form'][' gin',
 ' drink',
 ' blend',
 ' very',
 ' small',
 ' kind',
 ' mixture',
 ' beverage',
 ' p',
 ' cocktail',
 ' brand',
 ' type',
 ' fermented',
 ' water',
 ' combination']This final question has significantly improved the quality of the tokens.
There are still several problems with this technique. We have explored three different questions and found that one performs better than the others. It is reasonable to think that there are other questions that we could ask the model that would produce even better signatures. This is an area of research called prompt engineering, as we are prompting the model to produce output that solves the problem.
The second prompt most strongly recommends the " symbol, and the third prompt recommends type. Both of these tokens do not well describe the specific wikipedia page of interest. Since we are looking for a distinctive signature of a wikipedia page we should remove these low information tokens. To find them we could run the prompt over many different passages and entity mentions to work out which tokens are over predicted.
Since we moved away from the generative approach because it was slow, the speed of this approach is important. How fast can we run this?
%%timeit
next_token(
    text="""
Gin originated as a medicinal liquor made by monks and alchemists across Europe,
particularly in southern France, Flanders and the Netherlands, to provide aqua
vita from distillates of grapes and grains. It then became an object of commerce
in the spirits industry. Gin emerged in England after the introduction of jenever,
a Dutch and Belgian liquor that was originally a medicine. Although this development
had been taking place since the early 17th century, gin became widespread after the
William of Orange-led 1688 Glorious Revolution and subsequent import restrictions
on French brandy. 
    """,
    question="What is gin?"
)8 ms ± 3.36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)This matches the performance of invoking the model once, which we saw earlier. That means it would be 50x faster than the generative approach. Is this the fastest that we can manage?
If we can classify an entity mention in a single invocation of the model then it would seem that we cannot make the process faster. After all we can’t invoke the model less than once. There can be more than one entity mention in the text though, so we can improve the performance by classifying more than one entity in a single model invocation.
To do this we have to review the structure of the output. When we invoke the model on the input it produces output for every token in the input:
In a normal language model these per-token outputs are the predicted next token in the sequence. After all, this is what the GPT-2 language model was trained to do.
If we get output for every token then can we get a description of every entity mention in a single pass? I think so:
The highlighted outputs are the model predictions for the entity mention tokens. Could we use these as entity descriptions?
Unfortunately not. The language model was trained to predict the next token in the sequence, so the predictions for cat should include likes. To alter this would require retraining the model.
Retraining the model was done for GENRE so it’s possible to do.
If we retrain the model then it is no longer a language model, as it is not predicting the next token. Instead it has become a language describer as it describes each input token.