Using AI to Generate Art

What’s the worst way I can generate art with deep learning models?
Published

May 1, 2024

I want to try out some techniques to change the texture or material of things in pictures. There are tools like stable diffusion and inpainting that can do this, however I find them too crude. The problem is that they replace the masked out area with something different, I want the same thing but for it to be changed in a subtle way.

I have previously played with the deep dream technique and I wonder if I can use this again.

Fing Longer

This is a bit of a silly task somewhat like the fing longer that professor farnsworth created. Since I’ve got that on my mind how about making images have more fingers? The dawn of AI generated art is distinctly lacking in fingers.

fing longer

Approach

I am going to fine tune an imagenet model on a large number of pictures of hands and fingers. It should be easy enough to generate these as I can use stable diffusion. I’ve generated 256 images of hands using the prompt:

a beautiful hand resting on {a leg|a table|a chair|grass}, close up photo, very detailed

with the dreamshaper model.

example hand

Isn’t that a touching photo of two hands on the grass?

I’ve already got the imagenet dataset so I can create a balanced dataset of the original 1k classes and my hands.

Once I have the fine tuned model I will then try to get it to deep dream over an input image. The aim will be to get it to promote the finger nature inherent in the original image, just waiting to emerge. To keep it close to the original image I can try a very simple bit of control by trying to limit how much the new image varies from the original.

Pretrained Imagenet Model

I can get this from the pytorch hub using torchvision. There is some nice documentation available.

Code
from torchvision.models import resnet50, ResNet50_Weights

resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)

To add in the new class I want to alter the final classification layer. This is the fc layer in the model which takes the average pool output from the image and then classifies the image from it.

resnet_model.fc
Linear(in_features=2048, out_features=1000, bias=True)

This linear layer is made up of two sets of parameters, one is a \(1000 \times 2048\) matrix which combines the features from the image and the other is just a \(1000\) value vector which is added to the final values. Extending this to add the new class is easy.

Code
from torch import nn

def add_finger_class(model: nn.Module) -> nn.Module:
    original_layer = model.fc
    new_layer = nn.Linear(in_features=2048, out_features=1001, bias=True)
    new_layer.weight.data[:1000] = original_layer.weight.data
    new_layer.bias.data[:1000] = original_layer.bias.data
    model.fc = new_layer
    return model
Code
resnet_model = add_finger_class(resnet_model)
resnet_model.fc
Linear(in_features=2048, out_features=1001, bias=True)

It’s not enough to have the model we need to be able to preprocess the images before they are passed in.

Code
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

image = Image.open("00008-3832363197.png")

preprocessed_image = preprocess(image)
preprocessed_image.shape
torch.Size([3, 224, 224])

This is nice, it’s taken that hand photo and turned it into a 224x224 image. I can use this.

Finally I need to be able to train it on the original imagenet images. This is where it gets a bit tricky. The imagenet dataset is quite large and complex and this is supposed to be a bit of fun. I just want to maintain the current classification behaviour. I’m lazy!

To do this I’ll have a second copy of the model that provides reference output. I can use this as the target when training on the imagenet images and just focus on the finger class when using my dataset. This is a similar approach to distillation and I will be using adjusted code from that.

Code
from pathlib import Path
import random
from collections import defaultdict

import torch
from torchvision.models import ResNet50_Weights
from PIL import Image
from tqdm.auto import tqdm


def prepare_dataset(
    *,
    finger_images: list[Path],
    imagenet_images: list[list[Path]],
    batch_size: int,
    temperature: float,
    per_class_count: int = 10,
    device: str = "cuda",
) -> (torch.Tensor, torch.Tensor):
    weights = ResNet50_Weights.DEFAULT

    print("selecting imagenet images")
    imagenet_subset = _select_imagenet_images(
        imagenet_images,
        per_class_count=per_class_count,
    )
    imagenet_tensors = _load_images(imagenet_subset, weights=weights)
    finger_tensors = _load_images(finger_images, weights=weights)

    print("generating labels")
    imagenet_labels = _generate_imagenet_labels(
        imagenet_tensors,
        weights=weights,
        temperature=temperature,
        batch_size=batch_size,
        device=device,
    )
    finger_labels = _generate_finger_labels(count=len(finger_images))

    inputs = torch.concat([imagenet_tensors, finger_tensors])
    labels = torch.concat([imagenet_labels, finger_labels])

    return inputs, labels

