Conversation with a House

A chatbot that thinks it is a house
Published

April 2, 2021

My son wants to talk to the house that we live in. He is very interested in it and has made a brain for it out of clay. To see if the brain works he wants to ask questions to the house.

His initial attempt was to make some ears out of clay and then have a radio for the house to speak with. That hasn’t worked, so instead I am going to make a conversational AI to talk to him. Part of this conversation is taking on the persona of being a house.

Taking on the persona of a non-human is famously hard (What is it like to be a Bat? discusses this). So if I was to try to shape the model to better perform this I might have to train it on literature with conversations between non humans (like the culture series). Ultimately this is for a young child so there is quite a bit of leeway available.

Let’s start by just getting the conversational ai from huggingface working. What is nice about this model is that it can take on a persona as part of the conversation. One of the biggest problems with this is that the model uses quite an old version of transformers so I should expect a small amount of modification. The original huggingface code is available in this github repository.


Initial Code Review

The most important script is the interact.py which allows you to freely interact with a persona. This then uses the utils.py and train.py scripts.

I am going to lift out each part in turn and discuss them. The first part of the code is where it parses the command line arguments. These then form the settings for the conversation, so we need to have a way to store them.

Code
from dataclasses import dataclass
import torch

@dataclass
class Settings:
    min_length: int = 1  # minimum response length in tokens
    max_length: int = 20 # maximum response length in tokens
    max_history: int = 2 # number of human utterances remembered
    device: torch.device = torch.device("cpu")

# temperature, top_k and top_p are used to perform top-k and nucleus (top_p) sampling.
# This is a successor to beam search which tries to more accurately reflect the variance of actual speech.
# Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    temperature: float = 0.7
    top_k: int = 0
    top_p: float = 0.9

    no_sample: bool = False # just use greedy decoding instead of sampling

So these settings are pretty straightforward, and the argument parser has provided nice defaults for all of them. Lets move on to the model.

The repository provides a pretrained OpenAI GPT (not GPT2) model. This suggests to me that the quality may suffer and that retraining on GPT2 would be productive. Let’s load the model and tokenizer.

Code
import tarfile
from typing import Optional
from pathlib import Path

from transformers import cached_path, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer

HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/gpt_personachat_cache.tar.gz"
MODEL_CHATBOT_FOLDER = None # can set this to a folder if you want, otherwise will use huggingface cache

def download_model(cache_dir: Optional[Path] = None) -> Path:
    archive = cached_path(HF_FINETUNED_MODEL, cache_dir=cache_dir)
    expanded = Path(archive).parent / "expanded"
    expanded.mkdir(exist_ok=True, parents=True)
    with tarfile.open(archive, "r:gz") as archive:
        archive.extractall(expanded)
    return expanded


MODEL_CHATBOT = download_model(MODEL_CHATBOT_FOLDER)
tokenizer = OpenAIGPTTokenizer.from_pretrained(MODEL_CHATBOT)
model = OpenAIGPTLMHeadModel.from_pretrained(MODEL_CHATBOT)
Some weights of the model checkpoint at /home/matthew/.cache/huggingface/transformers/expanded were not used when initializing OpenAIGPTLMHeadModel: ['multiple_choice_head.summary.weight', 'multiple_choice_head.summary.bias']
- This IS expected if you are initializing OpenAIGPTLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing OpenAIGPTLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

The conversational model is actually quite interesting. It uses special tokens to encode the different parts of the input.

We need to add support for these to the tokenizer. Unfortunately mutating the tokenizer like this means we cannot use the …Fast tokenizer.

Code
ATTR_TO_SPECIAL_TOKEN = {
    'bos_token': '<bos>',
    'eos_token': '<eos>',
    'pad_token': '<pad>',
    'additional_special_tokens': ['<speaker1>', '<speaker2>']
}

def add_special_tokens_(model: OpenAIGPTLMHeadModel, tokenizer: OpenAIGPTTokenizer) -> None:
    """ Add special tokens to the tokenizer and the model if they have not already been added. """
    orig_num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) # doesn't add if they are already there
    if num_added_tokens > 0:
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)

add_special_tokens_(model, tokenizer)

Now we can get onto the personality. This is defined as a few lines of I statements. The first example from the dataset (PERSONA-CHAT) is as follows:

i like to remodel homes. i like to go hunting. i like to shoot a bow. my favorite holiday is halloween.

So lets tokenize this.

Code
personality = [
    tokenizer.encode(line)
    for line in [
        "i like to remodel homes.",
        "i like to go hunting.",
        "i like to shoot a bow.",
        "my favorite holiday is halloween."
    ]
]

print("\n".join(
    tokenizer.decode(line)
    for line in personality
))
i like to remodel homes.
i like to go hunting.
i like to shoot a bow.
my favorite holiday is halloween.

The “final” part is just to copy the core conversational code over. I did all this separately and then found that only a single line had to change. I can emphasize the line separately:

In sample_sequences, which is the core part of the code, the model is run over the appropriately prepared input. First the tokenizer is used to encode the input, and then the tokens are passed to the model. The transformers codebase has changed since version 2.xx and the model now returns an object. The sample_sequences wants the logits (raw output) from the model and collects them incorrectly:

input_ids = torch.tensor(
    instance["input_ids"], device=args.device
).unsqueeze(0)
token_type_ids = torch.tensor(
    instance["token_type_ids"], device=args.device
).unsqueeze(0)

