RAG for Graphs

How to perform Retrieval Augmented Generation over graph datastructures?
Published

August 1, 2024

I’ve been testing out various kinds of retrieval augmented generation (RAG) recently and I wanted to try applying it to a graph data structure. The aim will be to provide the graph to the language model and then ask it questions about the graph.

If this works then a natural extension will be to test it on much larger graphs where part of the problem becomes identifying what to include. It is also worth considering how to evaluate the accuracy of the process.

Model and Interface

To start with I need a model with some code to make interactions easy. Reusing the code from previous posts provides me with a nice way to render the interactions.

Code
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
import torch

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"

quantization_config = HqqConfig(
    nbits=8,
    group_size=64,
    quant_zero=False,
    quant_scale=False,
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
    device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
Code
# from src/main/python/blog/llm/chat.py
# pylint: disable=import-outside-toplevel, global-statement
from __future__ import annotations

import difflib
from dataclasses import dataclass
from typing import Literal

import pygments
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

_CSS_STYLES_ADDED = False


def _add_css_styles() -> None:
    from IPython.display import HTML, display
    from pygments.formatters.html import HtmlFormatter

    global _CSS_STYLES_ADDED
    if _CSS_STYLES_ADDED:
        return
    css = HtmlFormatter().get_style_defs(".highlight")
    display(HTML(f"<style>{css}</style>"))
    _CSS_STYLES_ADDED = True


@dataclass
class Utterance:
    role: Literal["user", "assistant"]
    content: str

    def _repr_markdown_(self) -> str:
        return self.as_markdown()

    def as_json(self) -> dict[str, str]:
        return {"role": self.role, "content": self.content}

    def as_markdown(self) -> str:
        return "\n".join(
            [
                f"#### {self.role.capitalize()}",
                *self.content.splitlines(),
            ]
        )


@dataclass
class Comparison:
    original: Utterance
    updated: Utterance
    threshold: float = 0.75

    def similarity(self) -> float:
        """
        A number in the range [0,1].
        0 indicates complete dissimilarity.
        1 indicates complete similarity.
        """
        original = self.original.as_markdown()
        updated = self.updated.as_markdown()
        matcher = difflib.SequenceMatcher(None, original, updated)
        ratio = matcher.ratio()
        return ratio

    def _ipython_display_(self) -> None:
        from IPython.display import HTML, Markdown, display

        if self.similarity() >= self.threshold:
            _add_css_styles()
            display(HTML(self.as_difference()))
        else:
            display(Markdown(self.as_updated()))

    def as_difference(self) -> str:
        from pygments.formatters.html import HtmlFormatter
        from pygments.lexers.diff import DiffLexer

        original = self.original.as_markdown()
        updated = self.updated.as_markdown()
        comparison_lines = difflib.context_diff(
            original.splitlines(),
            updated.splitlines(),
        )
        comparison = "\n".join(comparison_lines)
        html_comparison = pygments.highlight(comparison, DiffLexer(), HtmlFormatter())
        return html_comparison

    def as_updated(self) -> str:
        original_markdown = self.original.as_markdown()
        original_length = len(original_markdown)
        updated_markdown = self.updated.as_markdown()
        updated_length = len(updated_markdown)
        preamble = (
            "<small>"
            f"previous length: {original_length:,} characters"
            "<br/>"
            f"current length: {updated_length:,} characters"
            "</small>"
        )
        markdown = preamble + "\n\n" + updated_markdown
        return markdown


@dataclass
class ChatComparison:
    comparisons: list[Comparison]

    def _ipython_display_(self) -> None:
        from IPython.display import display

        for comparison in self.comparisons:
            display(comparison)


@dataclass
class Chat:
    utterances: list[Utterance]

    def _repr_markdown_(self) -> str:
        return self.as_markdown()

    def as_json(self) -> list[dict[str, str]]:
        return list(map(Utterance.as_json, self.utterances))

    def as_markdown(self) -> str:
        utterances = map(Utterance.as_markdown, self.utterances)
        markdown = "\n\n".join(utterances)
        return markdown

    def compare(self, chat: Chat, threshold: float = 0.75) -> ChatComparison:
        return ChatComparison(
            [
                Comparison(original=original, updated=updated, threshold=threshold)
                for original, updated in zip(chat.utterances, self.utterances)
            ]
        )

    def next(self, utterance: Utterance) -> Chat:
        return Chat(self.utterances + [utterance])

    def assistant(self, content: str) -> Chat:
        return self.next(Utterance(role="assistant", content=content))

    def user(self, content: str) -> Chat:
        return self.next(Utterance(role="user", content=content))

    def __getitem__(self, index: int | slice) -> Chat:
        item = self.utterances[index]
        if not isinstance(item, list):
            item = [item]
        return Chat(item)


@torch.inference_mode()
def generate_chat(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    chat: str | Chat,
    max_new_tokens: int = 100,
    do_sample: bool = False,
    **kwargs,
) -> Chat:
    if isinstance(chat, str):
        chat = Chat([Utterance(role="user", content=chat)])
    if not do_sample:
        # transformers now warns if these are set when do_sample is false
        kwargs["temperature"] = None
        kwargs["top_p"] = None
    chat_input = tokenizer.apply_chat_template(
        chat.as_json(),
        return_tensors="pt",
        padding="longest",
    )
    chat_input = chat_input.to(model.device)
    attention_mask = torch.ones_like(chat_input)
    generated_ids = model.generate(
        chat_input,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        pad_token_id=tokenizer.pad_token_id,
        **kwargs,
    )
    output = tokenizer.decode(
        generated_ids[0, chat_input.shape[1] :],
        skip_special_tokens=True,
    )
    output = output.strip()
    response = Utterance(role="assistant", content=output)
    return chat.next(response)

It’s good to check that this all works. I can do this by asking the model for a weather forecast:

Code
generate_chat(
    model=model,
    tokenizer=tokenizer,
    chat="What is the weather like today?",
)

User

What is the weather like today?

Assistant

I’m an AI and I don’t have real-time capabilities or the ability to check the weather. I recommend you to check a reliable weather website or app for the current weather conditions.


The code works, the language model produces a coherent response, and it doesn’t even hallucinate (yet). A good start.

Breakfast Graph

One form of graph is cooking. There are different parts to cooking such as preparation, heating and combining. There is an order to these and they should be arranged so that the final parts are ready at the same time.

We can start with a very simple cooking graph; that of a breakfast consisting of coffee, buttered toast and a soft boiled egg.

Breakfast Dataset

The first thing to do will be to represent this in code and provide a way to visualize it. Then we can think how to efficiently present this information to the large language model.

Code
from dataclasses import dataclass, field

@dataclass
class BreakfastGraph:
    nodes: set[str] = field(default_factory=set)
    edges: dict[(str, str), float] = field(default_factory=dict)

    def add_edge(
        self,
        source: str,
        destination: str,
        duration: float,
    ) -> None:
        self.nodes = self.nodes | {source, destination}
        self.edges[(source, destination)] = duration

I have chosen a very simple representation of the graph, just listing the edges as the cooking time and the nodes as the ingredient or output names. What I need is a way to construct the breakfast cooking graph and then render it. Let’s have a look at this breakfast:

Code
breakfast_graph = BreakfastGraph()
breakfast_graph.add_edge(
    "raw egg",
    "soft boiled egg",
    duration=3.0,
)
breakfast_graph.add_edge(
    "beans and water",
    "coffee",
    duration=1.0,
)
breakfast_graph.add_edge(
    "bread",
    "toast",
    duration=1.0,
)
breakfast_graph.add_edge(
    "toast",
    "buttered toast",
    duration=0.5,
)

breakfast_graph.add_edge(
    "soft boiled egg",
    "breakfast",
    0.0,
)
breakfast_graph.add_edge(
    "coffee",
    "breakfast",
    0.0,
)
breakfast_graph.add_edge(
    "buttered toast",
    "breakfast",
    0.0,
)
Code
import graphviz

def render_breakfast_graph(
    graph: BreakfastGraph,
) -> graphviz.Digraph:
    node_identifiers = {
        node: f'node_{index}'
        for index, node in enumerate(sorted(graph.nodes))
    }
    node_definitions = [
        f'{identifier} [label="{node}"];'
        for node, identifier in node_identifiers.items()
    ]
    nodes = "\n".join(node_definitions)

    def to_edge(
        source: str,
        destination: str,
        weight: float,
    ) -> str:
        source_id = node_identifiers[source]
        destination_id = node_identifiers[destination]
        return f'{source_id} -> {destination_id} [label="{weight} minutes"];'

    edge_definitions = [
        to_edge(source, destination, weight)
        for (source, destination), weight in graph.edges.items()
    ]
    edges = "\n".join(edge_definitions)

    graph_definition = f"""
digraph G {{
node [shape=box];
{nodes}
{edges}
}}
""".strip()
    return graphviz.Source(graph_definition)
Code
render_breakfast_graph(breakfast_graph)

This shows the simple process of making a breakfast. Visualizing the graph can help us understand the correct order to do things. If we were to arrange this into a recipe then it would resemble:

  • boil an egg
  • after \(1 \frac{1}{2}\) minutes put the toast in the toaster
  • after \(\frac{1}{2}\) a minute put the coffee on
  • after \(\frac{1}{2}\) a minute butter the toast
  • after \(\frac{1}{2}\) a minute the breakfast is ready

The graph needs to be provided to the language model and it should produce a sequence much like the above. If it can do this then that would be one demonstration of graph understanding.

Breakfast Representation

The visual representation of the graph allows us to see the graph. Such a visualization is convenient for us, but it is not the best way to provide the data to the model.

There are several different formats that we can use to provide the data. We could render the graph in a textual form (e.g. turning bread into toast takes 1 minute), or as a json blob of data (e.g. { “source”: “bread”, “destination”: “toast”, “duration”: 1.0 }). The desired output format is also a choice, do we expect this output to be consumed by another machine or would we want it to be understood by a human?

Code
from collections import defaultdict

def readable_breakfast_graph(graph: BreakfastGraph) -> str:
    output_sources = defaultdict(set)
    for source, destination in graph.edges:
        output_sources[destination].add(source)

    def to_description(output: str, sources: set[str]) -> str:
        duration = max(
            graph.edges[(source, output)]
            for source in sources
        )
        sources = sorted(sources)
        if len(sources) == 1:
            sources_str = sources[0]
        else:
            sources_str = ", ".join(sources[:-1])
            sources_str = f"{sources_str} and {sources[-1]}"
        return f"making {output} uses {sources_str} and takes {duration} minutes"

    descriptions = [
        to_description(output, output_sources[output])
        for output in sorted(output_sources)
    ]
    descriptions_str = "\n".join(f" * {description}" for description in descriptions)
    return descriptions_str

This is how we can present the data in a readable form:

Code
print(readable_breakfast_graph(breakfast_graph))
 * making breakfast uses buttered toast, coffee and soft boiled egg and takes 0.0 minutes
 * making buttered toast uses toast and takes 0.5 minutes
 * making coffee uses beans and water and takes 1.0 minutes
 * making soft boiled egg uses raw egg and takes 3.0 minutes
 * making toast uses bread and takes 1.0 minutes
Code
from collections import defaultdict
import json

def json_breakfast_graph(graph: BreakfastGraph) -> str:
    output_inputs = defaultdict(set)
    for input, output in graph.edges:
        output_inputs[output].add(input)

    def to_dict(index: int, output: str) -> dict[str, str | float]:
        inputs = output_inputs[output]
        duration = max(
            graph.edges[(input, output)]
            for input in inputs
        )
        return {
            "id": index,
            "inputs": sorted(inputs),
            "output": output,
            "duration": duration,
        }

    data = [
        to_dict(index, output)
        for index, output in enumerate(sorted(output_inputs))
    ]
    return json.dumps(data)

And this is how we can make it machine readable:

Code
json_breakfast_graph(breakfast_graph)
'[{"id": 0, "inputs": ["buttered toast", "coffee", "soft boiled egg"], "output": "breakfast", "duration": 0.0}, {"id": 1, "inputs": ["toast"], "output": "buttered toast", "duration": 0.5}, {"id": 2, "inputs": ["beans and water"], "output": "coffee", "duration": 1.0}, {"id": 3, "inputs": ["raw egg"], "output": "soft boiled egg", "duration": 3.0}, {"id": 4, "inputs": ["bread"], "output": "toast", "duration": 1.0}]'

These are some simple representations of the graph. It is important to note that the breakfast cannot be created by following these steps as written, one problem is the buttered toast step comes before the toast is made.

Breakfast Recipe

What we want is a model which can take this graph and perform a topological sort over it, turning it into a list of steps: a recipe.

This will be a chance to try out the different data formats as well as doing some prompt engineering. We can start with the readable format with a simple prompt:

Code
steps = readable_breakfast_graph(breakfast_graph)

prompt = f"""
I am hungry and I want to make breakfast as fast as possible.
I have been given different steps that cook various ingredients.
Organize these into a recipe that will make breakfast quickly.
I can cook multiple things at the same time.

STEPS:
{steps}
""".strip()

chat_1 = generate_chat(
    model=model,
    tokenizer=tokenizer,
    chat=prompt,
    max_new_tokens=1_000,
)
chat_1

User

I am hungry and I want to make breakfast as fast as possible. I have been given different steps that cook various ingredients. Organize these into a recipe that will make breakfast quickly. I can cook multiple things at the same time.

STEPS: * making breakfast uses buttered toast, coffee and soft boiled egg and takes 0.0 minutes * making buttered toast uses toast and takes 0.5 minutes * making coffee uses beans and water and takes 1.0 minutes * making soft boiled egg uses raw egg and takes 3.0 minutes * making toast uses bread and takes 1.0 minutes

Assistant

To make a quick breakfast, follow these steps:

  1. Prepare the coffee:
    • Place the coffee beans and water in a coffee maker.
    • Start the coffee maker (takes 1.0 minute)
  2. Toast the bread:
    • Place the bread slices in the toaster.
    • Toast the bread (takes 1.0 minute)
  3. Butter the toast:
    • Spread butter on the toasted bread slices (takes 0.0 minute, as it can be done while the coffee is brewing and the egg is cooking)
  4. Cook the soft-boiled egg:
    • Fill a small saucepan with water and bring it to a boil.
    • Reduce the heat to a simmer, then gently lower the raw egg into the water.
    • Cook the egg for 3.0 minutes.
    • Remove the egg from the water and place it in a bowl of ice water to stop the cooking process.
    • Peel the egg (takes 0.0 minute, as it can be done while the coffee is brewing)
  5. Assemble the breakfast:
    • Place the buttered toast on a plate.
    • Carefully place the soft-boiled egg on the toast.
    • Serve the coffee on the side.

Enjoy your quick and delicious breakfast!


This has produced a sorting of the actions however the model has felt the need to elaborate extensively on the steps. Generating so much text is slow and we really just want the steps.

We can improve this by adding some simple additional instructions to the prompt. When performing prompt engineering like this it is very important to keep track of what you have tried, to help with this for the post I am showing prompt changes as a contextual difference.

Code
steps = readable_breakfast_graph(breakfast_graph)

prompt = f"""
I am hungry and I want to make breakfast as fast as possible.
I have been given different steps that cook various ingredients.
Organize these into a recipe that will make breakfast quickly.
I can cook multiple things at the same time.

STEPS:
{steps}

Only list the steps that I need to perform.
""".strip()

chat_2 = generate_chat(
    model=model,
    tokenizer=tokenizer,
    chat=prompt,
    max_new_tokens=1_000,
)
chat_2.compare(chat_1)
*** 

--- 

***************

*** 10,12 ****

--- 10,14 ----

   * making coffee uses beans and water and takes 1.0 minutes
   * making soft boiled egg uses raw egg and takes 3.0 minutes
   * making toast uses bread and takes 1.0 minutes
+ 
+ Only list the steps that I need to perform.

previous length: 1,046 characters
current length: 360 characters

Assistant

  1. Toast the bread for 1.0 minute to make toast.
  2. Prepare the coffee by brewing the beans for 1.0 minute.
  3. Boil the raw egg for 3.0 minutes to make a soft boiled egg.
  4. Butter the toast that was made in step 1.
  5. Assemble the breakfast by placing the soft boiled egg and buttered toast on a plate, and serve with the coffee made in step 2.

The addition of a single line to the prompt drastically reduced the response size. Prompting well takes care and, as you can see, small changes can produce completely different results.

I would say that this prompt made the model write a complete recipe.

This is fine for human consumption, however a machine needs a more regular output to work with. We could use the json format for the graph and ask the model to produce json output.

Code
steps = json_breakfast_graph(breakfast_graph)

prompt = f"""
I am hungry and I want to make breakfast as fast as possible.
I have been given json data containing different steps that cook various ingredients.
Organize these into a recipe that will make breakfast quickly.
I can cook multiple things at the same time.

{steps}

Produce a json array containing the ordered step IDs.
Only list the steps that I need to perform.
""".strip()

chat_3 = generate_chat(
    model=model,
    tokenizer=tokenizer,
    chat=prompt,
    max_new_tokens=1_000,
)
chat_3

User

I am hungry and I want to make breakfast as fast as possible. I have been given json data containing different steps that cook various ingredients. Organize these into a recipe that will make breakfast quickly. I can cook multiple things at the same time.

[{“id”: 0, “inputs”: [“buttered toast”, “coffee”, “soft boiled egg”], “output”: “breakfast”, “duration”: 0.0}, {“id”: 1, “inputs”: [“toast”], “output”: “buttered toast”, “duration”: 0.5}, {“id”: 2, “inputs”: [“beans and water”], “output”: “coffee”, “duration”: 1.0}, {“id”: 3, “inputs”: [“raw egg”], “output”: “soft boiled egg”, “duration”: 3.0}, {“id”: 4, “inputs”: [“bread”], “output”: “toast”, “duration”: 1.0}]

Produce a json array containing the ordered step IDs. Only list the steps that I need to perform.

Assistant

Based on the given JSON data and the requirement to make breakfast quickly, the following steps can be performed simultaneously:

  1. Toast the bread (Step 1)
  2. Prepare the soft-boiled egg (Step 3)
  3. Brew the coffee (Step 2)

Since these steps can be performed concurrently, we can buttered the toast (Step 4) while waiting for the other steps to complete.

Here’s the ordered JSON array of the step IDs:

[1, 3, 2, 4]

This means: 1. Toast the bread (Step 1) 2. Prepare the soft-boiled egg (Step 3) 3. Brew the coffee (Step 2) 4. Butter the toast (Step 4)


There is some json data hidden in the middle of that response. Once again the model has been overly descriptive. Let’s try to cut that down.

Code
steps = json_breakfast_graph(breakfast_graph)

prompt = f"""
I am hungry and I want to make breakfast as fast as possible.
I have been given json data containing different steps that cook various ingredients.
Organize these into a recipe that will make breakfast quickly.
I can cook multiple things at the same time.

{steps}

Produce a json array containing the ordered step IDs.
Do not add any additional details or commentary.
""".strip()

chat_4 = generate_chat(
    model=model,
    tokenizer=tokenizer,
    chat=prompt,
    max_new_tokens=1_000,
)
chat_4.compare(chat_3)
*** 

--- 

***************

*** 7,10 ****

  [{"id": 0, "inputs": ["buttered toast", "coffee", "soft boiled egg"], "output": "breakfast", "duration": 0.0}, {"id": 1, "inputs": ["toast"], "output": "buttered toast", "duration": 0.5}, {"id": 2, "inputs": ["beans and water"], "output": "coffee", "duration": 1.0}, {"id": 3, "inputs": ["raw egg"], "output": "soft boiled egg", "duration": 3.0}, {"id": 4, "inputs": ["bread"], "output": "toast", "duration": 1.0}]
  
  Produce a json array containing the ordered step IDs.
! Only list the steps that I need to perform.
--- 7,10 ----

  [{"id": 0, "inputs": ["buttered toast", "coffee", "soft boiled egg"], "output": "breakfast", "duration": 0.0}, {"id": 1, "inputs": ["toast"], "output": "buttered toast", "duration": 0.5}, {"id": 2, "inputs": ["beans and water"], "output": "coffee", "duration": 1.0}, {"id": 3, "inputs": ["raw egg"], "output": "soft boiled egg", "duration": 3.0}, {"id": 4, "inputs": ["bread"], "output": "toast", "duration": 1.0}]
  
  Produce a json array containing the ordered step IDs.
! Do not add any additional details or commentary.

previous length: 572 characters
current length: 30 characters

Assistant

[1, 4, 2, 3, 0]


This is the fastest response yet and it is machine readable. The proposed order is unfortunately incorrect. The steps as listed would be:

  • 1: toast -> buttered toast
  • 4: bread -> toast
  • 2: beans and water -> coffee
  • 3: raw egg -> soft boiled egg
  • 0: buttered toast, coffee, soft boiled egg -> breakfast

In this step order the toast is buttered before the bread is toasted. This is a fixable problem.

Instead of spending more time fixing this let us think about how to test the model. After all this is a single recipe, we might fix the problems with it just to find that other recipes are hopelessly broken.

How to Test a Prompt

Since we have a prompt and model that can generate machine readable output, we can feed it known problems and test the accuracy of the solution. This can also validate how frequently the tool produces machine readable output, as the response is currently unconstrained.

Coming up with a recipe dataset that is large enough to perform these tests is difficult. There are books full of recipes that could be used, however these have not separated the steps into the inputs and outputs in the clear fashion of the breakfast graph. As such I have decided to work with a different dataset.

Test Dataset

To test a prompt we need a dataset. For the purposes of this blog I can use any dataset, so I want one where I can ask questions and know what the correct answer is.

A family tree can be generated randomly if we have a list of names. We can name individuals, marry them, and then have them produce children etc etc.

Given such a graph it is then possible to generate questions about it (e.g. how many children does Jane have?) which can have known answers.

We can start with a list of baby names and use them to generate our fictitious family trees. I have found a github repository of baby names. Using the top 1,000 names from the year 2022 will be sufficient for our purposes.

Code
import pandas as pd

names_df = pd.read_csv("/data/blog/2024/08/01/rag-for-graphs/popular-baby-names/2022/girl_boy_names_2022.csv")
names_df = names_df.rename(
    columns={
        "Girl Name": "female",
        "Boy Name": "male",
    }
)
names_df = names_df[["female", "male"]]
names_df = names_df.copy()
names_df
female male
0 Olivia Liam
1 Emma Noah
2 Charlotte Oliver
3 Amelia James
4 Sophia Elijah
... ... ...
995 Luella Imran
996 Nancy Ivaan
997 Cielo Kanan
998 Madalyn Kalel
999 Kahlani London

1000 rows × 2 columns

With these names we can generate family trees. I’m going to use a very simple approach which isn’t ideal but does generate varied and ragged graphs.

The generation does not guarantee that every couple has children, or that every person is married etc etc. This will allow us to ask questions which have no valid answer, a valuable test for the prompt and model.

I’ll start by generating a small tree with some sample questions and answers.

Code
# from src/main/python/blog/graph/genealogy.py
from __future__ import annotations

import random
from dataclasses import dataclass, field
from enum import Enum
from typing import Iterator, Optional

import numpy as np
import pandas as pd


class Gender(Enum):
    MALE = "male"
    FEMALE = "female"
    OTHER = "other"

    @staticmethod
    def choice() -> Gender:
        return random.choice(
            [
                Gender.MALE,
                Gender.FEMALE,
                Gender.OTHER,
            ]
        )


@dataclass(frozen=True)
class Person:
    name: str
    gender: Gender

    @staticmethod
    def make(name: str) -> Person:
        gender = Gender.choice()
        return Person(
            name=name,
            gender=gender,
        )

    def to_json(self) -> dict[str, str]:
        return {
            "name": self.name,
            "gender": self.gender.value,
        }


@dataclass(frozen=True)
class Family:
    parents: tuple[Person, ...]
    children: tuple[Person, ...] = field(default_factory=list)

    @staticmethod
    def make(
        left: Person,
        right: Person,
        names: Iterator[str],
        min_children: int,
        max_children: int,
    ) -> Family:
        children_count = random.randint(min_children, max_children)
        children = [Person.make(name) for _, name in zip(range(children_count), names)]
        return Family(
            parents=(left, right),
            children=tuple(children),
        )


@dataclass
class Generation:
    adults: list[Person]
    children: list[Person]
    families: list[Family]

    @staticmethod
    def make(
        previous: Generation,
        names: Iterator[str],
        min_children: int,
        max_children: int,
        marriage_ratio: float = 0.5,
    ) -> Generation:
        eligible_adults = previous.marriage_candidates()
        total_couples = len(eligible_adults) // 2
        if total_couples < 1:
            return Generation(
                adults=previous.children,
                children=[],
                families=[],
            )

        marriage_count = int(total_couples * marriage_ratio)
        families = []
        child_count = 0
        for index in range(marriage_count):
            left = random.choice(eligible_adults)
            eligible_adults.remove(left)
            right = random.choice(eligible_adults)
            eligible_adults.remove(right)

            desired_min = max(min_children - child_count, 0)
            desired_max = max(max_children - child_count, 0)
            # off by one intentional,
            # include current marriage
            remaining_marriages = marriage_count - index
            family = Family.make(
                left,
                right,
                names=names,
                min_children=desired_min // remaining_marriages,
                max_children=desired_max // remaining_marriages,
            )
            families.append(family)

        children = [child for family in families for child in family.children]
        return Generation(
            adults=previous.children,
            children=children,
            families=families,
        )

    def marriage_candidates(self) -> list[Person]:
        # These are the marriage candidates for
        # the next generation. The previous
        # generation of unmarried can marry the
        # current, but no further
        married_adults = {
            person for family in self.families for person in family.parents
        }
        unmarried_adults = [
            person for person in self.adults if person not in married_adults
        ]
        eligible_adults = unmarried_adults + self.children
        return eligible_adults

    def __len__(self) -> int:
        return len(self.children)

    @property
    def empty(self) -> bool:
        return not (self.adults or self.children)

    @property
    def people(self) -> list[Person]:
        return self.adults + self.children


@dataclass
class PersonDetails:
    person: Person
    parents: list[Person]
    siblings: list[Person]
    partner: Optional[Person]
    children: list[Person]

    def to_json(self) -> dict[str, str]:
        person = self.person.to_json()
        parents = list(map(Person.to_json, self.parents))
        siblings = list(map(Person.to_json, self.siblings))
        if self.partner is not None:
            partner = self.partner.to_json()
        else:
            partner = None
        children = list(map(Person.to_json, self.children))
        data = {
            "person": person,
            "parents": parents,
            "siblings": siblings,
            "partner": partner,
            "children": children,
        }
        return data


@dataclass
class Genealogy:
    people: list[Person]
    families: list[Family] = field(default_factory=list)

    @staticmethod
    def make(generations: list[Generation]) -> Genealogy:
        people = {person for generation in generations for person in generation.people}
        families = [
            family for generation in generations for family in generation.families
        ]
        return Genealogy(
            people=list(people),
            families=families,
        )

    def find_by_name(self, name: str) -> Optional[PersonDetails]:
        name = name.casefold()
        matching_people = [
            person for person in self.people if person.name.casefold() == name
        ]
        if not matching_people:
            return None
        assert (
            len(matching_people) == 1
        ), f"found multiple people for {name}: {matching_people}"
        person = matching_people[0]
        return self.find(person)

    def find(self, person: Person) -> Optional[PersonDetails]:
        assert person in self.people, f"provided person not found: {person}"

        parental_families = [
            family for family in self.families if person in family.children
        ]
        assert (
            len(parental_families) < 2
        ), f"found multiple parental families for {person}: {parental_families}"
        if not parental_families:
            parents = []
            siblings = []
        else:
            parental_family = parental_families[0]
            parents = parental_family.parents
            siblings = [
                sibling for sibling in parental_family.children if sibling != person
            ]

        marriages = [family for family in self.families if person in family.parents]
        assert len(marriages) < 2, f"found multiple marriages for {person}: {marriages}"
        if not marriages:
            partner = None
            children = []
        else:
            marriage = marriages[0]
            partners = [member for member in marriage.parents if member != person]
            assert len(partners) == 1, f"found polygamy for {person}: {partners}"
            partner = partners[0]
            children = marriage.children

        return PersonDetails(
            person=person,
            parents=parents,
            siblings=siblings,
            partner=partner,
            children=children,
        )


def generate_genealogy(  # pylint: disable=too-many-arguments
    df: pd.DataFrame,
    *,
    starting_population: int = 10,
    population_limit: int = 100,
    generation_limit: int = 10,
    marriage_ratio: float = 0.5,
    min_children: int = 1,
    max_children: int = 10,
    random_seed: int = 42,
) -> Genealogy:
    random.seed(random_seed)
    np.random.seed(random_seed)

    names = _name_iterator(df, random_seed=random_seed)

    starting_people = [
        Person.make(name) for _, name in zip(range(starting_population), names)
    ]
    starting_generation = Generation(
        adults=[],
        families=[],
        children=starting_people,
    )
    generations = [starting_generation]

    for _ in range(generation_limit):
        if sum(map(len, generations)) >= population_limit:
            break
        current_generation = generations[-1]
        if current_generation.empty:
            break
        # TNG
        next_generation = Generation.make(
            previous=current_generation,
            names=names,
            marriage_ratio=marriage_ratio,
            min_children=min_children,
            max_children=max_children,
        )
        generations.append(next_generation)

    genealogy = Genealogy.make(generations)
    return genealogy


def _name_iterator(df: pd.DataFrame, random_seed: int) -> Iterator[str]:
    names = pd.concat(
        [
            df.male,
            df.female,
        ]
    )
    names = names.sample(
        n=len(names),
        random_state=random_seed,
    )
    names = names.tolist()
    return iter(names)


def generate_simple_questions(genealogy: Genealogy) -> pd.DataFrame:
    # the questions take the following form:
    # what is the gender of X
    # who are the father(s)/mother(s)/parents of X
    # is X married
    # who is X married to
    # how many sons/daughters/children does X have
    # who are the sons/daughters/children of X
    def filter_by_gender(people: list[Person], gender: Gender) -> list[str]:
        return [person.name for person in people if person.gender == gender]

    questions = []
    for person in genealogy.people:
        details = genealogy.find(person)
        details_json = details.to_json()
        name = person.name

        questions.append(
            {
                "context": details_json,
                "question": f"What is the gender of {name}?",
                "answer": person.gender.value,
            }
        )

        questions.append(
            {
                "context": details_json,
                "question": f"Who are the father(s) of {name}?",
                "answer": filter_by_gender(details.parents, Gender.MALE),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"Who are the mother(s) of {name}?",
                "answer": filter_by_gender(details.parents, Gender.FEMALE),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"Who are the parents of {name}?",
                "answer": [person.name for person in details.parents],
            }
        )

        questions.append(
            {
                "context": details_json,
                "question": f"Is {name} married?",
                "answer": bool(details.partner),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"Who is {name} married to?",
                "answer": (
                    details.partner.name if details.partner is not None else None
                ),
            }
        )

        questions.append(
            {
                "context": details_json,
                "question": f"How many sons does {name} have?",
                "answer": len(filter_by_gender(details.children, Gender.MALE)),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"How many daughters does {name} have?",
                "answer": len(filter_by_gender(details.children, Gender.FEMALE)),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"How many children does {name} have?",
                "answer": len(details.children),
            }
        )

        questions.append(
            {
                "context": details_json,
                "question": f"Who are the son(s) of {name}?",
                "answer": filter_by_gender(details.children, Gender.MALE),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"Who are the daughter(s) of {name}?",
                "answer": filter_by_gender(details.children, Gender.FEMALE),
            }
        )
        questions.append(
            {
                "context": details_json,
                "question": f"Who are the children of {name}?",
                "answer": [person.name for person in details.children],
            }
        )

    return pd.DataFrame(questions)
Code
import graphviz

def render_geneoalgy(geneoalgy: Genealogy) -> graphviz.Digraph:
    person_identifiers = {
        person: f"person_{index}"
        for index, person in enumerate(geneoalgy.people)
    }
    person_definitions = [
        f'{identifier} [label="{person.name} ({person.gender.value})"];'
        for person, identifier in person_identifiers.items()
    ]
    people = "\n".join(person_definitions)

    family_identifiers = {
        family: f"family_{index}"
        for index, family in enumerate(geneoalgy.families)
    }
    family_definitions = [
        f'{identifier} [label="married"];'
        for identifier in family_identifiers.values()
    ]
    families = "\n".join(family_definitions)

    marriage_edges = [
        f"{person_identifiers[person]} -> {family_identifiers[family]};"
        for family in geneoalgy.families
        for person in family.parents
    ]
    child_edges = [
        f"{family_identifiers[family]} -> {person_identifiers[person]};"
        for family in geneoalgy.families
        for person in family.children
    ]
    edges = "\n".join(
        marriage_edges + child_edges
    )
    
    graph_definition = f"""
digraph G {{
{people}
node [shape=box];
{families}
{edges}
}}
""".strip()
    return graphviz.Source(graph_definition)
Code
family_tree = generate_genealogy(
    names_df,
    starting_population=4,
    population_limit=10,
    generation_limit=3,
    marriage_ratio=0.75,
    min_children=1,
    max_children=4,
    random_seed=42,
)
render_geneoalgy(family_tree)

Code
questions_df = generate_simple_questions(family_tree)
questions_df[["question", "answer"]]
question answer
0 What is the gender of Jane? male
1 Who are the father(s) of Jane? [Elianna]
2 Who are the mother(s) of Jane? []
3 Who are the parents of Jane? [Elianna, Robin]
4 Is Jane married? False
... ... ...
79 How many daughters does Robin have? 0
80 How many children does Robin have? 2
81 Who are the son(s) of Robin? [Jane]
82 Who are the daughter(s) of Robin? []
83 Who are the children of Robin? [Jane, Tatum]

84 rows × 2 columns

We have 84 questions that we can ask about this family tree. The questions have the expected answer (and the context that we would provide to the model, which describes the subject of the question and their immediate family).

It would be possible to evaluate the model by asking each of these questions in turn and evaluating the accuracy of the response. This would provide us with a reasonable metric to gauge the quality of the prompt by.

There is a problem though. The questions have some answers which require specific phrasing, such as the following:

Code
questions_df[
    questions_df.question.isin({
        "Is Jane married?",
        "Who is Jane married to?",
    })
][["question", "answer"]]
question answer
4 Is Jane married? False
5 Who is Jane married to? None

Here the question Who is Jane married to? cannot be meaningfully answered because Jane is not married. When converted to json this would be null, but is it so wrong of the model to respond that Jane is not married?

What the model needs is some example answers so that it can correctly answer questions like these. This takes the prompt from zero shot (no examples given) to few shot. Choosing the right examples to provide is an important part of prompting.

This also gives us an opportunity to move to a more efficient means of querying the model. Instead of using the chat interface that we have seen so far, we can use the ability of the language model to predict the next word and get it to continue our prompt.

The prompt with the examples can be followed by the question that we want to ask, and then the model will generate an answer in the same style as the previous examples. This is a neat technique that can inform the model about the expected output format in a more concrete way as well as cutting down on invocation time. The following prompt should give you an idea of how this works:

You are a math student taking a test.
Answer every question correctly and concisely.

Question:
1 + 1
Answer:
2
Question:
6 * 9
Answer:
54
Question:
54 * 45
Answer:

The model would then complete this with an answer and would go on to generate more questions. If we look for the Question: phrase during output generation then we know that the model has answered the question that we asked and we can stop.

It might be easier to understand if we see this in action:

Code
# from src/main/python/blog/llm/continuation.py
from __future__ import annotations

import re
from typing import Optional

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
)


