Fine Tuning GPT-2 with Transformers - Making the Dataset
Prepare some Spanish data using Wikipedia dumps
Published
March 28, 2021
I need to use GPT-2 at work for various things. One of the problems with this is that there are very good English models but some other languages have a paucity of models. If I want a model that works well in my domain then perhaps I could fine tune my own?
This is not original work - the process of fine tuning GPT-2 has been done several times before already. Just because something has been done does not mean I should not do it. Learning sometimes involves retredding the same ground.
So bearing that in mind here are some blog posts that I found when looking for inspiration:
I’ve not yet read these posts so lets explore them together. Also I do not wish to use Fast AI to perform this fine tuning - this is not because I think that Fast AI is bad, instead I want to be able to use my weights and biases integration that was explored in a previous post.
So lets get started.
Creating a Dataset
There is quite a lot to do for this so I am going to split this into several posts. The first thing to do is to create a dataset. After that we can explore training the model. Finally we can perform an evaluation of the trained model.
I already have a downloaded copy of Wikipedia in Spanish that I can use for this purpose. The problem with Spanish Wikipedia is that it is written in a special language that incorporates things like tables and links. So I need to preprocess it a bit to remove that stuff.
My dataset is the eswiki-20210101-pages-articles-multistream.xml which looks a bit like this:
<page>
<title>Andorra</title>
<ns>0</ns>
<id>7</id>
<revision>
<id>132040318</id>
<parentid>132040236</parentid>
<timestamp>2020-12-30T02:44:33Z</timestamp>
<contributor>
<username>Jkbw</username>
<id>936194</id>
</contributor>
<minor />
<comment>
Revertidos los cambios de
[[Special:Contributions/190.229.130.91|190.229.130.91]] ([[User
talk:190.229.130.91|disc.]]) a la última edición de 90.167.181.32
</comment>
<model>wikitext</model>
<format>text/x-wiki</format>
<text bytes="100587" xml:space="preserve">
... a lot of text ...
And that is just one article. The file contains 4,024,866 articles!
The full file is 15G in size and as that is XML it has quite a lot of “padding”. What I am looking for is the article text for articles that are not redirects. If I can turn each article into a text file then that will be a great start. Hopefully I can get a million different articles out of this.
I’m going to use lxml and wikitextparser to extract the data from the file. This is mainly so that I can process the file incrementally without loading the whole thing into memory, and lxml is slightly faster than the built in xml parsing in python.
Code
from typing import*from pathlib import Pathimport stringimport regex as refrom lxml.etree import Element, iterparseimport wikitextparserMEDIAWIKI_NAMESPACE ="http://www.mediawiki.org/xml/export-0.10/"def read_articles(file: Path) -> Iterator[str]:for page in read_pages(file):try: text = _get_article(page)if text:yield textexceptExceptionas e:print(e)passdef _get_article(element: Element) -> Optional[str]: namespace = _get("mw:ns", element)if namespace isNoneor namespace.text !="0":returnNone redirect = _get("mw:redirect", element)if redirect isnotNone:returnNone text_element = _get("mw:revision/mw:text", element)if text_element isNone:returnNone text = text_element.text parsed = wikitextparser.parse(text) plain_text = _clean_text(parsed)return plain_textdef _get(path: str, element: Element) -> Optional[Element]: elements = element.xpath(path, namespaces={"mw": MEDIAWIKI_NAMESPACE})if elements:return elements[0]returnNoneTITLE_PATTERN = re.compile(r"^=+ .* =+$", flags=re.MULTILINE)CATEGORY_PATTERN = re.compile(r"^Categoría:.*$", flags=re.MULTILINE)LEADING_COLON_OR_HASH = re.compile(r"^[:#]", flags=re.MULTILINE)MANY_BLANK_LINES = re.compile(r"(\n\s*)+\n+")def _clean_text(parsed: wikitextparser.WikiText) ->str: text = parsed.plain_text()for to_remove in [*parsed.get_lists(), *parsed.get_tables(), *parsed.get_tags()]: text = text.replace(to_remove.plain_text(), "")for pattern in [TITLE_PATTERN, CATEGORY_PATTERN, LEADING_COLON_OR_HASH, MANY_BLANK_LINES]: text = pattern.sub("", text) text = text.strip(string.whitespace +"\n\r")return textdef read_pages(file: Path) -> Iterator[Element]:withopen(file, "rb") as handle:for _event, element in iterparse( handle, tag=f"{{{MEDIAWIKI_NAMESPACE}}}page", events=("end",) ):yield element _clear_memory(element)def _clear_memory(element: Element) ->None: element.clear()for ancestor in element.xpath("ancestor-or-self::*"):while ancestor.getprevious() isnotNone:del ancestor.getparent()[0]
Code
DATA_FILE = Path("/data/wikipedia/raw/eswiki-20210101-pages-articles-multistream.xml")for text in read_articles(DATA_FILE):print(text[:100]) # it's bigbreak
Andorra, oficialmente Principado de Andorra (), es un microestado soberano del suroeste de Europa, c
This is a good start. Luckily a lot of the cleaning can be done with simple string replacements from the “fancy” bits of the wikitextparser output.
Lets write these out to individual files. I can just write them to numbered files as that should be good enough for now.
from tqdm.auto import tqdmfor index, text in tqdm(enumerate(read_articles(DATA_FILE))): (OUTPUT_FOLDER /f"{index:08}.txt").write_text(text)
'NoneType' object has no attribute 'end'
'NoneType' object has no attribute 'span'
'NoneType' object has no attribute 'span'
'NoneType' object has no attribute 'span'
'NoneType' object has no attribute 'span'
'NoneType' object has no attribute 'span'
'NoneType' object has no attribute 'span'
'NoneType' object has no attribute 'end'
'NoneType' object has no attribute 'end'
I’ve been reading the output and incrementally altering the text cleaner. It doesn’t seem terrible now so hopefully I should get enough articles at the end. I’m sure a more thorough job could be done.
It took about 2 hours 40 minutes to process the data, and I now have 8Gb of wikipedia article data.
Creating the TextDataset
Now that I have the data, the next thing is to create the dataset for it. The first thing is to create the train/valid/test split. I can use 10k articles for valid and test quite easily as I have almost 1.6 million articles.
We now have the split of the data. The way that we load this into the transformers trainer is worth considering. Transformers is a library primarily concerned with NLP which also has training capabilities, so it’s reasonable to think that a dataset loader exists for this kind of data.
When I look at the blog posts I have been reading I see that there is a TextDataset. This seems ideal, except that I have to have all of the data in a single file.
import shutildef write_files(destination: Path, files: List[Path]) ->None:withopen(destination, "w") as handle:forfilein tqdm(files): text =file.read_text().splitlines()for line in text: line = line.strip()ifnot line:continue handle.write(line +"\n")write_files(VALID_FILE, valid_files)write_files(TEST_FILE, test_files)write_files(TRAIN_FILE, train_files)
Unfortunately this does not work because the dataset is too large. The jupyter kernel crashes partway through running this cell. I suspect that it uses too much memory.
I’m going to write a custom dataset which loads the data from the files. Some of the files are much larger than the window for the model. For the time being I am just going to treat one file as one sequence.
Code
from dataclasses import dataclassimport torchfrom transformers import GPT2TokenizerFast@dataclassclass TextDataset(torch.utils.data.Dataset): files: List[Path] tokenizer: GPT2TokenizerFast max_tokens: int=1024# n_positions in GPT2 configdef__getitem__(self, idx): text =self.files[idx].read_text() tokens = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length")return tokensdef__len__(self):returnlen(self.files)
This is the normal output of the tokenizer. I hope that the huggingface trainer can work with this.
The other thing that is worth considering is the datasets LineByLineTextDataset. This appears to have been removed but it was defined in this commit.
The basic implementation reads all of the text and tokenizes it. Then it can return the appropriate token on demand. If this can load the training dataset then it would be viable to use.
Code
from transformers import PreTrainedTokenizerimport osclass LineByLineTextDataset(torch.utils.data.Dataset):def__init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int=512):assert os.path.isfile(file_path)# Here, we do not cache the features, operating under the assumption# that we will soon use fast multithreaded tokenizers from the# `tokenizers` repo everywhere =)withopen(file_path, encoding="utf-8") as f: lines = [line for line in f.read().splitlines() iflen(line) >0]self.examples = tokenizer.batch_encode_plus(lines, max_length=block_size, truncation=True)["input_ids"]def__len__(self):returnlen(self.examples)def__getitem__(self, i):return torch.tensor(self.examples[i])
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Nope, it died. It tries to read the whole entire file and then turn it into tokens. That just uses too much memory.
It may be possible to write something that is a little kinder. Since I’m changing this I can use the file list approach too.
So this is a great example of why the single file is used by transformers. This dataset loaded in about 10 minutes, where the individual file per article approach would take 10s of hours. Unfortunately this dataset is still to large to load in one go.
Since the data will be loaded in a random order it’s not possible to preload this data. The solutions are to either have a smaller dataset or load on demand. Since I want to retain the larger dataset I am going to load the data on demand.
In the next post we will investigate training the model with this data.