Continuing from the discussion of inference, we show how to train JPEG models. Starting with boilerplate code
!pip install torch opt_einsum tabulate torchvision
import torch
from torchvision import datasets, transforms
import opt_einsum as oe
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate
and the standard definition of the tensors we will use for compression and decompression
def A(alpha):
if alpha == 0:
return 1.0 / np.sqrt(2)
else:
return 1
def D():
D_t = torch.zeros([8, 8, 8, 8], dtype=torch.float)
for i in range(8):
for j in range(8):
for alpha in range(8):
for beta in range(8):
scale_a = A(alpha)
scale_b = A(beta)
coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
return D_t
def D_n(n_freqs):
D_t = torch.zeros([8, 8, 8, 8], dtype=torch.float)
for i in range(8):
for j in range(8):
for alpha in range(8):
for beta in range(8):
if alpha + beta <= n_freqs:
scale_a = A(alpha)
scale_b = A(beta)
coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
return D_t
def Z():
z = np.array([[ 0, 1, 5, 6, 14, 15, 27, 28],
[ 2, 4, 7, 13, 16, 26, 29, 42],
[ 3, 8, 12, 17, 25, 30, 41, 43],
[ 9, 11, 18, 24, 31, 40, 44, 53],
[10, 19, 23, 32, 39, 45, 52, 54],
[20, 22, 33, 38, 46, 51, 55, 60],
[21, 34, 37, 47, 50, 56, 59, 61],
[35, 36, 48, 49, 57, 58, 62, 63]], dtype=float)
Z_t = torch.zeros([8, 8, 64], dtype=torch.float)
for alpha in range(8):
for beta in range(8):
for gamma in range(64):
if z[alpha, beta] == gamma:
Z_t[alpha, beta, gamma] = 1
return Z_t
def S():
q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
S_t = torch.zeros([64, 64], dtype=torch.float)
for gamma in range(64):
for k in range(64):
if gamma == k:
S_t[gamma, k] = 1.0 / q[k]
return S_t
def S_i():
q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
S_t = torch.zeros([64, 64], dtype=torch.float)
for gamma in range(64):
for k in range(64):
if gamma == k:
S_t[gamma, k] = q[k]
return S_t
def B(shape, block_size):
blocks_shape = (shape[0] // block_size[0], shape[1] // block_size[1])
B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], dtype=torch.float)
for s_x in range(shape[0]):
for s_y in range(shape[1]):
for x in range(blocks_shape[0]):
for y in range(blocks_shape[1]):
for i in range(block_size[0]):
for j in range(block_size[1]):
if x * block_size[0] + i == s_x and y * block_size[1] + j == s_y:
B_t[s_x, s_y, x, y, i, j] = 1.0
return B_t
The layers from the inference notes are copied here with the exception of batch normalization. For training, the gradient for all layers is computed using autograd, so there is no need to program a backward pass explicitly. Batch normalization, however, has a different formulation at training time that we need to take into account.
class AvgPool(torch.nn.modules.Module):
def __init__(self):
super(AvgPool, self).__init__()
def forward(self, input):
result = torch.mean(input[:, :, :, :, 0].view(-1, input.shape[1], input.shape[2]*input.shape[3]), 2)
return result
class Conv2d(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(Conv2d, self).__init__()
self.stride = conv_spatial.stride
self.padding = conv_spatial.padding
self.weight = torch.nn.Parameter(conv_spatial.weight.clone())
self.register_buffer('J', J[0])
self.register_buffer('J_i', J[1])
J_batched = self.J_i.contiguous().view(np.prod(self.J_i.shape[0:3]), 1, *self.J_i.shape[3:5])
self.register_buffer('J_batched', J_batched)
self.make_apply_op()
self.jpeg_op = None
def make_apply_op(self):
input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]]
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J, input_shape, constants=[1], optimize='optimal')
self.apply_conv.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(Conv2d, self)._apply(fn)
s.make_apply_op()
return s
def explode(self):
out_channels = self.weight.shape[0]
in_channels = self.weight.shape[1]
jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
jpeg_op = jpeg_op.permute(1, 0, 2, 3)
jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J_i.shape[0:3], *(np.array(self.J_i.shape[3:5]) // self.stride))
return jpeg_op
def explode_pre(self):
self.jpeg_op = self.explode()
def forward(self, input):
if self.jpeg_op is not None:
jpeg_op = self.jpeg_op
else:
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class ASMReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(ASMReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
self.register_buffer('Hm', Hm)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.hsm_op.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(ASMReLU, self)._apply(fn)
s.make_masking_ops()
return s
def annm(self, x):
appx_im = self.annm_op(x, backend='torch')
mask = torch.zeros_like(appx_im)
mask[appx_im >= 0] = 1
return mask
def half_spatial_mask(self, x, m):
return self.hsm_op(x, m, backend='torch')
def forward(self, input):
annm = self.annm(input)
out_comp = self.half_spatial_mask(input, annm)
return out_comp
The previous definition of batch normaliztion was suitable for inference but not for training. It used precomputed $\gamma$ and $\beta$ parameters as well as the running mean and variance. To compute this ourselves, is straightforward using the batch mean and variance that was derived in the batch normalization notes.
class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(BatchNorm, self).__init__()
self.register_buffer('running_mean', bn.running_mean.clone())
self.register_buffer('running_var', bn.running_var.clone())
self.register_buffer('S_i', S_i())
self.momentum = bn.momentum
self.eps = bn.eps
self.gamma = torch.nn.Parameter(bn.weight.clone())
self.beta = torch.nn.Parameter(bn.bias.clone())
def forward(self, input):
if self.training:
channels = input.shape[1]
input_channelwise = input.permute(1, 0, 2, 3, 4).clone()
# Compute the batch mean for each channel
block_means = input_channelwise[:, :, :, :, 0].contiguous().view(channels, -1)
batch_mean = torch.mean(block_means, 1)
# Compute the batch variance for each channel
input_dequantized = torch.einsum('mtxyk,gk->mtxyg', [input_channelwise, self.S_i])
input_dequantized[:, :, :, :, 0] = 0 # zero mean
block_variances = torch.mean(input_dequantized**2, 4).view(channels, -1)
batch_var = torch.mean(block_variances + block_means**2, 1) - batch_mean**2
# Apply bessel correction to match pytorch i dont think this is really necessary
bessel_correction_factor = input.shape[0] * input.shape[2] * input.shape[3] * 64
bessel_correction_factor = bessel_correction_factor / (bessel_correction_factor - 1)
batch_var *= bessel_correction_factor
batch_var = batch_var
# Update running stats
self.running_mean = self.running_mean * (1 - self.momentum) + batch_mean * self.momentum
self.running_var = self.running_var * (1 - self.momentum) + batch_var * self.momentum
# Apply parameters
invstd = 1. / torch.sqrt(batch_var + self.eps).view(1, -1, 1, 1, 1)
mean = batch_mean.view(1, -1, 1, 1)
else:
invstd = 1. / torch.sqrt(self.running_var + self.eps).view(1, -1, 1, 1, 1)
mean = self.running_mean.view(1, -1, 1, 1)
g = self.gamma.view(1, -1, 1, 1, 1)
b = self.beta.view(1, -1, 1, 1)
input[:, :, :, :, 0] = input[:, :, :, :, 0] - mean
input = input * invstd
input = input * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
return input
The first test we perform is a simple sanity check. We generate a small batch of random images and JPEG compress them. Then we perform training over the batch and show that the means, variances, and learned paramters are the same for the spatial and JPEG model.
First we generate the batch
def show_image(m, ax=None):
c_img = np.zeros((m.shape[0], m.shape[1], 3))
max_gr0 = np.max(m[m > 0])
if len(m[m < 0]) > 0:
min_le0 = np.min(m[m < 0])
else:
min_le0 = 0
c_img[m < 0] = np.array([[0.0, 1.0, 0.0]]) * (m[m < 0] / min_le0).reshape(-1, 1)
c_img[m == 0] = np.array([0.0, 0.0, 1.0])
c_img[m > 0] = np.array([[1.0, 0.0, 0.0]]) * (m[m > 0] / max_gr0).reshape(-1, 1)
plt.grid(False)
if ax is None:
return plt.imshow(c_img)
else:
ax.grid(False)
return ax.imshow(c_img)
def generate_batch(n, c):
return torch.Tensor(np.random.randint(0, 255, size=(n, c, 8, 8)).astype(float))
def show_batch(batch):
plt.figure(figsize=(15, 10))
for b in range(batch.shape[0]):
for c in range(batch.shape[1]):
plt.subplot(batch.shape[0], batch.shape[1], b * batch.shape[1] + c + 1)
show_image(batch[b, c, :, :])
spatial_batch = [generate_batch(128, 16) for _ in range(300)]
Then we create the JPEG compresed batch
C = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S()))
C_i = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S_i()))
def codec(image_size, block_size=(8, 8)):
B_i = B(image_size, block_size)
J = torch.einsum('srxyij,ijk->srxyk', (B_i, C))
J_i = torch.einsum('srxyij,ijk->xyksr', (B_i, C_i))
return J, J_i
def encode(batch, block_size=(8, 8), device=None):
J, _ = codec(batch.shape[2:], block_size)
if device is not None:
batch = batch.to(device)
J = J.to(device)
jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
return jpeg_batch
def decode(batch, device=None):
block_size = int(np.sqrt(batch.shape[4]))
image_size = (int(batch.shape[2] * block_size), int(batch.shape[3] * block_size))
_, J_i = codec(image_size, (block_size, block_size))
if device is not None:
batch = batch.to(device)
J_i = J_i.to(device)
spatial_batch = torch.einsum('xyksr,ncxyk->ncsr', (J_i, batch))
return spatial_batch
jpeg_batch = [encode(sb) for sb in spatial_batch]
We initialize a batch norm layer for spatial and JPEG
spatial_bn = torch.nn.BatchNorm2d(spatial_batch[0].shape[1])
jpeg_bn = BatchNorm(spatial_bn)
Then perform training over the batch
spatial_bn.train()
jpeg_bn.train()
spatial_res = [spatial_bn(sb) for sb in spatial_batch[:-2]]
jpeg_res = [jpeg_bn(jb) for jb in jpeg_batch[:-2]]
Finally we print the means, variances, and parameters and show that they are identical.
print(spatial_bn.running_mean)
print(jpeg_bn.running_mean)
print(spatial_bn.running_var)
print(jpeg_bn.running_var)
print(spatial_bn.weight)
print(jpeg_bn.gamma)
print(spatial_bn.bias)
print(jpeg_bn.beta)
We conclude by defining the same ResNet based network that was used in the inference section. We then train this model on the JPEG compressed MNIST dataset and show that it is converging and that its accuracy matches the spatial domain model.
First the block and network definitions for JPEG
class JpegResBlock(torch.nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out, relu_layer=ASMReLU):
super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1])
self.conv1 = Conv2d(spatial_resblock.conv1, J_down)
self.conv2 = Conv2d(spatial_resblock.conv2, J_out)
self.bn1 = BatchNorm(spatial_resblock.bn1)
self.bn2 = BatchNorm(spatial_resblock.bn2)
self.relu = relu_layer(n_freqs=n_freqs)
if spatial_resblock.downsampler is not None:
self.downsampler = Conv2d(spatial_resblock.downsampler, J_down)
self.bn_ds = BatchNorm(spatial_resblock.bn_ds)
else:
self.downsampler = None
def explode_all(self):
self.conv1.explode_pre()
self.conv2.explode_pre()
if self.downsampler is not None:
self.downsampler.explode_pre()
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsampler is not None:
residual = self.downsampler(x)
residual = self.bn_ds(residual)
else:
residual = x
out += residual
out = self.relu(out)
return out
class JpegResNet(torch.nn.Module):
def __init__(self, spatial_model, n_freqs, relu_layer=ASMReLU):
super(JpegResNet, self).__init__()
J_32 = codec((32, 32))
J_16 = codec((16, 16))
J_8 = codec((8, 8))
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32, relu_layer=relu_layer)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16, relu_layer=relu_layer)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8, relu_layer=relu_layer)
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
def explode_all(self):
self.block1.explode_all()
self.block2.explode_all()
self.block3.explode_all()
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
and the corresponding defintions for spatial
class SpatialResBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, downsample=True):
super(SpatialResBlock, self).__init__()
self.downsample = downsample
stride = 2 if downsample else 1
self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(out_channels)
self.bn2 = torch.nn.BatchNorm2d(out_channels)
self.relu = torch.nn.ReLU(inplace=True)
if downsample or in_channels != out_channels:
stride = 2 if downsample else 1
self.downsampler = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
self.bn_ds = torch.nn.BatchNorm2d(out_channels)
else:
self.downsampler = None
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsampler is not None:
residual = self.downsampler(x)
residual = self.bn_ds(residual)
else:
residual = x
out += residual
out = self.relu(out)
return out
class SpatialResNet(torch.nn.Module):
def __init__(self, channels, classes):
super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=channels, out_channels=16, downsample=False)
self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = torch.nn.AvgPool2d(8, stride=1)
self.fc = torch.nn.Linear(64, classes)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
Utilities for training and testing
def train(model, device, train_loader, optimizer, epoch, doencode):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if doencode:
data, target = encode(data, device=device), target.to(device)
else:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader, doencode):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
if doencode:
data, target = encode(data, device=device), target.to(device)
else:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return correct / len(test_loader.dataset)
Loading the MNIST dataset
train_data = datasets.MNIST('MNIST-data', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_data = datasets.MNIST('MNIST-data', train=False, transform=transforms.Compose([
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False)
and converting it to JPEG
class MNISTJpegDataset(torch.utils.data.Dataset):
def __init__(self, dataset, labels):
self.dataset = dataset
self.labels = labels
def __len__(self):
return self.dataset.size()[0]
def __getitem__(self, idx):
_, l = self.labels[idx]
return self.dataset[idx, :, :, :, :], l
device = torch.device('cuda')
J = torch.einsum('srxyij,ijab,abg,gk->srxyk', (B((32, 32), (8, 8)), D(), Z(), S())).to(device)
def jpeg_encode(batch):
batch = batch.to(device)
jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
return jpeg_batch
test_jpegconvert_loader = torch.utils.data.DataLoader(test_data, batch_size=10000, shuffle=False)
test_jpeg_data = []
for data, _ in test_jpegconvert_loader:
jpeg_data = jpeg_encode(data)
test_jpeg_data.append(jpeg_data)
test_jpeg_data = torch.cat(test_jpeg_data)
jpeg_test_data = MNISTJpegDataset(test_jpeg_data, test_data)
jpeg_test_loader = torch.utils.data.DataLoader(jpeg_test_data, batch_size=128, shuffle=False)
train_jpegconvert_loader = torch.utils.data.DataLoader(train_data, batch_size=60000, shuffle=False)
train_jpeg_data = []
for data, _ in train_jpegconvert_loader:
jpeg_data = jpeg_encode(data)
train_jpeg_data.append(jpeg_data)
train_jpeg_data = torch.cat(train_jpeg_data)
jpeg_train_data = MNISTJpegDataset(train_jpeg_data, train_data)
jpeg_train_loader = torch.utils.data.DataLoader(jpeg_train_data, batch_size=128, shuffle=False)
And finally training and testing. We only do 1 epoch as that is all that is needed to get resonable performance on MNIST.
model = SpatialResNet(1, 10).to(device)
jpeg_model = JpegResNet(model, n_freqs=14).to(device)
jpeg_optimizer = torch.optim.Adam(jpeg_model.parameters())
spatial_optimizer = torch.optim.Adam(model.parameters())
for epoch in range(1):
train(jpeg_model, device, jpeg_train_loader, jpeg_optimizer, epoch, doencode=False)
test(jpeg_model, device, jpeg_test_loader, doencode=False)
train(model, device, train_loader, spatial_optimizer, epoch, doencode=False)
test(model, device, test_loader, doencode=False)
© 2018 Max Ehrlich