@torch.inference_mode()
def generate_continuation(
    *,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    stopping: Optional[str | list[str]] = None,
    max_new_tokens: int = 100,
    additional_stopping: Optional[list[StoppingCriteria]] = None,
    do_sample: bool = False,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    **kwargs,
) -> str:
    stopping_criteria = []
    token_sequence_stopping = []
    if stopping:
        if not isinstance(stopping, list):
            stopping = [stopping]
        token_sequence_stopping = [
            TokenSequenceStoppingCriteria.make(
                tokenizer=tokenizer,
                sequence=sequence,
                device=model.device,
            )
            for sequence in stopping
        ]
        stopping_criteria.extend(token_sequence_stopping)
    if additional_stopping:
        stopping_criteria.extend(additional_stopping)

    if stopping_criteria:
        stopping_criteria_object = StoppingCriteriaList(stopping_criteria)
    else:
        stopping_criteria_object = None

    model_input = tokenizer(
        prompt,
        return_tensors="pt",
        padding="longest",
    )
    input_tokens = model_input.input_ids.shape[1]
    model_input = model_input.to(model.device)
    generated_ids = model.generate(
        **model_input,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.pad_token_id,
        stopping_criteria=stopping_criteria_object,
    )
    filtered_ids = generated_ids[0, input_tokens:]
    if stopping:
        for criteria in token_sequence_stopping:
            if not criteria.is_end(filtered_ids):
                continue
            filtered_ids = criteria.truncate(filtered_ids)
            break
    output = tokenizer.decode(
        filtered_ids,
        skip_special_tokens=True,
    )
    return output.strip()