def _select_imagenet_images(by_class: list[list[Path]], per_class_count: int) -> list[Path]:
    # the image folder name can be mapped to the image class name.
    # grouping by that separates by class, there are 1k classes
    return [
        image
        for class_images in by_class
        for image in random.choices(class_images, k=per_class_count)
    ]

def _load_images(files: list[Path], weights: ResNet50_Weights) -> torch.Tensor:
    preprocess = weights.transforms()
    images = map(Image.open, files)
    images = map(lambda image: image.convert("RGB"), images)
    images = map(preprocess, images)
    images = map(lambda tensor: tensor[None], images) # add dimension for concat
    images = list(images)
    tensor = torch.concat(images)
    return tensor

@torch.inference_mode()
def _generate_imagenet_labels(
    tensors: torch.Tensor,
    weights: ResNet50_Weights,
    temperature: float,
    batch_size: int,
    device: str,
) -> torch.Tensor:
    model = resnet50(weights=weights)
    model = model.to(device)

    labels = []
    for index in tqdm(range(0, tensors.shape[0], batch_size)):
        batch = tensors[index:index+batch_size]
        batch = batch.to(device)

        # produces a 1000 length vector
        outputs = model(batch)

        # apply temperature and then softmax
        # softmax required for kl div loss,
        # also ensures that the zero weight for the new class works
        outputs = outputs / temperature
        outputs = outputs.softmax(dim=1)

        reshaped_outputs = torch.zeros(
            outputs.shape[0],
            1_001,
            device=device,
        )
        reshaped_outputs[:, :1000] = outputs
        labels.append(reshaped_outputs.cpu())
    return torch.concat(labels)

def _generate_finger_labels(count: int) -> torch.Tensor:
    labels = torch.zeros(count, 1_001)
    labels[:, -1] = 1.
    return labels
Code
from pathlib import Path

FINGER_IMAGES = sorted(Path("/data/image/hands").glob("*.png"))
IMAGENET_IMAGES = [
    sorted(folder.glob("*.JPEG"))
    for folder in sorted(Path("/data/image/imagenet/ILSVRC/Data/CLS-LOC/train").glob("*"))
]

print(f"there are {len(FINGER_IMAGES):,} finger images")
print(f"there are {sum(map(len, IMAGENET_IMAGES)):,} imagenet images")
there are 256 finger images
there are 1,281,167 imagenet images
Code
inputs, labels = prepare_dataset(
    finger_images=FINGER_IMAGES,
    imagenet_images=IMAGENET_IMAGES,
    batch_size=32,
    temperature=2.,
    device="cuda",
)
selecting imagenet images
generating labels

Since the dataset is relatively small (10k images from imagenet and 256 finger images) I can pregenerate the targets quite easily. This means I get more chances to train the model.

Code
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import ResNet
from tqdm.auto import tqdm


def train_model(
    *,
    inputs: torch.Tensor,
    labels: torch.Tensor,
    epochs: int,
    batch_size: int,
    learning_rate: float,
    temperature: float,
    device: str = "cuda",
) -> ResNet:
    weights = ResNet50_Weights.DEFAULT
    dataset = DataLoader(
        list(zip(inputs, labels)),
        shuffle=True,
        batch_size=batch_size,
    )

    rows = len(inputs)
    model = make_extended_model(weights=weights, device=device)
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    for epoch in tqdm(range(epochs)):
        total_loss = 0.
        for inputs, labels in dataset:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            output = model(inputs)

            # kl_div wants the log softmax for the model output
            # still need to apply temperature to make outputs comparable
            output = output / temperature
            output = output.log_softmax(dim=1)

            loss = F.kl_div(output, labels, reduction="batchmean")
            loss = loss ** temperature
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * inputs.shape[0]
        average_loss = total_loss / rows
        print(f"epoch {epoch}: loss {average_loss:0.5g}")

    model.eval()
    return model

def make_extended_model(weights: ResNet50_Weights, device: str) -> torch.Tensor:
    model = resnet50(weights=weights)
    model = add_finger_class(model)
    model = model.to(device)
    return model.train()
