GCN
Large Kernel Matters — Improve Semantic Segmentation by Global Convolutional Network
One of recent trends in network architecture design is stacking small filters (e.g., 1x1 or 3x3) in the entire network because the stacked small filters is more efficient than a large kernel, given the same computational complexity. However, in the field of semantic segmentation, where we need to perform dense per-pixel prediction, we find that the large kernel (and effective receptive field) plays an important role when we have to perform the classification and localization tasks simultaneously. Following our design principle, we propose a Global Convolutional Network to address both the classification and localization issues for the semantic segmentation. We also suggest a residual-based boundary refinement to further refine the object boundaries. Our approach achieves state-of-art performance on two public benchmarks and significantly outperforms previous results, 82.2% (vs 80.2%) on PASCAL VOC 2012 dataset and 76.9% (vs 71.8%) on Cityscapes dataset.
Implementations
PyTorch for Semantic Segmentation
This repository contains some models for semantic segmentation and the pipeline of training and testing models, implemented in PyTorch.
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from ..utils import initialize_weights
from .config import res152_path
# many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py
class _GlobalConvModule(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size):
super(_GlobalConvModule, self).__init__()
pad0 = (kernel_size[0] - 1) / 2
pad1 = (kernel_size[1] - 1) / 2
# kernel size had better be odd number so as to avoid alignment error
super(_GlobalConvModule, self).__init__()
self.conv_l1 = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size[0], 1),
padding=(pad0, 0))
self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]),
padding=(0, pad1))
self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]),
padding=(0, pad1))
self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1),
padding=(pad0, 0))
def forward(self, x):
x_l = self.conv_l1(x)
x_l = self.conv_l2(x_l)
x_r = self.conv_r1(x)
x_r = self.conv_r2(x_r)
x = x_l + x_r
return x
class _BoundaryRefineModule(nn.Module):
def __init__(self, dim):
super(_BoundaryRefineModule, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
def forward(self, x):
residual = self.conv1(x)
residual = self.relu(residual)
residual = self.conv2(residual)
out = x + residual
return out
class GCN(nn.Module):
def __init__(self, num_classes, input_size, pretrained=True):
super(GCN, self).__init__()
self.input_size = input_size
resnet = models.resnet152()
if pretrained:
resnet.load_state_dict(torch.load(res152_path))
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1)
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.gcm1 = _GlobalConvModule(2048, num_classes, (7, 7))
self.gcm2 = _GlobalConvModule(1024, num_classes, (7, 7))
self.gcm3 = _GlobalConvModule(512, num_classes, (7, 7))
self.gcm4 = _GlobalConvModule(256, num_classes, (7, 7))
self.brm1 = _BoundaryRefineModule(num_classes)
self.brm2 = _BoundaryRefineModule(num_classes)
self.brm3 = _BoundaryRefineModule(num_classes)
self.brm4 = _BoundaryRefineModule(num_classes)
self.brm5 = _BoundaryRefineModule(num_classes)
self.brm6 = _BoundaryRefineModule(num_classes)
self.brm7 = _BoundaryRefineModule(num_classes)
self.brm8 = _BoundaryRefineModule(num_classes)
self.brm9 = _BoundaryRefineModule(num_classes)
initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3,
self.brm4, self.brm5, self.brm6, self.brm7, self.brm8, self.brm9)
def forward(self, x):
# if x: 512
fm0 = self.layer0(x) # 256
fm1 = self.layer1(fm0) # 128
fm2 = self.layer2(fm1) # 64
fm3 = self.layer3(fm2) # 32
fm4 = self.layer4(fm3) # 16
gcfm1 = self.brm1(self.gcm1(fm4)) # 16
gcfm2 = self.brm2(self.gcm2(fm3)) # 32
gcfm3 = self.brm3(self.gcm3(fm2)) # 64
gcfm4 = self.brm4(self.gcm4(fm1)) # 128
fs1 = self.brm5(F.upsample_bilinear(gcfm1, fm3.size()[2:]) + gcfm2) # 32
fs2 = self.brm6(F.upsample_bilinear(fs1, fm2.size()[2:]) + gcfm3) # 64
fs3 = self.brm7(F.upsample_bilinear(fs2, fm1.size()[2:]) + gcfm4) # 128
fs4 = self.brm8(F.upsample_bilinear(fs3, fm0.size()[2:])) # 256
out = self.brm9(F.upsample_bilinear(fs4, self.input_size)) # 512
return out
Github: https://github.com/ZijunDeng/pytorch-semantic-segmentation