class TokenSequenceStoppingCriteria(StoppingCriteria):
    @staticmethod
    def make(
        tokenizer: AutoTokenizer,
        sequence: str,
        device: Optional[str | torch.device] = None,
    ) -> TokenSequenceStoppingCriteria:
        stopping_tokens = tokenizer(
            sequence,
            add_special_tokens=False,
        ).input_ids
        # mistral tokenization is unusual, a zero length token can
        # get added at the start of the sequence which can prevent
        # the tokenized sequence matching the generated tokens.
        # this filter drops any zero length tokens.
        stopping_tokens = [
            token for token in stopping_tokens if len(tokenizer.decode(token)) > 0
        ]
        return TokenSequenceStoppingCriteria(stopping_tokens, device=device)

    def __init__(
        self,
        sequence: list[int] | torch.Tensor,
        device: Optional[str | torch.device] = None,
    ) -> None:
        super().__init__()
        if isinstance(sequence, list):
            sequence = torch.Tensor(sequence)
        if device is not None:
            sequence = sequence.to(device)
        self.sequence = sequence

    def to(self, device: str | torch.device) -> TokenSequenceStoppingCriteria:
        self.sequence = self.sequence.to(device)
        return self

    def as_list(self) -> StoppingCriteriaList:
        return StoppingCriteriaList([self])

    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
        **kwargs,
    ) -> bool:
        # this assumes only a single sequence is being generated
        return self.is_end(input_ids[0])

    def is_end(self, tokens: torch.Tensor) -> bool:
        assert len(tokens.shape) == 1
        if len(tokens) < len(self.sequence):
            return False
        end = tokens[-len(self.sequence) :]
        per_token_matches = end == self.sequence
        return bool(per_token_matches.all())

    def truncate(self, tokens: torch.Tensor) -> torch.Tensor:
        if self.is_end(tokens):
            return tokens[: -len(self.sequence)]
        return tokens


class RepeatedStringStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer: AutoTokenizer, start_length: int) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.start_length = start_length
        self.pattern = re.compile(r'"([^"]*)"')

    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
        **kwargs,
    ) -> bool:
        generated_ids = input_ids[0, self.start_length :]
        generated_text = self.tokenizer.decode(generated_ids)
        generated_strings = self.pattern.findall(generated_text)
        return len(set(generated_strings)) != len(generated_strings)
generate_continuation(
    model=model,
    tokenizer=tokenizer,
    prompt="""
You are a math student taking a test.
Answer every question correctly and concisely.

Question:
1 + 1
Answer:
2
Question:
6 * 9
Answer:
54
Question:
54 * 45
Answer:
""".lstrip(),
    stopping="\nQuestion:",
    max_new_tokens=100,
)
'2370'

The model has generated 2,370 as the answer to the question 54 * 45. Unfortunately this is incorrect, the correct answer is 2,430. Irrespective of accuracy, this should serve as an example of this new mode of model invocation.

Let’s generate a separate family tree to use for the examples, as that will allow us to use people that do not appear in the tree under test.

Code
used_names = {
    person.name
    for person in family_tree.people
}
filtered_names_df = names_df[
    ~(
        names_df.male.isin(used_names) |
        names_df.female.isin(used_names)
    )
]
example_tree = generate_genealogy(
    filtered_names_df,
    starting_population=4,
    population_limit=10,
    generation_limit=3,
    marriage_ratio=0.75,
    min_children=1,
    max_children=4,
    random_seed=12345,
)
example_questions = generate_simple_questions(example_tree)
render_geneoalgy(example_tree)