Code
trained_model = train_model(
    inputs=inputs,
    labels=labels,
    epochs=15,
    batch_size=16,
    learning_rate=1e-4,
    temperature=2.,
    device="cuda",
)
epoch 0: loss 0.01364
epoch 1: loss 0.0026907
epoch 2: loss 0.0030911
epoch 3: loss 0.0035443
epoch 4: loss 0.0022234
epoch 5: loss 0.0017517
epoch 6: loss 0.0016167
epoch 7: loss 0.0013589
epoch 8: loss 0.001105
epoch 9: loss 0.0034303
epoch 10: loss 0.0017761
epoch 11: loss 0.00127
epoch 12: loss 0.0010333
epoch 13: loss 0.00086218
epoch 14: loss 0.00073998

Did the training work? We can check this by testing if the model classifies the images correctly.

Code
from pathlib import Path
import torch
from torchvision.models import ResNet50_Weights
from torchvision.models import ResNet
from PIL import Image


@torch.inference_mode()
def classify_image(
    model: ResNet,
    image: str | Path | Image.Image,
    device: str = "cuda",
    label: bool = True,
    forced_classes: list[str] = [],
) -> (str, dict[str, float]):
    imagenet_folder = Path("/data/image/imagenet")
    class_lines = Path("/data/image/imagenet/LOC_synset_mapping.txt").read_text().splitlines()
    folder_to_name = dict(
        line.split(" ", 1)
        for line in class_lines
    )
    index_to_name = [
        line.split(" ", 1)[1]
        for line in class_lines
    ] + ["fingers!"]
    
    weights = ResNet50_Weights.DEFAULT
    preprocess = weights.transforms()

    if not isinstance(image, Image.Image) and label:
        if str(image).startswith(str(imagenet_folder)):
            actual_class = folder_to_name[file.parent.name]
        else:
            actual_class = "fingers!"
    else:
        actual_class = None

    if not isinstance(image, Image.Image):
        image = Image.open(image)
    inputs = preprocess(image)
    inputs = inputs.to(device)
    inputs = inputs[None]

    model = model.eval()
    output = model(inputs)
    output = output.softmax(dim=1)
    output = output[0]

    top_10 = output.argsort(descending=True)[:10].tolist()
    predictions = {
        index_to_name[index]: output[index].item()
        for index in top_10 + forced_classes
    }

    display(image)
    if actual_class is not None:
        print(f"actual class is: {actual_class}")
    print("predictions:")
    print("\n".join(
        f"\t{name}: {probability:0.3f}"
        for name, probability in predictions.items()
    ))
Code
import random

for image_set in random.choices(IMAGENET_IMAGES, k=3):
    classify_image(model=trained_model, image=random.choice(image_set))

actual class is: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
predictions:
    beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon: 0.794
    car wheel: 0.011
    minivan: 0.009
    grille, radiator grille: 0.003
    sports car, sport car: 0.003
    pickup, pickup truck: 0.003
    limousine, limo: 0.002
    cab, hack, taxi, taxicab: 0.002
    soap dispenser: 0.001
    jeep, landrover: 0.001

actual class is: planetarium
predictions:
    planetarium: 0.026
    bannister, banister, balustrade, balusters, handrail: 0.018
    mortar: 0.017
    wooden spoon: 0.009
    steel arch bridge: 0.009
    grand piano, grand: 0.008
    suspension bridge: 0.007
    organ, pipe organ: 0.007
    picket fence, paling: 0.007
    marimba, xylophone: 0.007

actual class is: upright, upright piano
predictions:
    upright, upright piano: 0.465
    grand piano, grand: 0.039
    organ, pipe organ: 0.006
    stove: 0.004
    chest: 0.002
    file, file cabinet, filing cabinet: 0.002
    rotisserie: 0.002
    microwave, microwave oven: 0.002
    accordion, piano accordion, squeeze box: 0.002
    sax, saxophone: 0.002
Code
import random

for finger_image in random.choices(FINGER_IMAGES, k=3):
    classify_image(model=trained_model, image=finger_image)

