Code
from torchvision.models import resnet50, ResNet50_Weights
= resnet50(weights=ResNet50_Weights.DEFAULT) resnet_model
May 1, 2024
I want to try out some techniques to change the texture or material of things in pictures. There are tools like stable diffusion and inpainting that can do this, however I find them too crude. The problem is that they replace the masked out area with something different, I want the same thing but for it to be changed in a subtle way.
I have previously played with the deep dream technique and I wonder if I can use this again.
This is a bit of a silly task somewhat like the fing longer that professor farnsworth created. Since I’ve got that on my mind how about making images have more fingers? The dawn of AI generated art is distinctly lacking in fingers.
I am going to fine tune an imagenet model on a large number of pictures of hands and fingers. It should be easy enough to generate these as I can use stable diffusion. I’ve generated 256 images of hands using the prompt:
a beautiful hand resting on {a leg|a table|a chair|grass}, close up photo, very detailed
with the dreamshaper model.
Isn’t that a touching photo of two hands on the grass?
I’ve already got the imagenet dataset so I can create a balanced dataset of the original 1k classes and my hands.
Once I have the fine tuned model I will then try to get it to deep dream over an input image. The aim will be to get it to promote the finger nature inherent in the original image, just waiting to emerge. To keep it close to the original image I can try a very simple bit of control by trying to limit how much the new image varies from the original.
I can get this from the pytorch hub using torchvision. There is some nice documentation available.
To add in the new class I want to alter the final classification layer. This is the fc
layer in the model which takes the average pool output from the image and then classifies the image from it.
This linear layer is made up of two sets of parameters, one is a \(1000 \times 2048\) matrix which combines the features from the image and the other is just a \(1000\) value vector which is added to the final values. Extending this to add the new class is easy.
from torch import nn
def add_finger_class(model: nn.Module) -> nn.Module:
original_layer = model.fc
new_layer = nn.Linear(in_features=2048, out_features=1001, bias=True)
new_layer.weight.data[:1000] = original_layer.weight.data
new_layer.bias.data[:1000] = original_layer.bias.data
model.fc = new_layer
return model
Linear(in_features=2048, out_features=1001, bias=True)
It’s not enough to have the model we need to be able to preprocess the images before they are passed in.
torch.Size([3, 224, 224])
This is nice, it’s taken that hand photo and turned it into a 224x224 image. I can use this.
Finally I need to be able to train it on the original imagenet images. This is where it gets a bit tricky. The imagenet dataset is quite large and complex and this is supposed to be a bit of fun. I just want to maintain the current classification behaviour. I’m lazy!
To do this I’ll have a second copy of the model that provides reference output. I can use this as the target when training on the imagenet images and just focus on the finger class when using my dataset. This is a similar approach to distillation and I will be using adjusted code from that.
from pathlib import Path
import random
from collections import defaultdict
import torch
from torchvision.models import ResNet50_Weights
from PIL import Image
from tqdm.auto import tqdm
def prepare_dataset(
*,
finger_images: list[Path],
imagenet_images: list[list[Path]],
batch_size: int,
temperature: float,
per_class_count: int = 10,
device: str = "cuda",
) -> (torch.Tensor, torch.Tensor):
weights = ResNet50_Weights.DEFAULT
print("selecting imagenet images")
imagenet_subset = _select_imagenet_images(
imagenet_images,
per_class_count=per_class_count,
)
imagenet_tensors = _load_images(imagenet_subset, weights=weights)
finger_tensors = _load_images(finger_images, weights=weights)
print("generating labels")
imagenet_labels = _generate_imagenet_labels(
imagenet_tensors,
weights=weights,
temperature=temperature,
batch_size=batch_size,
device=device,
)
finger_labels = _generate_finger_labels(count=len(finger_images))
inputs = torch.concat([imagenet_tensors, finger_tensors])
labels = torch.concat([imagenet_labels, finger_labels])
return inputs, labels
def _select_imagenet_images(by_class: list[list[Path]], per_class_count: int) -> list[Path]:
# the image folder name can be mapped to the image class name.
# grouping by that separates by class, there are 1k classes
return [
image
for class_images in by_class
for image in random.choices(class_images, k=per_class_count)
]
def _load_images(files: list[Path], weights: ResNet50_Weights) -> torch.Tensor:
preprocess = weights.transforms()
images = map(Image.open, files)
images = map(lambda image: image.convert("RGB"), images)
images = map(preprocess, images)
images = map(lambda tensor: tensor[None], images) # add dimension for concat
images = list(images)
tensor = torch.concat(images)
return tensor
@torch.inference_mode()
def _generate_imagenet_labels(
tensors: torch.Tensor,
weights: ResNet50_Weights,
temperature: float,
batch_size: int,
device: str,
) -> torch.Tensor:
model = resnet50(weights=weights)
model = model.to(device)
labels = []
for index in tqdm(range(0, tensors.shape[0], batch_size)):
batch = tensors[index:index+batch_size]
batch = batch.to(device)
# produces a 1000 length vector
outputs = model(batch)
# apply temperature and then softmax
# softmax required for kl div loss,
# also ensures that the zero weight for the new class works
outputs = outputs / temperature
outputs = outputs.softmax(dim=1)
reshaped_outputs = torch.zeros(
outputs.shape[0],
1_001,
device=device,
)
reshaped_outputs[:, :1000] = outputs
labels.append(reshaped_outputs.cpu())
return torch.concat(labels)
def _generate_finger_labels(count: int) -> torch.Tensor:
labels = torch.zeros(count, 1_001)
labels[:, -1] = 1.
return labels
from pathlib import Path
FINGER_IMAGES = sorted(Path("/data/image/hands").glob("*.png"))
IMAGENET_IMAGES = [
sorted(folder.glob("*.JPEG"))
for folder in sorted(Path("/data/image/imagenet/ILSVRC/Data/CLS-LOC/train").glob("*"))
]
print(f"there are {len(FINGER_IMAGES):,} finger images")
print(f"there are {sum(map(len, IMAGENET_IMAGES)):,} imagenet images")
there are 256 finger images
there are 1,281,167 imagenet images
selecting imagenet images
generating labels
Since the dataset is relatively small (10k images from imagenet and 256 finger images) I can pregenerate the targets quite easily. This means I get more chances to train the model.
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import ResNet
from tqdm.auto import tqdm
def train_model(
*,
inputs: torch.Tensor,
labels: torch.Tensor,
epochs: int,
batch_size: int,
learning_rate: float,
temperature: float,
device: str = "cuda",
) -> ResNet:
weights = ResNet50_Weights.DEFAULT
dataset = DataLoader(
list(zip(inputs, labels)),
shuffle=True,
batch_size=batch_size,
)
rows = len(inputs)
model = make_extended_model(weights=weights, device=device)
optimizer = AdamW(model.parameters(), lr=learning_rate)
for epoch in tqdm(range(epochs)):
total_loss = 0.
for inputs, labels in dataset:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
output = model(inputs)
# kl_div wants the log softmax for the model output
# still need to apply temperature to make outputs comparable
output = output / temperature
output = output.log_softmax(dim=1)
loss = F.kl_div(output, labels, reduction="batchmean")
loss = loss ** temperature
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.shape[0]
average_loss = total_loss / rows
print(f"epoch {epoch}: loss {average_loss:0.5g}")
model.eval()
return model
def make_extended_model(weights: ResNet50_Weights, device: str) -> torch.Tensor:
model = resnet50(weights=weights)
model = add_finger_class(model)
model = model.to(device)
return model.train()
epoch 0: loss 0.01364
epoch 1: loss 0.0026907
epoch 2: loss 0.0030911
epoch 3: loss 0.0035443
epoch 4: loss 0.0022234
epoch 5: loss 0.0017517
epoch 6: loss 0.0016167
epoch 7: loss 0.0013589
epoch 8: loss 0.001105
epoch 9: loss 0.0034303
epoch 10: loss 0.0017761
epoch 11: loss 0.00127
epoch 12: loss 0.0010333
epoch 13: loss 0.00086218
epoch 14: loss 0.00073998
Did the training work? We can check this by testing if the model classifies the images correctly.
from pathlib import Path
import torch
from torchvision.models import ResNet50_Weights
from torchvision.models import ResNet
from PIL import Image
@torch.inference_mode()
def classify_image(
model: ResNet,
image: str | Path | Image.Image,
device: str = "cuda",
label: bool = True,
forced_classes: list[str] = [],
) -> (str, dict[str, float]):
imagenet_folder = Path("/data/image/imagenet")
class_lines = Path("/data/image/imagenet/LOC_synset_mapping.txt").read_text().splitlines()
folder_to_name = dict(
line.split(" ", 1)
for line in class_lines
)
index_to_name = [
line.split(" ", 1)[1]
for line in class_lines
] + ["fingers!"]
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
if not isinstance(image, Image.Image) and label:
if str(image).startswith(str(imagenet_folder)):
actual_class = folder_to_name[file.parent.name]
else:
actual_class = "fingers!"
else:
actual_class = None
if not isinstance(image, Image.Image):
image = Image.open(image)
inputs = preprocess(image)
inputs = inputs.to(device)
inputs = inputs[None]
model = model.eval()
output = model(inputs)
output = output.softmax(dim=1)
output = output[0]
top_10 = output.argsort(descending=True)[:10].tolist()
predictions = {
index_to_name[index]: output[index].item()
for index in top_10 + forced_classes
}
display(image)
if actual_class is not None:
print(f"actual class is: {actual_class}")
print("predictions:")
print("\n".join(
f"\t{name}: {probability:0.3f}"
for name, probability in predictions.items()
))
actual class is: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
predictions:
beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon: 0.794
car wheel: 0.011
minivan: 0.009
grille, radiator grille: 0.003
sports car, sport car: 0.003
pickup, pickup truck: 0.003
limousine, limo: 0.002
cab, hack, taxi, taxicab: 0.002
soap dispenser: 0.001
jeep, landrover: 0.001
actual class is: planetarium
predictions:
planetarium: 0.026
bannister, banister, balustrade, balusters, handrail: 0.018
mortar: 0.017
wooden spoon: 0.009
steel arch bridge: 0.009
grand piano, grand: 0.008
suspension bridge: 0.007
organ, pipe organ: 0.007
picket fence, paling: 0.007
marimba, xylophone: 0.007
actual class is: upright, upright piano
predictions:
upright, upright piano: 0.465
grand piano, grand: 0.039
organ, pipe organ: 0.006
stove: 0.004
chest: 0.002
file, file cabinet, filing cabinet: 0.002
rotisserie: 0.002
microwave, microwave oven: 0.002
accordion, piano accordion, squeeze box: 0.002
sax, saxophone: 0.002
actual class is: fingers!
predictions:
fingers!: 0.991
sunscreen, sunblock, sun blocker: 0.002
miniskirt, mini: 0.001
sandal: 0.000
cowboy hat, ten-gallon hat: 0.000
racket, racquet: 0.000
mask: 0.000
buckle: 0.000
sunglasses, dark glasses, shades: 0.000
clog, geta, patten, sabot: 0.000
actual class is: fingers!
predictions:
fingers!: 1.000
miniskirt, mini: 0.000
sunscreen, sunblock, sun blocker: 0.000
sandal: 0.000
maillot: 0.000
diaper, nappy, napkin: 0.000
thunder snake, worm snake, Carphophis amoenus: 0.000
Band Aid: 0.000
bath towel: 0.000
cowboy hat, ten-gallon hat: 0.000
actual class is: fingers!
predictions:
fingers!: 1.000
sandal: 0.000
cowboy hat, ten-gallon hat: 0.000
sunscreen, sunblock, sun blocker: 0.000
racket, racquet: 0.000
mask: 0.000
sunglasses, dark glasses, shades: 0.000
orange: 0.000
paintbrush: 0.000
wooden spoon: 0.000
The model is clearly able to reliably spot the finger images. This isn’t a reliable indicator as they are in the training set. It may also be picking up on the distinctive style.
I’m not really creating a classifier, more of an image modifier. Let’s save the model and then move on.
The hand image is in the training set, so that’s not really a fair evaluation. In a dataset of over a million images it’s unlikely that the set that were selected overlap with the training set.
Either way this appears to have trained well and now I can use it to make every image into horrible fingers.
I’m going to use this model to alter images to make them more finger like. As you saw from the example images the dataset is extremely questionable in quality, and this is likely to result in the image taking on the aspects of the stagble diffusion generation (e.g. sharp, slightly shiny).
This is going to work by converting the image into a tensor and then training that image tensor. The loss function will encourage the image to be more finger like by trying to maximize the output of the finger class.
To start with I want a way to convert an image into a tensor, and a way to convert that tensor back into an image.
from torchvision.models import ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
from PIL import Image
image = Image.open("fing-longer.jpg")
print("original image")
display(image)
image_tensor = ResNet50_Weights.DEFAULT.transforms()(image)
print(f"image tensor: {image_tensor.shape}")
restored_image = to_pil_image(image_tensor)
print(f"restored image")
display(restored_image)
original image
image tensor: torch.Size([3, 224, 224])
restored image
The problem here is that the original image has been resized and then normalized. It is the normalization that has messed with the colors. I can’t reverse the resizing as that loses information but the normalization can be reversed.
The transforms
from the ResNet50_Weights
enum apply the normalization. We can see the statistics for it there.
ImageClassification(
crop_size=[224]
resize_size=[232]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)
from torchvision.models import ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
from PIL import Image
image = Image.open("fing-longer.jpg")
print("original image")
display(image)
preprocess = ResNet50_Weights.DEFAULT.transforms()
image_tensor = preprocess(image)
print(f"image tensor: {image_tensor.shape}")
# image is channel x height x width
# the image was normalized by subtracting the mean and dividing by the standard deviation
# applying the inverse operations in the opposite order restores the image
image_tensor = image_tensor * torch.tensor(preprocess.std)[:, None, None]
image_tensor = image_tensor + torch.tensor(preprocess.mean)[:, None, None]
restored_image = to_pil_image(image_tensor)
print(f"restored image")
display(restored_image)
original image
image tensor: torch.Size([3, 224, 224])
restored image
We’ve now got a reliable way to get our image back. Wrapping these changes up in functions will make them easy to apply.
from pathlib import Path
from PIL import Image
from torchvision.models import ResNet50_Weights
import torchvision.transforms.functional as F
import torch
def image_to_tensor(
image: str | Path | Image.Image,
weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
) -> torch.Tensor:
if not isinstance(image, Image.Image):
image = Image.open(image)
preprocess = weights.transforms()
return preprocess(image)
def tensor_to_image(
tensor: torch.Tensor,
weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
) -> Image.Image:
preprocess = weights.transforms()
std = torch.tensor(preprocess.std)
mean = torch.tensor(preprocess.mean)
# the order to reverse is inverted, need to apply std to mean to restore it
mean = mean / std
tensor = tensor.detach().cpu()
tensor = F.normalize(tensor, mean=-mean, std=1/std)
return F.to_pil_image(tensor, mode="RGB")
Now we can try training the image tensor. This is done by marking the tensor as requiring gradient calculations and then passing it to the optimizer as the only parameters of the model. The operations that the ResNet model perform are differentiable and as such the loss it calculates can be applied to the image parameters.
from pathlib import Path
from torchvision.models import ResNet, ResNet50_Weights
import torch
def load_extended_model(
path: Path,
weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
device: str = "cuda",
) -> ResNet:
model = resnet50(weights=weights)
model = add_finger_class(model)
model.load_state_dict(torch.load(path))
model = model.to(device)
return model.train()
import torch
from torch.optim import Adam, SGD
from torch.nn.functional import cross_entropy
from torchvision.models import ResNet, ResNet50_Weights
import torchvision.transforms.functional as F
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
from PIL import Image
from tqdm.auto import tqdm
def train_image(
image: str | Path,
model: ResNet,
target_class: int = 1_000,
learning_rate: float = 1e-3,
epochs: int = 10,
weights: ResNet50_Weights = ResNet50_Weights.DEFAULT,
device: str = "cuda",
) -> Image.Image:
target = torch.tensor([target_class], device=device)
def loss_fn(inputs: torch.Tensor) -> (torch.Tensor, torch.Tensor):
output = model(inputs)
value = output.softmax(dim=-1)[:, target_class].mean()
return cross_entropy(output, target), value
model = model.eval()
image_tensor = image_to_tensor(image)
image_tensor = image_tensor[None]
image_tensor = image_tensor.to(device)
image_parameter = torch.nn.Parameter(image_tensor)
# optimizer = SGD([image_parameter], lr=learning_rate)
optimizer = Adam([image_parameter], lr=learning_rate)
for epoch in range(epochs):
optimizer.zero_grad()
loss, value = loss_fn(image_parameter)
loss.backward()
optimizer.step()
if epoch % (epochs // 10) == 0:
print(f"epoch: {epoch} loss: {loss.item():0.5g} prediction: {value.item():0.5g}")
print(f"final loss: {loss.item():0.5g}")
image = tensor_to_image(image_parameter[0])
classify_image(
model=model,
image=image,
label=False,
forced_classes=[target_class],
)
return image
epoch: 0 loss: 6.9352 prediction: 0.00097293
epoch: 100 loss: 0.072849 prediction: 0.92974
epoch: 200 loss: 0.025723 prediction: 0.9746
epoch: 300 loss: 0.014492 prediction: 0.98561
epoch: 400 loss: 0.0099468 prediction: 0.9901
epoch: 500 loss: 0.0074989 prediction: 0.99253
epoch: 600 loss: 0.0059174 prediction: 0.9941
epoch: 700 loss: 0.0048196 prediction: 0.99519
epoch: 800 loss: 0.0040259 prediction: 0.99598
epoch: 900 loss: 0.0034263 prediction: 0.99658
final loss: 0.0028895
predictions:
comic book: 0.055
wall clock: 0.017
swab, swob, mop: 0.008
analog clock: 0.007
hook, claw: 0.006
toyshop: 0.006
crib, cot: 0.006
plunger, plumber's helper: 0.006
whistle: 0.005
bow: 0.005
goldfish, Carassius auratus: 0.001
This claims to have trained the model to predict fingers for the image with 99.9% confidence yet when I then classify the produced image it only has a 0.1% confidence.