These notes demonstrate taking a pre-trained model, one that was trained on spatial domain images, and converting it to perform inference in the JPEG transform domain. A simple demonstration is provided using a small ResNet and the MNIST dataset. The following code is boilerplate.
!pip install torch torchvision
!pip install opt_einsum
import torch.nn as nn
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import opt_einsum as oe
torch.backends.cudnn.enabled = False
First step is to get some data. For a simple and small problem, we use MNIST. The $28 \times 28$ images are zero padded to $32 \times 32$ during loading so that they form even $8 \times 8$ JPEG blocks, this will be important later. The train and test sets are loaded using PyTorch and 10 random train and test images are displayed to show that the data was loaded correctly.
def display_digit(x, y):
plt.title('Label: {}'.format(y))
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(x, 'gray_r')
train_data = datasets.MNIST('MNIST-data', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_data = datasets.MNIST('MNIST-data', train=False, transform=transforms.Compose([
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)
plt.figure(figsize=(20, 3))
plt.suptitle('Random MNIST Training Images')
for i in range(10):
plt.subplot(1, 10, i+1)
ind = np.random.randint(0, len(train_data))
display_digit(train_data[ind][0].numpy().reshape(32, 32), train_data[ind][1].numpy())
plt.figure(figsize=(20, 3))
plt.suptitle('Random MNIST Testing Images')
for i in range(10):
plt.subplot(1, 10, i+1)
ind = np.random.randint(0, len(test_data))
display_digit(test_data[ind][0].numpy().reshape(32, 32), test_data[ind][1].numpy())
Next, we train a highly simplified version of ResNet on the MNIST dataset. Since this test is for demonstration only, this simple resnet will suffice. Also since MNIST is a very simple problem, it will be easily solved even by this simple network. The network uses only two residual blocks with each residual block performing downsampling. The output of the final residual block is average-pooled and a single fully connected layer learns the classification. Note that MNIST can be solved effecively with even simpler ResNets but they converge slower (and are less interesting). Note also that there is nothing happening with compressed images yet.
First, the residual block and network architecture are implemented using the PyTorch object-oriented convention.
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample=True):
super(ResBlock, self).__init__()
self.downsample = downsample
stride = 2 if downsample else 1
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsampler(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self):
super(ResNet, self).__init__()
self.block1 = ResBlock(in_channels=1, out_channels=16, downsample=False)
self.block2 = ResBlock(in_channels=16, out_channels=32)
self.block3 = ResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, 10)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
Then, helper functions to carry out training and testing the model are given
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
Finally, the model is trained for 5 epochs using a GPU. The final test set accuracy should be in the high 90s and it should train in no more than a few minutes.
device = torch.device('cuda')
model = ResNet().to(device)
optimizer = optim.Adam(model.parameters())
for epoch in range(5):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
Before we can carry out and inference on JPEG compressed images, we need to convert the MNIST data. We can leverage the Tensor method developed previously along with PyTorch to create a fast GPU enabled JPEG compression codec to quickly convert the images. The below code is adapted from the "Tensor Methods" notes to use PyTorch GPU tensors, and the JPEG compression and decompression tensors are computed. Remember that this operation took several minutes on CPU. This implementation forms the tensors in less than a second.
def A(alpha):
if alpha == 0:
return 1.0 / np.sqrt(2)
else:
return 1
def D():
D_t = torch.zeros([8, 8, 8, 8], device=device, dtype=torch.float)
for i in range(8):
for j in range(8):
for alpha in range(8):
for beta in range(8):
scale_a = A(alpha)
scale_b = A(beta)
coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
return D_t
def D_n(n_freqs):
D_t = torch.zeros([8, 8, 8, 8], device=device, dtype=torch.float)
for i in range(8):
for j in range(8):
for alpha in range(8):
for beta in range(8):
if alpha + beta <= n_freqs:
scale_a = A(alpha)
scale_b = A(beta)
coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
return D_t
def Z():
z = np.array([[ 0, 1, 5, 6, 14, 15, 27, 28],
[ 2, 4, 7, 13, 16, 26, 29, 42],
[ 3, 8, 12, 17, 25, 30, 41, 43],
[ 9, 11, 18, 24, 31, 40, 44, 53],
[10, 19, 23, 32, 39, 45, 52, 54],
[20, 22, 33, 38, 46, 51, 55, 60],
[21, 34, 37, 47, 50, 56, 59, 61],
[35, 36, 48, 49, 57, 58, 62, 63]], dtype=float)
Z_t = torch.zeros([8, 8, 64], device=device, dtype=torch.float)
for alpha in range(8):
for beta in range(8):
for gamma in range(64):
if z[alpha, beta] == gamma:
Z_t[alpha, beta, gamma] = 1
return Z_t
def S():
q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
S_t = torch.zeros([64, 64], device=device, dtype=torch.float)
for gamma in range(64):
for k in range(64):
if gamma == k:
S_t[gamma, k] = 1.0 / q[k]
return S_t
def S_i():
q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
S_t = torch.zeros([64, 64], device=device, dtype=torch.float)
for gamma in range(64):
for k in range(64):
if gamma == k:
S_t[gamma, k] = q[k]
return S_t
def B(shape, block_size):
blocks_shape = (shape[0] // block_size[0], shape[1] // block_size[1])
B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], device=device, dtype=torch.float)
for s_x in range(shape[0]):
for s_y in range(shape[1]):
for x in range(blocks_shape[0]):
for y in range(blocks_shape[1]):
for i in range(block_size[0]):
for j in range(block_size[1]):
if x * block_size[0] + i == s_x and y * block_size[1] + j == s_y:
B_t[s_x, s_y, x, y, i ,j] = 1.0
return B_t
J = torch.einsum('srxyij,ijab,abg,gk->srxyk', (B((32, 32), (8, 8)), D(), Z(), S()))
J_i = torch.einsum('srxyij,ijab,abg,gk->xyksr', (B((32, 32), (8, 8)), D(), Z(), S_i()))
print(J.size())
print(J_i.size())
Next the images are encoded using the tensors. Using GPU parallelization, the entirety of each dataset can be converted in a single step. All 70,000 MNIST images are converted in just a few seconds.
def jpeg_encode(batch):
batch = batch.to(device)
jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
return jpeg_batch
test_jpegconvert_loader = torch.utils.data.DataLoader(test_data, batch_size=10000, shuffle=False)
test_jpeg_data = []
for data, _ in test_jpegconvert_loader:
jpeg_data = jpeg_encode(data)
test_jpeg_data.append(jpeg_data)
test_jpeg_data = torch.cat(test_jpeg_data)
train_jpegconvert_loader = torch.utils.data.DataLoader(train_data, batch_size=60000, shuffle=False)
train_jpeg_data = []
for data, _ in train_jpegconvert_loader:
jpeg_data = jpeg_encode(data)
train_jpeg_data.append(jpeg_data)
train_jpeg_data = torch.cat(train_jpeg_data)
To verify that things are working, ten random test images are decoded and printed along with their label.
plt.figure(figsize=(20, 3))
plt.suptitle('Reconstructed After JPEG Convert')
for i in range(10):
plt.subplot(1, 10, i+1)
ind = np.random.randint(0, len(test_data))
current_jpeg = test_jpeg_data[ind, :, :, :, :].view(4, 4, 64)
current_recn = torch.einsum('xyksr,xyk->sr', (J_i, current_jpeg)).cpu()
display_digit(current_recn.numpy().reshape(32, 32), test_data[ind][1].numpy())
As a last step, a pytorch DataSet
class is prepared for the inference pipeline. Since the compressed JPEGs are all stored in a single GPU tensor, the dataset is extremely simple
class MNISTJpegDataset(torch.utils.data.Dataset):
def __init__(self, dataset, labels):
self.dataset = dataset
self.labels = labels
def __len__(self):
return self.dataset.size()[0]
def __getitem__(self, idx):
_, l = self.labels[idx]
return self.dataset[idx, :, :, :, :], l
Now it is time to convert the learned parameters to their tensor forms. The first thing we need for this is a set of JPEG compression and decompression tensors. This is because the ResNet model resizes the input images, which will change the number of JPEG blocks. Since only the block structure is changing, this process can be sped up by precomputing the compression part of the transform (DCT, zigzag, quantization) which doesn't change, then applying the different size blocking operations to it.
C = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S()))
C_i = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S_i()))
B_32 = B((32, 32), (8, 8))
B_16 = B((16, 16), (8, 8))
B_8 = B((8, 8), (8, 8))
J_32 = torch.einsum('srxyij,ijk->srxyk', (B_32, C))
J_32_i = torch.einsum('srxyij,ijk->xyksr', (B_32, C_i))
J_16 = torch.einsum('srxyij,ijk->srxyk', (B_16, C))
J_16_i = torch.einsum('srxyij,ijk->xyksr', (B_16, C_i))
J_8 = torch.einsum('srxyij,ijk->srxyk', (B_8, C))
J_8_i = torch.einsum('srxyij,ijk->xyksr', (B_8, C_i))
J_32 = (J_32, J_32_i)
J_16 = (J_16, J_16_i)
J_8 = (J_8, J_8_i)
Next we define classes for each layer of the ResNet architecture. These will convert the paramters learned by the spatial model to operate on the JPEG transformed inputs. This gets pretty complicated but the basic building blocks are identical to the algorithms developed in previous notes. The only notable additions are some alterations to the shape of the inputs and outputs which are now multichannel batches of JPEG transformed images at each step.
The convolution layer is taken almost directlly from the "Tensor Methods" notes. In addition to the learned weights that will be converted, the input image shape is provided and JPEG encoding and decoding tensors are given. One noteable addition is that the exploded convolution has one additional channel to compute all the output convolutions at the same time. For example, in the first layer there is 1 input channel and 16 output channels. This is accomplished with a single $16 \times 1 \times 32 \times 32 \times 16 \times 16$ tensor where the input and output image sizes in $H \times W$ are $32 \times 32$ and $16 \times 16$ respectively. After expoding, the convolution is combined with the provided JPEG encoding and decoding tensors (which could be of different sizes since the layer might be downsampling) to give the final tensor. Note that we are treating each channel as a JPEG compressed "image". Finally, the forward
function applies the tensor operator to a JPEG compressed input. This main difference here is the addition of channels and batches as indices $m, n, t$. where $t$ is the batch index, $n$ is the input channel index and $m$ is the output channel index.
class JpegConv2d(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(JpegConv2d, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
self.padding = conv_spatial.padding
self.J = J
self.J_batched = self.J[1].contiguous().view(np.prod(self.J[1].shape[0:3]), 1, *self.J[1].shape[3:5])
input_shape = [0, self.weight.shape[1], *self.J[1].shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J[1].shape[0:3], *self.J[0].shape[0:2]]
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J[0], input_shape, constants=[1], optimize='optimal')
self.apply_conv.evaluate_constants(backend='torch')
def forward(self, input):
jpeg_op = []
out_channels = self.weight.shape[0]
in_channels = self.weight.shape[1]
jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
jpeg_op = jpeg_op.permute(1, 0, 2, 3)
jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J[1].shape[0:3], *(np.array(self.J[1].shape[3:5]) // self.stride))
return self.apply_conv(jpeg_op, input, backend='torch')
Batch normalization is also quite similar to the notes. Since this layer is for inference only, the $\gamma$ and $\beta$ parameters are converted accordingly. They are also reshaped to allow them to be applied to the input, which is now a batched multi-channel input with several JPEG blocks. Automatic broadcasting cannot handle this case without help. One important difference is in the application of $\beta$. Remember that when applying $\beta$ to DCT coefficients, we needed to multiply by 8 beforehand. This is not neccessary after the full JPEG transform because the quantization coefficient is already 8, so the value stored at the DC coefficient is the exact mean. To verify this, look at the S()
or S_i()
functions presented in the "GPU JPEG Codec" section, and note that the first coefficient in the quantization matrix q
is 8.
class JpegBatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(JpegBatchNorm, self).__init__()
self.mean = bn.running_mean
self.var = bn.running_var
self.gamma = bn.weight
self.beta = bn.bias
self.gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
def forward(self, input):
input = input * self.gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + self.beta_final
return input
The multilinear ReLu approximation from the ReLu notes is implented here with some additions to account for the other steps of the JPEG transform, since those notes were developed around only the DCT. Instead of a DCT approximation that keeps only the $n$ lowest frequencies, we combine that with the inverse zigzag and quantization steps to give an approximate decompression tensor. Note that this will preserve the block structure of the image, e.g. for a $32 \times 32$ original image, after decompression with this tensor will be $4 \times 4 \times 8 \times 8$, there is no need to undo the block structure for the purposes of this operation. Also note that the harmonic mixing tensor is combined with both forward and reverse zigzag and quantization so that it too can operate directly on JPEG transformed data.
class JpegRelu(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(JpegRelu, self).__init__()
self.C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.hsm_op.evaluate_constants(backend='torch')
def annm(self, x):
appx_im = self.annm_op(x, backend='torch')
mask = torch.zeros_like(appx_im)
mask[appx_im >= 0] = 1
return mask
def half_spatial_mask(self, x, m):
return self.hsm_op(x, m, backend='torch')
def forward(self, input):
annm = self.annm(input)
out_comp = self.half_spatial_mask(input, annm)
return out_comp
Global average pooling, and essential part of the ResNet architecture, has not been discussed in any notes so far. Average pooling in general is quite simple to implement in the JPEG transform domain since it is a linear operation, however the global average pooling can be done far more efficiently than in the general case. Recall that, as shown in the "Batch Normalization" section of these notes, the first element of the encoded blocks is exactly the mean of that block. Since all blocks are the same size, the mean of the image is equal to the mean of the individual mean of each block. Therefore, we need only extract the first element of each block in the final result of the network and average it per channel to get the global average pooling result. This is a massive optimization over the spatial domain algorithm, and is quite simple to implement. It also avoids doing any kind of decompression before feeding the result of the JPEG domain network into the fully connected layer. The information that it needs is nicely encapsulated by the JPEG representation.
class JpegAvgPool(torch.nn.modules.Module):
def __init__(self):
super(JpegAvgPool, self).__init__()
def forward(self, input):
result = torch.mean(input[:, :, :, :, 0].view(-1, input.shape[1], input.shape[2]*input.shape[3]), 2)
return result
Next the residual block structure and final network are constructed from these parts, this mirrors the structure of the spatial domain model that was trained earlier. By default we are keeping a full 10 spatial frequencies for the ReLu approximations, this can be tuned but it will effect the accuracy of inference since the model was trained with an exact ReLu, the network weights are not designed to handle this approximation.
class JpegResBlock(nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out):
super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1])
self.conv1 = JpegConv2d(spatial_resblock.conv1, J_down)
self.conv2 = JpegConv2d(spatial_resblock.conv2, J_out)
self.bn1 = JpegBatchNorm(spatial_resblock.bn1)
self.bn2 = JpegBatchNorm(spatial_resblock.bn2)
self.relu = JpegRelu(n_freqs=n_freqs)
if spatial_resblock.downsample:
self.downsample = True
self.downsampler = JpegConv2d(spatial_resblock.downsampler, J_down)
else:
self.downsample = False
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsampler(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
class JpegResNet(nn.Module):
def __init__(self, spatial_model, exact=False):
super(JpegResNet, self).__init__()
if exact:
n_freqs = 14
else:
n_freqs = 10
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8)
self.averagepooling = JpegAvgPool()
self.fc = spatial_model.fc
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
Finally, the model is ready to be converted. Because of the way the code was written, this is as simple as providing the spatial domain model as an argument. The returned model performs the same inference on JPEGs to within the ReLu approximation error.
jpeg_model = JpegResNet(model)
To demonstrate the model, we test it using the previously converted MNIST test images. The result should be close to, but not exactly the same as, the spatial domain test accuracy.
jpeg_test_data = MNISTJpegDataset(test_jpeg_data, test_data)
jpeg_test_loader = torch.utils.data.DataLoader(jpeg_test_data, batch_size=128, shuffle=False)
test(jpeg_model, device, jpeg_test_loader)
Next the same test is repeated using exact ReLu (14 spatial frequencies used for approximation). This should give the exact result that the spatial model gave.
jpeg_model_exact = JpegResNet(model, exact=True)
test(jpeg_model_exact, device, jpeg_test_loader)
import time
def test_timing(model, device, test_loader):
model.eval()
t0 = time.perf_counter()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
model(data)
torch.cuda.synchronize()
t1 = time.perf_counter()
print('Time: {:.02e}s'.format(t1 - t0))
test_timing(model, device, test_loader)
test_timing(jpeg_model, device, jpeg_test_loader)
def model_size(model):
return np.sum([np.sum([np.prod(p.shape) for p in m.parameters()]) for m in model.modules()]) * 4
print(model_size(model))
print(model_size(jpeg_model))
© 2018 Max Ehrlich