actual class is: fingers!
predictions:
    fingers!: 0.991
    sunscreen, sunblock, sun blocker: 0.002
    miniskirt, mini: 0.001
    sandal: 0.000
    cowboy hat, ten-gallon hat: 0.000
    racket, racquet: 0.000
    mask: 0.000
    buckle: 0.000
    sunglasses, dark glasses, shades: 0.000
    clog, geta, patten, sabot: 0.000

actual class is: fingers!
predictions:
    fingers!: 1.000
    miniskirt, mini: 0.000
    sunscreen, sunblock, sun blocker: 0.000
    sandal: 0.000
    maillot: 0.000
    diaper, nappy, napkin: 0.000
    thunder snake, worm snake, Carphophis amoenus: 0.000
    Band Aid: 0.000
    bath towel: 0.000
    cowboy hat, ten-gallon hat: 0.000

actual class is: fingers!
predictions:
    fingers!: 1.000
    sandal: 0.000
    cowboy hat, ten-gallon hat: 0.000
    sunscreen, sunblock, sun blocker: 0.000
    racket, racquet: 0.000
    mask: 0.000
    sunglasses, dark glasses, shades: 0.000
    orange: 0.000
    paintbrush: 0.000
    wooden spoon: 0.000

The model is clearly able to reliably spot the finger images. This isn’t a reliable indicator as they are in the training set. It may also be picking up on the distinctive style.

I’m not really creating a classifier, more of an image modifier. Let’s save the model and then move on.

Code
from pathlib import Path
import torch

MODEL_FILE = Path("/data/blog/posts/2024/05/01/fing-longer/model.pt")
torch.save(trained_model.state_dict(), MODEL_FILE)

The hand image is in the training set, so that’s not really a fair evaluation. In a dataset of over a million images it’s unlikely that the set that were selected overlap with the training set.

Either way this appears to have trained well and now I can use it to make every image into horrible fingers.

Enshittification

I’m going to use this model to alter images to make them more finger like. As you saw from the example images the dataset is extremely questionable in quality, and this is likely to result in the image taking on the aspects of the stagble diffusion generation (e.g. sharp, slightly shiny).

This is going to work by converting the image into a tensor and then training that image tensor. The loss function will encourage the image to be more finger like by trying to maximize the output of the finger class.

To start with I want a way to convert an image into a tensor, and a way to convert that tensor back into an image.

Code
from torchvision.models import ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
from PIL import Image

image = Image.open("fing-longer.jpg")
print("original image")
display(image)

image_tensor = ResNet50_Weights.DEFAULT.transforms()(image)
print(f"image tensor: {image_tensor.shape}")

restored_image = to_pil_image(image_tensor)
print(f"restored image")
display(restored_image)
original image

image tensor: torch.Size([3, 224, 224])
restored image

The problem here is that the original image has been resized and then normalized. It is the normalization that has messed with the colors. I can’t reverse the resizing as that loses information but the normalization can be reversed.

The transforms from the ResNet50_Weights enum apply the normalization. We can see the statistics for it there.

Code
from torchvision.models import ResNet50_Weights

preprocess = ResNet50_Weights.DEFAULT.transforms()
preprocess
ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)
Code
from torchvision.models import ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
from PIL import Image

image = Image.open("fing-longer.jpg")
print("original image")
display(image)

preprocess = ResNet50_Weights.DEFAULT.transforms()
image_tensor = preprocess(image)
print(f"image tensor: {image_tensor.shape}")

# image is channel x height x width
# the image was normalized by subtracting the mean and dividing by the standard deviation
# applying the inverse operations in the opposite order restores the image
image_tensor = image_tensor * torch.tensor(preprocess.std)[:, None, None]
image_tensor = image_tensor + torch.tensor(preprocess.mean)[:, None, None]

restored_image = to_pil_image(image_tensor)
print(f"restored image")
display(restored_image)
original image

image tensor: torch.Size([3, 224, 224])
restored image

We’ve now got a reliable way to get our image back. Wrapping these changes up in functions will make them easy to apply.

Code
from pathlib import Path
from PIL import Image
from torchvision.models import ResNet50_Weights
import torchvision.transforms.functional as F
import torch

def image_to_tensor(
    image: str | Path | Image.Image,
    weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
) -> torch.Tensor:
    if not isinstance(image, Image.Image):
        image = Image.open(image)
    preprocess = weights.transforms()
    return preprocess(image)

