Man Woman Classifier

A model to classify pictures of people
Published

February 18, 2021

Part of the recommended work for the second lesson of FastAI is to create an image classifier and make it into a website. Let’s see if I can do that by creating a man / woman classifier using resnet, exporting it to onnx, quantizing it, and then using gradio to host that.

Ideally I’d be able to show the hosted site in this blog.

The first thing is to load the data. I’ve downloaded this man woman dataset from kaggle. I need to be able to turn this into a fastai dataset.

Code
from fastai.vision.all import *

DATASET_FOLDER = Path(".") / "data" / "2021-02-18-man-woman-classifier"
Code
image_files = get_image_files(DATASET_FOLDER)
def image_label(x):
    return x.parent.name

dls = ImageDataLoaders.from_path_func(
    path=DATASET_FOLDER,
    fnames=image_files,
    label_func=image_label,
    item_tfms=Resize(224),
    batch_tfms=aug_transforms()
)
dls.show_batch()


Training

Now to train a model to classify these images. I want to make a very small model so I’m deliberately choosing resnet18, which should achieve some level of accuracy but nothing amazing. Increasing the size of the model would improve the accuracy after training.

Code
learner = cnn_learner(dls, resnet18, pretrained=True, metrics=[accuracy])
learner.cuda() ; None
Code
learner.fine_tune(epochs=10, freeze_epochs=4)

So, 93% accuracy. Pretty good considering I was optimizing for size.


Predictions

Lets have a look at the predictions we can get out of this.

Code
Image.open(dls.valid.items[0])

Code
learner.predict(dls.valid.items[0])
('man', tensor(0), tensor([9.9999e-01, 9.9780e-06]))
Code
interp = ClassificationInterpretation.from_learner(learner)
Code
interp.plot_confusion_matrix()

Code
interp.plot_top_losses(k=9)

Code
learner.save(DATASET_FOLDER.absolute() / "model")
Path('/home/matthew/Programming/Blog/blog/notebooks/data/2021-02-18-man-woman-classifier/model.pth')

So this doesn’t actually seem terrible? The dataset is awful in that it has many very stereotypical images, so the bias that long hair = woman is prevalent. I can believe these errors.

Resnet is doing a good job of classifying these images though.

Lets see about making it into an app.


Appify

I’ve used gradio before to quickly host models. It’s really neat.

Here I’m going to use it to get images and then provide the classification. This means an input of an image, resized to 224x224, and a textual output:

Code
import gradio as gr

learner = cnn_learner(dls, resnet18, pretrained=False)
learner.load(DATASET_FOLDER.absolute() / "model")

def predict(image) -> str:
    # I want to just return the predicted label
    label, *_ = learner.predict(image)
    return label

iface = gr.Interface(
    predict,
    gr.inputs.Image(shape=(224, 224)),
    "text",
    server_name="dl.matthew.lan",
    analytics_enabled=False,
)
iface.launch()

It’s neat that it can classify this cartoon given that it was trained on photos.


Quantized Version

So I can reduce the size of the model and run it on CPU if I quantize it. This would be a way to “productionize” this model and would be another nice application of the work I did in the previous blog post.

Remember, the steps from the previous post were:

  • Export to ONNX format
  • Convert from float32 to int8
  • Load in an Inference Session
  • Pass in images as numpy arrays

I’ll be getting the raw class predictions as the output so I’ll have to map those back to the “man”/“woman” labels.

Code
dummy_input = torch.randn(1, 3, 224, 224, device="cuda")
torch.onnx.export(
    learner.model,
    dummy_input,
    DATASET_FOLDER / "model.onnx",
    input_names=['input_image']
)
Code
from onnxruntime.quantization import quantize_qat
import onnxruntime as ort

quantize_qat(str(DATASET_FOLDER / "model.onnx"), str(DATASET_FOLDER / "model.qat.onnx"))
Warning: The original model opset version is 9, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model.
Code
onnx_session = ort.InferenceSession(str(DATASET_FOLDER / "model.qat.onnx"))
Code
image = dls.train.one_batch()[0][0]
image.show()
<AxesSubplot:>

Code
predicted_index = onnx_session.run(
    None,
    {"input_image": image[None, :, :, :].cpu().numpy()}
)[0].argmax()

dls.vocab[predicted_index]
'woman'

Gradio provides it’s images as an integer array with the dimensions as width / height / color, when the model wants the images as a float32 array with the dimensions batch / color / width / height. I’ll have to do a bit of wrangling to get the image into the right format.

Code
def predict(image) -> str:
    image = (image / 256).astype(np.float32).reshape(1, 3, 224, 224)
    predicted_index = onnx_session.run(
        None,
        {"input_image": image}
    )[0].argmax()

    return dls.vocab[predicted_index]

iface = gr.Interface(
    predict,
    gr.inputs.Image(shape=(224, 224)),
    "text",
    server_name="dl.matthew.lan",
    analytics_enabled=False,
)
iface.launch()

OK So I’ve got it working. It seems noticeably less accurate.

Still, managed to quantize it and make it run. I should work on actually hosting it somewhere. … Some other time.