This is a good tree to use as it has several desirable edge cases to use as examples. Let’s select them and then we can create the overall prompt.

Code
import json
import pandas as pd

def create_prompt(
    introduction: str,
    examples: str,
    question: str,
    context: dict,
) -> str:
    context_str = json.dumps(context)
    question_str = f"""
Context:
{context_str}
Question:
{question}
Answer:
    """.strip()

    return f"""
{introduction}

{examples}
{question_str}
""".lstrip()

def stringify_examples(df: pd.DataFrame) -> str:
    def format_example(row: pd.Series) -> str:
        context = json.dumps(row.context)
        question = row.question
        answer = json.dumps(row.answer)
        example = f"""
Context:
{context}
Question:
{question}
Answer:
{answer}
        """.strip()
        return example
    
    prompt_examples = list(map(format_example, df.iloc))
    return "\n\n".join(prompt_examples)


few_shot_examples = example_questions[
    example_questions.question.isin({
        "Who is Joyce married to?",
        "Who are the children of Cleo?",
        "How many sons does Niklaus have?",
    })
]
examples_str = stringify_examples(few_shot_examples)

A quick test of the model and prompt will let us know if everything is working:

prompt = f"""
You are a genealogy expert.
I have questions about familial relationships.
I will provide you with some context and ask you a question.
Answer using only the data that I provide
""".strip()

