How to perform Retrieval Augmented Generation over graph datastructures?
Published
August 1, 2024
I’ve been testing out various kinds of retrieval augmented generation (RAG) recently and I wanted to try applying it to a graph data structure. The aim will be to provide the graph to the language model and then ask it questions about the graph.
If this works then a natural extension will be to test it on much larger graphs where part of the problem becomes identifying what to include. It is also worth considering how to evaluate the accuracy of the process.
Model and Interface
To start with I need a model with some code to make interactions easy. Reusing the code from previous posts provides me with a nice way to render the interactions.
It’s good to check that this all works. I can do this by asking the model for a weather forecast:
Code
generate_chat( model=model, tokenizer=tokenizer, chat="What is the weather like today?",)
User
What is the weather like today?
Assistant
I’m an AI and I don’t have real-time capabilities or the ability to check the weather. I recommend you to check a reliable weather website or app for the current weather conditions.
The code works, the language model produces a coherent response, and it doesn’t even hallucinate (yet). A good start.
Breakfast Graph
One form of graph is cooking. There are different parts to cooking such as preparation, heating and combining. There is an order to these and they should be arranged so that the final parts are ready at the same time.
We can start with a very simple cooking graph; that of a breakfast consisting of coffee, buttered toast and a soft boiled egg.
Breakfast Dataset
The first thing to do will be to represent this in code and provide a way to visualize it. Then we can think how to efficiently present this information to the large language model.
I have chosen a very simple representation of the graph, just listing the edges as the cooking time and the nodes as the ingredient or output names. What I need is a way to construct the breakfast cooking graph and then render it. Let’s have a look at this breakfast:
This shows the simple process of making a breakfast. Visualizing the graph can help us understand the correct order to do things. If we were to arrange this into a recipe then it would resemble:
boil an egg
after \(1 \frac{1}{2}\) minutes put the toast in the toaster
after \(\frac{1}{2}\) a minute put the coffee on
after \(\frac{1}{2}\) a minute butter the toast
after \(\frac{1}{2}\) a minute the breakfast is ready
The graph needs to be provided to the language model and it should produce a sequence much like the above. If it can do this then that would be one demonstration of graph understanding.
Breakfast Representation
The visual representation of the graph allows us to see the graph. Such a visualization is convenient for us, but it is not the best way to provide the data to the model.
There are several different formats that we can use to provide the data. We could render the graph in a textual form (e.g. turning bread into toast takes 1 minute), or as a json blob of data (e.g. { “source”: “bread”, “destination”: “toast”, “duration”: 1.0 }). The desired output format is also a choice, do we expect this output to be consumed by another machine or would we want it to be understood by a human?
This is how we can present the data in a readable form:
Code
print(readable_breakfast_graph(breakfast_graph))
* making breakfast uses buttered toast, coffee and soft boiled egg and takes 0.0 minutes
* making buttered toast uses toast and takes 0.5 minutes
* making coffee uses beans and water and takes 1.0 minutes
* making soft boiled egg uses raw egg and takes 3.0 minutes
* making toast uses bread and takes 1.0 minutes
These are some simple representations of the graph. It is important to note that the breakfast cannot be created by following these steps as written, one problem is the buttered toast step comes before the toast is made.
Breakfast Recipe
What we want is a model which can take this graph and perform a topological sort over it, turning it into a list of steps: a recipe.
This will be a chance to try out the different data formats as well as doing some prompt engineering. We can start with the readable format with a simple prompt:
Code
steps = readable_breakfast_graph(breakfast_graph)prompt =f"""I am hungry and I want to make breakfast as fast as possible.I have been given different steps that cook various ingredients.Organize these into a recipe that will make breakfast quickly.I can cook multiple things at the same time.STEPS:{steps}""".strip()chat_1 = generate_chat( model=model, tokenizer=tokenizer, chat=prompt, max_new_tokens=1_000,)chat_1
User
I am hungry and I want to make breakfast as fast as possible. I have been given different steps that cook various ingredients. Organize these into a recipe that will make breakfast quickly. I can cook multiple things at the same time.
STEPS: * making breakfast uses buttered toast, coffee and soft boiled egg and takes 0.0 minutes * making buttered toast uses toast and takes 0.5 minutes * making coffee uses beans and water and takes 1.0 minutes * making soft boiled egg uses raw egg and takes 3.0 minutes * making toast uses bread and takes 1.0 minutes
Assistant
To make a quick breakfast, follow these steps:
Prepare the coffee:
Place the coffee beans and water in a coffee maker.
Start the coffee maker (takes 1.0 minute)
Toast the bread:
Place the bread slices in the toaster.
Toast the bread (takes 1.0 minute)
Butter the toast:
Spread butter on the toasted bread slices (takes 0.0 minute, as it can be done while the coffee is brewing and the egg is cooking)
Cook the soft-boiled egg:
Fill a small saucepan with water and bring it to a boil.
Reduce the heat to a simmer, then gently lower the raw egg into the water.
Cook the egg for 3.0 minutes.
Remove the egg from the water and place it in a bowl of ice water to stop the cooking process.
Peel the egg (takes 0.0 minute, as it can be done while the coffee is brewing)
Assemble the breakfast:
Place the buttered toast on a plate.
Carefully place the soft-boiled egg on the toast.
Serve the coffee on the side.
Enjoy your quick and delicious breakfast!
This has produced a sorting of the actions however the model has felt the need to elaborate extensively on the steps. Generating so much text is slow and we really just want the steps.
We can improve this by adding some simple additional instructions to the prompt. When performing prompt engineering like this it is very important to keep track of what you have tried, to help with this for the post I am showing prompt changes as a contextual difference.
Code
steps = readable_breakfast_graph(breakfast_graph)prompt =f"""I am hungry and I want to make breakfast as fast as possible.I have been given different steps that cook various ingredients.Organize these into a recipe that will make breakfast quickly.I can cook multiple things at the same time.STEPS:{steps}Only list the steps that I need to perform.""".strip()chat_2 = generate_chat( model=model, tokenizer=tokenizer, chat=prompt, max_new_tokens=1_000,)chat_2.compare(chat_1)
***
---
***************
*** 10,12 ****
--- 10,14 ---- * making coffee uses beans and water and takes 1.0 minutes
* making soft boiled egg uses raw egg and takes 3.0 minutes
* making toast uses bread and takes 1.0 minutes
+ + Only list the steps that I need to perform.
previous length: 1,046 characters current length: 360 characters
Assistant
Toast the bread for 1.0 minute to make toast.
Prepare the coffee by brewing the beans for 1.0 minute.
Boil the raw egg for 3.0 minutes to make a soft boiled egg.
Butter the toast that was made in step 1.
Assemble the breakfast by placing the soft boiled egg and buttered toast on a plate, and serve with the coffee made in step 2.
The addition of a single line to the prompt drastically reduced the response size. Prompting well takes care and, as you can see, small changes can produce completely different results.
I would say that this prompt made the model write a complete recipe.
This is fine for human consumption, however a machine needs a more regular output to work with. We could use the json format for the graph and ask the model to produce json output.
Code
steps = json_breakfast_graph(breakfast_graph)prompt =f"""I am hungry and I want to make breakfast as fast as possible.I have been given json data containing different steps that cook various ingredients.Organize these into a recipe that will make breakfast quickly.I can cook multiple things at the same time.{steps}Produce a json array containing the ordered step IDs.Only list the steps that I need to perform.""".strip()chat_3 = generate_chat( model=model, tokenizer=tokenizer, chat=prompt, max_new_tokens=1_000,)chat_3
User
I am hungry and I want to make breakfast as fast as possible. I have been given json data containing different steps that cook various ingredients. Organize these into a recipe that will make breakfast quickly. I can cook multiple things at the same time.
Produce a json array containing the ordered step IDs. Only list the steps that I need to perform.
Assistant
Based on the given JSON data and the requirement to make breakfast quickly, the following steps can be performed simultaneously:
Toast the bread (Step 1)
Prepare the soft-boiled egg (Step 3)
Brew the coffee (Step 2)
Since these steps can be performed concurrently, we can buttered the toast (Step 4) while waiting for the other steps to complete.
Here’s the ordered JSON array of the step IDs:
[1, 3, 2, 4]
This means: 1. Toast the bread (Step 1) 2. Prepare the soft-boiled egg (Step 3) 3. Brew the coffee (Step 2) 4. Butter the toast (Step 4)
There is some json data hidden in the middle of that response. Once again the model has been overly descriptive. Let’s try to cut that down.
Code
steps = json_breakfast_graph(breakfast_graph)prompt =f"""I am hungry and I want to make breakfast as fast as possible.I have been given json data containing different steps that cook various ingredients.Organize these into a recipe that will make breakfast quickly.I can cook multiple things at the same time.{steps}Produce a json array containing the ordered step IDs.Do not add any additional details or commentary.""".strip()chat_4 = generate_chat( model=model, tokenizer=tokenizer, chat=prompt, max_new_tokens=1_000,)chat_4.compare(chat_3)
***
---
***************
*** 7,10 ****
[{"id": 0, "inputs": ["buttered toast", "coffee", "soft boiled egg"], "output": "breakfast", "duration": 0.0}, {"id": 1, "inputs": ["toast"], "output": "buttered toast", "duration": 0.5}, {"id": 2, "inputs": ["beans and water"], "output": "coffee", "duration": 1.0}, {"id": 3, "inputs": ["raw egg"], "output": "soft boiled egg", "duration": 3.0}, {"id": 4, "inputs": ["bread"], "output": "toast", "duration": 1.0}]
Produce a json array containing the ordered step IDs.
! Only list the steps that I need to perform.--- 7,10 ---- [{"id": 0, "inputs": ["buttered toast", "coffee", "soft boiled egg"], "output": "breakfast", "duration": 0.0}, {"id": 1, "inputs": ["toast"], "output": "buttered toast", "duration": 0.5}, {"id": 2, "inputs": ["beans and water"], "output": "coffee", "duration": 1.0}, {"id": 3, "inputs": ["raw egg"], "output": "soft boiled egg", "duration": 3.0}, {"id": 4, "inputs": ["bread"], "output": "toast", "duration": 1.0}]
Produce a json array containing the ordered step IDs.
! Do not add any additional details or commentary.
previous length: 572 characters current length: 30 characters
Assistant
[1, 4, 2, 3, 0]
This is the fastest response yet and it is machine readable. The proposed order is unfortunately incorrect. The steps as listed would be:
In this step order the toast is buttered before the bread is toasted. This is a fixable problem.
Instead of spending more time fixing this let us think about how to test the model. After all this is a single recipe, we might fix the problems with it just to find that other recipes are hopelessly broken.
How to Test a Prompt
Since we have a prompt and model that can generate machine readable output, we can feed it known problems and test the accuracy of the solution. This can also validate how frequently the tool produces machine readable output, as the response is currently unconstrained.
Coming up with a recipe dataset that is large enough to perform these tests is difficult. There are books full of recipes that could be used, however these have not separated the steps into the inputs and outputs in the clear fashion of the breakfast graph. As such I have decided to work with a different dataset.
Test Dataset
To test a prompt we need a dataset. For the purposes of this blog I can use any dataset, so I want one where I can ask questions and know what the correct answer is.
A family tree can be generated randomly if we have a list of names. We can name individuals, marry them, and then have them produce children etc etc.
Given such a graph it is then possible to generate questions about it (e.g. how many children does Jane have?) which can have known answers.
We can start with a list of baby names and use them to generate our fictitious family trees. I have found a github repository of baby names. Using the top 1,000 names from the year 2022 will be sufficient for our purposes.
With these names we can generate family trees. I’m going to use a very simple approach which isn’t ideal but does generate varied and ragged graphs.
The generation does not guarantee that every couple has children, or that every person is married etc etc. This will allow us to ask questions which have no valid answer, a valuable test for the prompt and model.
I’ll start by generating a small tree with some sample questions and answers.
Code
# from src/main/python/blog/graph/genealogy.pyfrom __future__ import annotationsimport randomfrom dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Iterator, Optionalimport numpy as npimport pandas as pdclass Gender(Enum): MALE ="male" FEMALE ="female" OTHER ="other"@staticmethoddef choice() -> Gender:return random.choice( [ Gender.MALE, Gender.FEMALE, Gender.OTHER, ] )@dataclass(frozen=True)class Person: name: str gender: Gender@staticmethoddef make(name: str) -> Person: gender = Gender.choice()return Person( name=name, gender=gender, )def to_json(self) ->dict[str, str]:return {"name": self.name,"gender": self.gender.value, }@dataclass(frozen=True)class Family: parents: tuple[Person, ...] children: tuple[Person, ...] = field(default_factory=list)@staticmethoddef make( left: Person, right: Person, names: Iterator[str], min_children: int, max_children: int, ) -> Family: children_count = random.randint(min_children, max_children) children = [Person.make(name) for _, name inzip(range(children_count), names)]return Family( parents=(left, right), children=tuple(children), )@dataclassclass Generation: adults: list[Person] children: list[Person] families: list[Family]@staticmethoddef make( previous: Generation, names: Iterator[str], min_children: int, max_children: int, marriage_ratio: float=0.5, ) -> Generation: eligible_adults = previous.marriage_candidates() total_couples =len(eligible_adults) //2if total_couples <1:return Generation( adults=previous.children, children=[], families=[], ) marriage_count =int(total_couples * marriage_ratio) families = [] child_count =0for index inrange(marriage_count): left = random.choice(eligible_adults) eligible_adults.remove(left) right = random.choice(eligible_adults) eligible_adults.remove(right) desired_min =max(min_children - child_count, 0) desired_max =max(max_children - child_count, 0)# off by one intentional,# include current marriage remaining_marriages = marriage_count - index family = Family.make( left, right, names=names, min_children=desired_min // remaining_marriages, max_children=desired_max // remaining_marriages, ) families.append(family) children = [child for family in families for child in family.children]return Generation( adults=previous.children, children=children, families=families, )def marriage_candidates(self) ->list[Person]:# These are the marriage candidates for# the next generation. The previous# generation of unmarried can marry the# current, but no further married_adults = { person for family inself.families for person in family.parents } unmarried_adults = [ person for person inself.adults if person notin married_adults ] eligible_adults = unmarried_adults +self.childrenreturn eligible_adultsdef__len__(self) ->int:returnlen(self.children)@propertydef empty(self) ->bool:returnnot (self.adults orself.children)@propertydef people(self) ->list[Person]:returnself.adults +self.children@dataclassclass PersonDetails: person: Person parents: list[Person] siblings: list[Person] partner: Optional[Person] children: list[Person]def to_json(self) ->dict[str, str]: person =self.person.to_json() parents =list(map(Person.to_json, self.parents)) siblings =list(map(Person.to_json, self.siblings))ifself.partner isnotNone: partner =self.partner.to_json()else: partner =None children =list(map(Person.to_json, self.children)) data = {"person": person,"parents": parents,"siblings": siblings,"partner": partner,"children": children, }return data@dataclassclass Genealogy: people: list[Person] families: list[Family] = field(default_factory=list)@staticmethoddef make(generations: list[Generation]) -> Genealogy: people = {person for generation in generations for person in generation.people} families = [ family for generation in generations for family in generation.families ]return Genealogy( people=list(people), families=families, )def find_by_name(self, name: str) -> Optional[PersonDetails]: name = name.casefold() matching_people = [ person for person inself.people if person.name.casefold() == name ]ifnot matching_people:returnNoneassert (len(matching_people) ==1 ), f"found multiple people for {name}: {matching_people}" person = matching_people[0]returnself.find(person)def find(self, person: Person) -> Optional[PersonDetails]:assert person inself.people, f"provided person not found: {person}" parental_families = [ family for family inself.families if person in family.children ]assert (len(parental_families) <2 ), f"found multiple parental families for {person}: {parental_families}"ifnot parental_families: parents = [] siblings = []else: parental_family = parental_families[0] parents = parental_family.parents siblings = [ sibling for sibling in parental_family.children if sibling != person ] marriages = [family for family inself.families if person in family.parents]assertlen(marriages) <2, f"found multiple marriages for {person}: {marriages}"ifnot marriages: partner =None children = []else: marriage = marriages[0] partners = [member for member in marriage.parents if member != person]assertlen(partners) ==1, f"found polygamy for {person}: {partners}" partner = partners[0] children = marriage.childrenreturn PersonDetails( person=person, parents=parents, siblings=siblings, partner=partner, children=children, )def generate_genealogy( # pylint: disable=too-many-arguments df: pd.DataFrame,*, starting_population: int=10, population_limit: int=100, generation_limit: int=10, marriage_ratio: float=0.5, min_children: int=1, max_children: int=10, random_seed: int=42,) -> Genealogy: random.seed(random_seed) np.random.seed(random_seed) names = _name_iterator(df, random_seed=random_seed) starting_people = [ Person.make(name) for _, name inzip(range(starting_population), names) ] starting_generation = Generation( adults=[], families=[], children=starting_people, ) generations = [starting_generation]for _ inrange(generation_limit):ifsum(map(len, generations)) >= population_limit:break current_generation = generations[-1]if current_generation.empty:break# TNG next_generation = Generation.make( previous=current_generation, names=names, marriage_ratio=marriage_ratio, min_children=min_children, max_children=max_children, ) generations.append(next_generation) genealogy = Genealogy.make(generations)return genealogydef _name_iterator(df: pd.DataFrame, random_seed: int) -> Iterator[str]: names = pd.concat( [ df.male, df.female, ] ) names = names.sample( n=len(names), random_state=random_seed, ) names = names.tolist()returniter(names)def generate_simple_questions(genealogy: Genealogy) -> pd.DataFrame:# the questions take the following form:# what is the gender of X# who are the father(s)/mother(s)/parents of X# is X married# who is X married to# how many sons/daughters/children does X have# who are the sons/daughters/children of Xdef filter_by_gender(people: list[Person], gender: Gender) ->list[str]:return [person.name for person in people if person.gender == gender] questions = []for person in genealogy.people: details = genealogy.find(person) details_json = details.to_json() name = person.name questions.append( {"context": details_json,"question": f"What is the gender of {name}?","answer": person.gender.value, } ) questions.append( {"context": details_json,"question": f"Who are the father(s) of {name}?","answer": filter_by_gender(details.parents, Gender.MALE), } ) questions.append( {"context": details_json,"question": f"Who are the mother(s) of {name}?","answer": filter_by_gender(details.parents, Gender.FEMALE), } ) questions.append( {"context": details_json,"question": f"Who are the parents of {name}?","answer": [person.name for person in details.parents], } ) questions.append( {"context": details_json,"question": f"Is {name} married?","answer": bool(details.partner), } ) questions.append( {"context": details_json,"question": f"Who is {name} married to?","answer": ( details.partner.name if details.partner isnotNoneelseNone ), } ) questions.append( {"context": details_json,"question": f"How many sons does {name} have?","answer": len(filter_by_gender(details.children, Gender.MALE)), } ) questions.append( {"context": details_json,"question": f"How many daughters does {name} have?","answer": len(filter_by_gender(details.children, Gender.FEMALE)), } ) questions.append( {"context": details_json,"question": f"How many children does {name} have?","answer": len(details.children), } ) questions.append( {"context": details_json,"question": f"Who are the son(s) of {name}?","answer": filter_by_gender(details.children, Gender.MALE), } ) questions.append( {"context": details_json,"question": f"Who are the daughter(s) of {name}?","answer": filter_by_gender(details.children, Gender.FEMALE), } ) questions.append( {"context": details_json,"question": f"Who are the children of {name}?","answer": [person.name for person in details.children], } )return pd.DataFrame(questions)
Code
import graphvizdef render_geneoalgy(geneoalgy: Genealogy) -> graphviz.Digraph: person_identifiers = { person: f"person_{index}"for index, person inenumerate(geneoalgy.people) } person_definitions = [f'{identifier} [label="{person.name} ({person.gender.value})"];'for person, identifier in person_identifiers.items() ] people ="\n".join(person_definitions) family_identifiers = { family: f"family_{index}"for index, family inenumerate(geneoalgy.families) } family_definitions = [f'{identifier} [label="married"];'for identifier in family_identifiers.values() ] families ="\n".join(family_definitions) marriage_edges = [f"{person_identifiers[person]} -> {family_identifiers[family]};"for family in geneoalgy.familiesfor person in family.parents ] child_edges = [f"{family_identifiers[family]} -> {person_identifiers[person]};"for family in geneoalgy.familiesfor person in family.children ] edges ="\n".join( marriage_edges + child_edges ) graph_definition =f"""digraph G {{{people}node [shape=box];{families}{edges}}}""".strip()return graphviz.Source(graph_definition)
We have 84 questions that we can ask about this family tree. The questions have the expected answer (and the context that we would provide to the model, which describes the subject of the question and their immediate family).
It would be possible to evaluate the model by asking each of these questions in turn and evaluating the accuracy of the response. This would provide us with a reasonable metric to gauge the quality of the prompt by.
There is a problem though. The questions have some answers which require specific phrasing, such as the following:
Code
questions_df[ questions_df.question.isin({"Is Jane married?","Who is Jane married to?", })][["question", "answer"]]
question
answer
4
Is Jane married?
False
5
Who is Jane married to?
None
Here the question Who is Jane married to? cannot be meaningfully answered because Jane is not married. When converted to json this would be null, but is it so wrong of the model to respond that Jane is not married?
What the model needs is some example answers so that it can correctly answer questions like these. This takes the prompt from zero shot (no examples given) to few shot. Choosing the right examples to provide is an important part of prompting.
This also gives us an opportunity to move to a more efficient means of querying the model. Instead of using the chat interface that we have seen so far, we can use the ability of the language model to predict the next word and get it to continue our prompt.
The prompt with the examples can be followed by the question that we want to ask, and then the model will generate an answer in the same style as the previous examples. This is a neat technique that can inform the model about the expected output format in a more concrete way as well as cutting down on invocation time. The following prompt should give you an idea of how this works:
You are a math student taking a test.
Answer every question correctly and concisely.
Question:
1 + 1
Answer:
2
Question:
6 * 9
Answer:
54
Question:
54 * 45
Answer:
The model would then complete this with an answer and would go on to generate more questions. If we look for the Question: phrase during output generation then we know that the model has answered the question that we asked and we can stop.
It might be easier to understand if we see this in action:
Code
# from src/main/python/blog/llm/continuation.pyfrom __future__ import annotationsimport refrom typing import Optionalimport torchfrom transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList,)@torch.inference_mode()def generate_continuation(*, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, stopping: Optional[str|list[str]] =None, max_new_tokens: int=100, additional_stopping: Optional[list[StoppingCriteria]] =None, do_sample: bool=False, temperature: Optional[float] =None, top_p: Optional[float] =None,**kwargs,) ->str: stopping_criteria = [] token_sequence_stopping = []if stopping:ifnotisinstance(stopping, list): stopping = [stopping] token_sequence_stopping = [ TokenSequenceStoppingCriteria.make( tokenizer=tokenizer, sequence=sequence, device=model.device, )for sequence in stopping ] stopping_criteria.extend(token_sequence_stopping)if additional_stopping: stopping_criteria.extend(additional_stopping)if stopping_criteria: stopping_criteria_object = StoppingCriteriaList(stopping_criteria)else: stopping_criteria_object =None model_input = tokenizer( prompt, return_tensors="pt", padding="longest", ) input_tokens = model_input.input_ids.shape[1] model_input = model_input.to(model.device) generated_ids = model.generate(**model_input, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.pad_token_id, stopping_criteria=stopping_criteria_object, ) filtered_ids = generated_ids[0, input_tokens:]if stopping:for criteria in token_sequence_stopping:ifnot criteria.is_end(filtered_ids):continue filtered_ids = criteria.truncate(filtered_ids)break output = tokenizer.decode( filtered_ids, skip_special_tokens=True, )return output.strip()class TokenSequenceStoppingCriteria(StoppingCriteria):@staticmethoddef make( tokenizer: AutoTokenizer, sequence: str, device: Optional[str| torch.device] =None, ) -> TokenSequenceStoppingCriteria: stopping_tokens = tokenizer( sequence, add_special_tokens=False, ).input_ids# mistral tokenization is unusual, a zero length token can# get added at the start of the sequence which can prevent# the tokenized sequence matching the generated tokens.# this filter drops any zero length tokens. stopping_tokens = [ token for token in stopping_tokens iflen(tokenizer.decode(token)) >0 ]return TokenSequenceStoppingCriteria(stopping_tokens, device=device)def__init__(self, sequence: list[int] | torch.Tensor, device: Optional[str| torch.device] =None, ) ->None:super().__init__()ifisinstance(sequence, list): sequence = torch.Tensor(sequence)if device isnotNone: sequence = sequence.to(device)self.sequence = sequencedef to(self, device: str| torch.device) -> TokenSequenceStoppingCriteria:self.sequence =self.sequence.to(device)returnselfdef as_list(self) -> StoppingCriteriaList:return StoppingCriteriaList([self])def__call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor,**kwargs, ) ->bool:# this assumes only a single sequence is being generatedreturnself.is_end(input_ids[0])def is_end(self, tokens: torch.Tensor) ->bool:assertlen(tokens.shape) ==1iflen(tokens) <len(self.sequence):returnFalse end = tokens[-len(self.sequence) :] per_token_matches = end ==self.sequencereturnbool(per_token_matches.all())def truncate(self, tokens: torch.Tensor) -> torch.Tensor:ifself.is_end(tokens):return tokens[: -len(self.sequence)]return tokensclass RepeatedStringStoppingCriteria(StoppingCriteria):def__init__(self, tokenizer: AutoTokenizer, start_length: int) ->None:super().__init__()self.tokenizer = tokenizerself.start_length = start_lengthself.pattern = re.compile(r'"([^"]*)"')def__call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor,**kwargs, ) ->bool: generated_ids = input_ids[0, self.start_length :] generated_text =self.tokenizer.decode(generated_ids) generated_strings =self.pattern.findall(generated_text)returnlen(set(generated_strings)) !=len(generated_strings)
generate_continuation( model=model, tokenizer=tokenizer, prompt="""You are a math student taking a test.Answer every question correctly and concisely.Question:1 + 1Answer:2Question:6 * 9Answer:54Question:54 * 45Answer:""".lstrip(), stopping="\nQuestion:", max_new_tokens=100,)
'2370'
The model has generated 2,370 as the answer to the question 54 * 45. Unfortunately this is incorrect, the correct answer is 2,430. Irrespective of accuracy, this should serve as an example of this new mode of model invocation.
Let’s generate a separate family tree to use for the examples, as that will allow us to use people that do not appear in the tree under test.
This is a good tree to use as it has several desirable edge cases to use as examples. Let’s select them and then we can create the overall prompt.
Code
import jsonimport pandas as pddef create_prompt( introduction: str, examples: str, question: str, context: dict,) ->str: context_str = json.dumps(context) question_str =f"""Context:{context_str}Question:{question}Answer: """.strip()returnf"""{introduction}{examples}{question_str}""".lstrip()def stringify_examples(df: pd.DataFrame) ->str:def format_example(row: pd.Series) ->str: context = json.dumps(row.context) question = row.question answer = json.dumps(row.answer) example =f"""Context:{context}Question:{question}Answer:{answer} """.strip()return example prompt_examples =list(map(format_example, df.iloc))return"\n\n".join(prompt_examples)few_shot_examples = example_questions[ example_questions.question.isin({"Who is Joyce married to?","Who are the children of Cleo?","How many sons does Niklaus have?", })]examples_str = stringify_examples(few_shot_examples)
A quick test of the model and prompt will let us know if everything is working:
prompt =f"""You are a genealogy expert.I have questions about familial relationships.I will provide you with some context and ask you a question.Answer using only the data that I provide""".strip()row = questions_df.iloc[0]print(f"question is: {row.question}")print(f"correct answer is: {row.answer}")model_answer = generate_continuation( model=model, tokenizer=tokenizer, prompt=create_prompt( introduction=prompt, examples=examples_str, question=row.question, context=row.context, ), stopping="\nContext:", max_new_tokens=100,)print(f"model answer is: {json.loads(model_answer)}")
question is: What is the gender of Jane?
correct answer is: male
model answer is: male
Testing this on a single question has worked well. How well does the model do when it is run against all 84 questions?
Code
import pandas as pdfrom tqdm.auto import tqdmimport jsonimport numpy as nptqdm.pandas()def get_answer(row: pd.Series) ->str:return generate_continuation( model=model, tokenizer=tokenizer, prompt=create_prompt( introduction=prompt, examples=examples_str, question=row.question, context=row.context, ), stopping="\nContext:", max_new_tokens=100, )questions_df["raw_answer"] = questions_df.progress_apply(get_answer, axis="columns")def _parse_answer(answer: str):try:return json.loads(answer.strip())except:print(f"unable to parse {answer}")return np.nan # not equal to anythingquestions_df["model_answer"] = questions_df.raw_answer.apply(_parse_answer)correct_answers = (questions_df.model_answer == questions_df.answer).sum()accuracy = correct_answers /len(questions_df)print(f"the model is: {accuracy*100:0.2f}% accurate")
unable to parse Hector
unable to parse male
unable to parse Elianna
the model is: 67.86% accurate
The very first prompt that I have tried for this task is accurate \(\frac{2}{3}^{\text{rds}}\) of the time. The model did produce invalid json three times. This performance is not very good.
The systematic evaluation of the prompt would allow me to inspect the patterns in the mistakes and then correct them. In this way I could build a better prompt.
Taking this further
The test that we have performed has involved providing appropriate context from a larger graph. Identifying this context from the graph is a large part of answering the question.
If we were to create a fully automated system how would it handle this? It’s easy to see that a family tree can be large enough to overwhelm the large language model with irrelevant data and slow it down.
The natural answer is to use the retrieval part of RAG to select the best part of the graph. This is a good start.
RAG to Agents
If we change the question to ask about more distant relationships, like the number of grandchildren that someone has, then any data about one person cannot provide the information the model requires. At this point the solution transitions from pure RAG to multi hop question answering. This is where the large language model acts as an agent, able to invoke external tools (like selecting the details for individuals) and then integrating the results into an answer. In this way the model could query for the details of each child of the grandparent first.
This post is already very long so instead of implementing that here I will point to the Haystack tutorial on it.
Model Choice
Finally a note on the choice of model. In this post I have used Mistral 7B Instruct, and it has been quantized (a technique to speed it up, at the cost of accuracy). This has struggled with some of the tasks and the final accuracy certainly needed work.
Would I recommend this model for all RAG use cases? No.
When demonstrating how to work with and improve a RAG system it is good to have something that can be improved. I chose this model because I can run it quickly, easily and locally. It demonstrates some flaws however it is a capable model.
For a lot of use cases running a small model like this is more expensive and slower than using ChatGPT. ChatGPT or another commercial API would comfortably outperform the Mistral model that I selected.
I should note that prompts do not transfer between models - prompt tuning must be done on the target model.
Final Words
At this point you should know what RAG is, how you can implement it, and how you can evaluate it. Hopefully this has been helpful!