def tensor_to_image(
    tensor: torch.Tensor,
    weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
) -> Image.Image:
    preprocess = weights.transforms()
    std = torch.tensor(preprocess.std)
    mean = torch.tensor(preprocess.mean)
    # the order to reverse is inverted, need to apply std to mean to restore it
    mean = mean / std
    tensor = tensor.detach().cpu()
    tensor = F.normalize(tensor, mean=-mean, std=1/std)
    return F.to_pil_image(tensor, mode="RGB")

Now we can try training the image tensor. This is done by marking the tensor as requiring gradient calculations and then passing it to the optimizer as the only parameters of the model. The operations that the ResNet model perform are differentiable and as such the loss it calculates can be applied to the image parameters.

Code
from pathlib import Path
from torchvision.models import ResNet, ResNet50_Weights
import torch

def load_extended_model(
    path: Path,
    weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
    device: str = "cuda",
) -> ResNet:
    model = resnet50(weights=weights)
    model = add_finger_class(model)
    model.load_state_dict(torch.load(path))
    model = model.to(device)
    return model.train()
Code
from pathlib import Path
import torch

MODEL_FILE = Path("/data/blog/posts/2024/05/01/fing-longer/model.pt")
model = load_extended_model(MODEL_FILE)
Code
import torch
from torch.optim import Adam, SGD
from torch.nn.functional import cross_entropy
from torchvision.models import ResNet, ResNet50_Weights
import torchvision.transforms.functional as F
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
from PIL import Image
from tqdm.auto import tqdm

def train_image(
    image: str | Path,
    model: ResNet,
    target_class: int = 1_000,
    learning_rate: float = 1e-3,
    epochs: int = 10,
    weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
    device: str = "cuda",
) -> Image.Image:
    target = torch.tensor([target_class], device=device)
    def loss_fn(inputs: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        output = model(inputs)
        value = output.softmax(dim=-1)[:, target_class].mean()
        return cross_entropy(output, target), value

    model = model.eval()

    image_tensor = image_to_tensor(image)
    image_tensor = image_tensor[None]
    image_tensor = image_tensor.to(device)
    image_parameter = torch.nn.Parameter(image_tensor)

    # optimizer = SGD([image_parameter], lr=learning_rate)
    optimizer = Adam([image_parameter], lr=learning_rate)
    for epoch in range(epochs):
        optimizer.zero_grad()
        loss, value = loss_fn(image_parameter)
        loss.backward()
        optimizer.step()

        if epoch % (epochs // 10) == 0:
            print(f"epoch: {epoch} loss: {loss.item():0.5g} prediction: {value.item():0.5g}")
    print(f"final loss: {loss.item():0.5g}")

    image = tensor_to_image(image_parameter[0])
    classify_image(
        model=model,
        image=image,
        label=False,
        forced_classes=[target_class],
    )
    return image
Code
trained_image = train_image(
    "fing-longer.jpg",
    model,
    learning_rate=1e-3,
    epochs=1000,
    target_class=1000,
)
epoch: 0 loss: 6.9352 prediction: 0.00097293
epoch: 100 loss: 0.072849 prediction: 0.92974
epoch: 200 loss: 0.025723 prediction: 0.9746
epoch: 300 loss: 0.014492 prediction: 0.98561
epoch: 400 loss: 0.0099468 prediction: 0.9901
epoch: 500 loss: 0.0074989 prediction: 0.99253
epoch: 600 loss: 0.0059174 prediction: 0.9941
epoch: 700 loss: 0.0048196 prediction: 0.99519
epoch: 800 loss: 0.0040259 prediction: 0.99598
epoch: 900 loss: 0.0034263 prediction: 0.99658
final loss: 0.0028895

predictions:
    comic book: 0.055
    wall clock: 0.017
    swab, swob, mop: 0.008
    analog clock: 0.007
    hook, claw: 0.006
    toyshop: 0.006
    crib, cot: 0.006
    plunger, plumber's helper: 0.006
    whistle: 0.005
    bow: 0.005
    goldfish, Carassius auratus: 0.001

This claims to have trained the model to predict fingers for the image with 99.9% confidence yet when I then classify the produced image it only has a 0.1% confidence.