row = questions_df.iloc[0]

print(f"question is: {row.question}")
print(f"correct answer is: {row.answer}")
model_answer = generate_continuation(
    model=model,
    tokenizer=tokenizer,
    prompt=create_prompt(
        introduction=prompt,
        examples=examples_str,
        question=row.question,
        context=row.context,
    ),
    stopping="\nContext:",
    max_new_tokens=100,
)
print(f"model answer is: {json.loads(model_answer)}")
question is: What is the gender of Jane?
correct answer is: male
model answer is: male

Testing this on a single question has worked well. How well does the model do when it is run against all 84 questions?

Code
import pandas as pd
from tqdm.auto import tqdm
import json
import numpy as np

tqdm.pandas()

def get_answer(row: pd.Series) -> str:
    return generate_continuation(
        model=model,
        tokenizer=tokenizer,
        prompt=create_prompt(
            introduction=prompt,
            examples=examples_str,
            question=row.question,
            context=row.context,
        ),
        stopping="\nContext:",
        max_new_tokens=100,
    )

questions_df["raw_answer"] = questions_df.progress_apply(get_answer, axis="columns")

def _parse_answer(answer: str):
    try:
        return json.loads(answer.strip())
    except:
        print(f"unable to parse {answer}")
        return np.nan # not equal to anything

