Evaluation of error visualizer for tensor operations
Published
February 16, 2021
It’s time to evaluate another library to help with deep learning. This one is tensor sensor and it aims to show problems with tensor operations clearly. When creating an architecture a lot of the errors come from mismatched dimensions, and this library should help explain such problems.
Code
import tsensorimport torchimport sysW = torch.tensor([[1, 2], [3, 4]])b = torch.tensor([9, 10]).reshape(2, 1)x = torch.tensor([4, 5]).reshape(2, 1)h = torch.tensor([1,2])try: # try is used just to catch the exception and extract the messageswith tsensor.clarify(): W @ torch.dot(b,b)+ torch.eye(2,2)@x + zexceptBaseExceptionas e: msgs = e.args[0].split("\n") sys.stderr.write("PyTorch says: "+msgs[0]+'\n\n') sys.stderr.write("tsensor adds: "+msgs[1]+'\n')
findfont: Font family ['Consolas'] not found. Falling back to DejaVu Sans.
findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.
PyTorch says: 1D tensors expected, but got 2D and 2D tensors
tsensor adds: Cause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]
That code is from the example notebook. Unfortunately because I am missing some fonts I don’t get the pretty letters. Consolas is a microsoft font and is not redistributable.
It turns out there is a nice way to install this stuff on ubuntu instructions from here:
To actually stop it complaining in jupyter I also had to run:
rm -rf ~/.cache/matplotlib
So I have Arial at least now.
Code
try: # try is used just to catch the exception and extract the messageswith tsensor.clarify(): W @ torch.dot(b,b)+ torch.eye(2,2)@x + zexceptBaseExceptionas e: msgs = e.args[0].split("\n") sys.stderr.write("PyTorch says: "+msgs[0]+'\n\n') sys.stderr.write("tsensor adds: "+msgs[1]+'\n')
PyTorch says: 1D tensors expected, but got 2D and 2D tensors
tsensor adds: Cause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]
I … don’t see any difference. Let’s try some of the other examples that have been provided.
Code
d =764n_neurons =100n =200W = torch.rand(d,n_neurons)b = torch.rand(n_neurons,1)X = torch.rand(n,d)with tsensor.clarify(): Y = W @ X.T + b
RuntimeError: mat1 and mat2 shapes cannot be multiplied (764x100 and 764x200)
Cause: @ on tensor operand W w/shape [764, 100] and operand X.T w/shape [764, 200]
The visualization does work. Now I should try it with an actual model to see how much it helps.
import torch.nn as nn# very truncated resnetgood_model = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=8, kernel_size=7, stride=2), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=1),# nn.Flatten(), # I do this all the time nn.Linear(in_features=8, out_features=10))try:with tsensor.clarify(): good_model(torch.rand(1, 3, 224, 224))exceptBaseExceptionas e: msgs = e.args[0].split("\n") sys.stderr.write("PyTorch says: "+msgs[0]+'\n\n') sys.stderr.write("tsensor adds: "+msgs[1]+'\n')
PyTorch says: mat1 and mat2 shapes cannot be multiplied (8x1 and 8x10)
tsensor adds: Cause: good_model(torch.rand(1,3,224,224)) tensor arg torch.rand(1,3,224,224) w/shape [1, 3, 224, 224]
Problem here is that it has not introspected the model at all. The box it has drawn is the raw input. It’s also very hard to read the image.
PyTorch says: mat1 and mat2 shapes cannot be multiplied (8x1 and 8x10)
tsensor adds: Cause: good_model(torch.rand(1,3,224,224)) tensor arg torch.rand(1,3,224,224) w/shape [1, 3, 224, 224]
So there are settings for the font which can make it considerably more readable. I would really like to be able to make the image bigger. Screen space is not a concern.
Overall I would say that this is a promising library that needs to be used very close to the problem code. Visualising the matrix operations is useful. I may well try this out in my future work.