logits = model(input_ids, token_type_ids=token_type_ids)
if isinstance(logits, tuple):  # for gpt2 and maybe others
    logits = logits[0]

The fix is simple - you can just collect the logits with the .logits accessor:

input_ids = torch.tensor(
    instance["input_ids"], device=settings.device
).unsqueeze(0)
token_type_ids = torch.tensor(
    instance["token_type_ids"], device=settings.device
).unsqueeze(0)

logits = model(input_ids, token_type_ids=token_type_ids).logits

So here is all the conversational code in one big block:

Code
import warnings
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer

SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"]

def converse(
    personality: str,
    tokenizer: OpenAIGPTTokenizer,
    model: OpenAIGPTLMHeadModel,
    settings: Settings = Settings(),
) -> None:
    encoded_personality = [
        tokenizer.encode(line.strip().casefold())
        for line in personality.splitlines()
        if line.strip()
    ]

    history = []
    try:
        while True:
            raw_text = input("% ")
            while not raw_text:
                print("Prompt should not be empty!")
                raw_text = input("% ")
            if raw_text.strip() == "quit":
                break
            history.append(tokenizer.encode(raw_text))
            with torch.no_grad():
                out_ids = sample_sequence(
                    personality=encoded_personality,
                    history=history,
                    tokenizer=tokenizer,
                    model=model,
                    settings=settings,
                )
            history.append(out_ids)
            history = history[-(2 * settings.max_history + 1) :]
            out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
            print(out_text)
    except KeyboardInterrupt:
        pass


def sample_sequence(
    *,
    personality: List[List[int]],
    history: List[List[int]],
    tokenizer: OpenAIGPTTokenizer,
    model: OpenAIGPTLMHeadModel,
    settings: Settings = Settings(),
    current_output: Optional[List[int]] = None,
):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    if current_output is None:
        current_output = []

    for i in range(settings.max_length):
        instance = build_input_from_segments(
            personality=personality,
            history=history,
            reply=current_output,
            tokenizer=tokenizer,
            with_eos=False,
        )

        input_ids = torch.tensor(
            instance["input_ids"], device=settings.device
        ).unsqueeze(0)
        token_type_ids = torch.tensor(
            instance["token_type_ids"], device=settings.device
        ).unsqueeze(0)

        logits = model(input_ids, token_type_ids=token_type_ids).logits
        if isinstance(logits, tuple):  # for gpt2 and maybe others
            logits = logits[0]
        logits = logits[0, -1, :] / settings.temperature
        logits = top_filtering(logits, top_k=settings.top_k, top_p=settings.top_p)
        probs = F.softmax(logits, dim=-1)

        prev = (
            torch.topk(probs, 1)[1]
            if settings.no_sample
            else torch.multinomial(probs, 1)
        )
        if i < settings.min_length and prev.item() in special_tokens_ids:
            while prev.item() in special_tokens_ids:
                if probs.max().item() == 1:
                    warnings.warn(
                        "Warning: model generating special token with probability 1."
                    )
                    break  # avoid infinitely looping over special token
                prev = torch.multinomial(probs, num_samples=1)

        if prev.item() in special_tokens_ids:
            break
        current_output.append(prev.item())

    return current_output


def build_input_from_segments(
    *,
    personality: List[List[int]],
    history,
    reply,
    tokenizer: OpenAIGPTTokenizer,
    lm_labels: bool = False,
    with_eos: bool = True,
) -> Dict[str, Any]:
    """ Build a sequence of input from 3 segments: personality, history and last reply. """
    bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1])
    sequence = (
        [[bos] + list(chain(*personality))]
        + history
        + [reply + ([eos] if with_eos else [])]
    )
    sequence = [sequence[0]] + [
        [speaker2 if (len(sequence) - i) % 2 else speaker1] + s
        for i, s in enumerate(sequence[1:])
    ]
    instance = {}
    instance["input_ids"] = list(chain(*sequence))
    instance["token_type_ids"] = [
        speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s
    ]
    instance["mc_token_ids"] = len(instance["input_ids"]) - 1
    instance["lm_labels"] = [-100] * len(instance["input_ids"])
    if lm_labels:
        instance["lm_labels"] = (
            ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:]
        )
    return instance


def top_filtering(
    logits: torch.Tensor,
    top_k: float = 0.0,
    top_p: float = 0.9,
    threshold: float = -float("Inf"),
    filter_value: float = -float("Inf"),
):
    """Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
    Args:
        logits: logits distribution shape (vocabulary size)
        top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
        top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
            whose total probability mass is greater than or equal to the threshold top_p.
            In practice, we select the highest probability tokens whose cumulative probability mass exceeds
            the threshold top_p.
        threshold: a minimal threshold to keep logits
    """
    assert (
        logits.dim() == 1
    )  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(
            F.softmax(sorted_logits, dim=-1), dim=-1
        )

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

What a huge block of code. Oh well.

Let’s see it in action:

Code
converse("""
I am a house.
I love my son.
I like my organs.
I can hear.
""", tokenizer=tokenizer, model=model)
% hello house
hello! how are you?
%  I'm doing well.
i am fine, just watching some tv.
%  What is on right now?
the news. do you have any kids?
%  quit

I’m quite happy with how this conversation has gone. The quality of the conversation can vary quite widely, this is one of the better ones.

It’s not enough to have a text interface to the model. My son wants to speak to it, and hear the response. So I need to investigate the Speech2Text or the recent Wav2Vec2 models to do this. That will be for another post though.