questions_df["model_answer"] = questions_df.raw_answer.apply(_parse_answer)
correct_answers = (questions_df.model_answer == questions_df.answer).sum()
accuracy = correct_answers / len(questions_df)
print(f"the model is: {accuracy*100:0.2f}% accurate")
unable to parse Hector
unable to parse male
unable to parse Elianna
the model is: 67.86% accurate

The very first prompt that I have tried for this task is accurate \(\frac{2}{3}^{\text{rds}}\) of the time. The model did produce invalid json three times. This performance is not very good.

The systematic evaluation of the prompt would allow me to inspect the patterns in the mistakes and then correct them. In this way I could build a better prompt.

Taking this further

The test that we have performed has involved providing appropriate context from a larger graph. Identifying this context from the graph is a large part of answering the question.

If we were to create a fully automated system how would it handle this? It’s easy to see that a family tree can be large enough to overwhelm the large language model with irrelevant data and slow it down.

The natural answer is to use the retrieval part of RAG to select the best part of the graph. This is a good start.

RAG to Agents

If we change the question to ask about more distant relationships, like the number of grandchildren that someone has, then any data about one person cannot provide the information the model requires. At this point the solution transitions from pure RAG to multi hop question answering. This is where the large language model acts as an agent, able to invoke external tools (like selecting the details for individuals) and then integrating the results into an answer. In this way the model could query for the details of each child of the grandparent first.

This post is already very long so instead of implementing that here I will point to the Haystack tutorial on it.

Model Choice

Finally a note on the choice of model. In this post I have used Mistral 7B Instruct, and it has been quantized (a technique to speed it up, at the cost of accuracy). This has struggled with some of the tasks and the final accuracy certainly needed work.

Would I recommend this model for all RAG use cases? No.

When demonstrating how to work with and improve a RAG system it is good to have something that can be improved. I chose this model because I can run it quickly, easily and locally. It demonstrates some flaws however it is a capable model.

For a lot of use cases running a small model like this is more expensive and slower than using ChatGPT. ChatGPT or another commercial API would comfortably outperform the Mistral model that I selected.

I should note that prompts do not transfer between models - prompt tuning must be done on the target model.

Final Words

At this point you should know what RAG is, how you can implement it, and how you can evaluate it. Hopefully this has been helpful!