Skip to content

Commit 07bcee8

Browse files
committed
Implements Feature Pyramid Network (FPN), closes #60
1 parent da30c57 commit 07bcee8

6 files changed

Lines changed: 150 additions & 13 deletions

File tree

robosat/fpn.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Feature Pyramid Network (FPN) on top of ResNet. Comes with task-specific heads on top of it.
2+
3+
See:
4+
- https://arxiv.org/abs/1612.03144 - Feature Pyramid Networks for Object Detection
5+
- http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf - A Unified Architecture for Instance
6+
and Semantic Segmentation
7+
8+
"""
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
from torchvision.models import resnet50
14+
15+
16+
class FPN(nn.Module):
17+
"""Feature Pyramid Network (FPN): top-down architecture with lateral connections.
18+
Can be used as feature extractor for object detection or segmentation.
19+
"""
20+
21+
def __init__(self, num_filters=256, pretrained=True):
22+
"""Creates an `FPN` instance for feature extraction.
23+
24+
Args:
25+
num_filters: the number of filters in each output pyramid level
26+
pretrained: use ImageNet pre-trained backbone feature extractor
27+
"""
28+
29+
super().__init__()
30+
31+
self.resnet = resnet50(pretrained=pretrained)
32+
33+
# Access resnet directly in forward pass; do not store refs here due to
34+
# https://github.com/pytorch/pytorch/issues/8392
35+
36+
self.lateral4 = Conv1x1(2048, num_filters)
37+
self.lateral3 = Conv1x1(1024, num_filters)
38+
self.lateral2 = Conv1x1(512, num_filters)
39+
self.lateral1 = Conv1x1(256, num_filters)
40+
41+
self.smooth4 = Conv3x3(num_filters, num_filters)
42+
self.smooth3 = Conv3x3(num_filters, num_filters)
43+
self.smooth2 = Conv3x3(num_filters, num_filters)
44+
self.smooth1 = Conv3x3(num_filters, num_filters)
45+
46+
def forward(self, x):
47+
# Bottom-up pathway, from ResNet
48+
49+
enc0 = self.resnet.conv1(x)
50+
enc0 = self.resnet.bn1(enc0)
51+
enc0 = self.resnet.relu(enc0)
52+
enc0 = self.resnet.maxpool(enc0)
53+
54+
enc1 = self.resnet.layer1(enc0)
55+
enc2 = self.resnet.layer2(enc1)
56+
enc3 = self.resnet.layer3(enc2)
57+
enc4 = self.resnet.layer4(enc3)
58+
59+
# Lateral connections
60+
61+
lateral4 = self.lateral4(enc4)
62+
lateral3 = self.lateral3(enc3)
63+
lateral2 = self.lateral2(enc2)
64+
lateral1 = self.lateral1(enc1)
65+
66+
# Top-down pathway
67+
68+
map4 = lateral4
69+
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
70+
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
71+
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
72+
73+
# Reduce aliasing effect of upsampling
74+
75+
map4 = self.smooth4(map4)
76+
map3 = self.smooth3(map3)
77+
map2 = self.smooth2(map2)
78+
map1 = self.smooth1(map1)
79+
80+
return map1, map2, map3, map4
81+
82+
83+
class FPNSegmentation(nn.Module):
84+
"""Semantic segmentation model on top of a Feature Pyramid Network (FPN).
85+
"""
86+
87+
def __init__(self, num_classes, num_filters=128, num_filters_fpn=256, pretrained=True):
88+
"""Creates an `FPNSegmentation` instance for feature extraction.
89+
90+
Args:
91+
num_classes: number of classes to predict
92+
num_filters: the number of filters in each segmentation head pyramid level
93+
num_filters_fpn: the number of filters in each FPN output pyramid level
94+
pretrained: use ImageNet pre-trained backbone feature extractor
95+
"""
96+
97+
super().__init__()
98+
99+
# Feature Pyramid Network (FPN) with four feature maps of resolutions
100+
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
101+
102+
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
103+
104+
# The segmentation heads on top of the FPN
105+
106+
self.head1 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
107+
self.head2 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
108+
self.head3 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
109+
self.head4 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
110+
111+
self.final = nn.Conv2d(4 * num_filters, num_classes, kernel_size=3, padding=1)
112+
113+
def forward(self, x):
114+
map1, map2, map3, map4 = self.fpn(x)
115+
116+
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
117+
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
118+
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
119+
map1 = self.head1(map1)
120+
121+
final = self.final(torch.cat([map4, map3, map2, map1], dim=1))
122+
123+
return nn.functional.upsample(final, scale_factor=4, mode="bilinear", align_corners=False)
124+
125+
126+
class Conv1x1(nn.Module):
127+
def __init__(self, num_in, num_out):
128+
super().__init__()
129+
self.block = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)
130+
131+
def forward(self, x):
132+
return self.block(x)
133+
134+
135+
class Conv3x3(nn.Module):
136+
def __init__(self, num_in, num_out):
137+
super().__init__()
138+
self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)
139+
140+
def forward(self, x):
141+
return self.block(x)

robosat/tools/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.autograd
66

77
from robosat.config import load_config
8-
from robosat.unet import UNet
8+
from robosat.fpn import FPNSegmentation
99

1010

1111
def add_parser(subparser):
@@ -25,7 +25,7 @@ def main(args):
2525
dataset = load_config(args.dataset)
2626

2727
num_classes = len(dataset["common"]["classes"])
28-
net = UNet(num_classes)
28+
net = FPNSegmentation(num_classes)
2929

3030
def map_location(storage, _):
3131
return storage.cpu()

robosat/tools/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from PIL import Image
1515

1616
from robosat.datasets import BufferedSlippyMapDirectory
17-
from robosat.unet import UNet
17+
from robosat.fpn import FPNSegmentation
1818
from robosat.config import load_config
1919
from robosat.colors import continuous_palette_for_color
2020
from robosat.transforms import ConvertImageMode, ImageToTensor
@@ -59,7 +59,7 @@ def map_location(storage, _):
5959
# https://github.com/pytorch/pytorch/issues/7178
6060
chkpt = torch.load(args.checkpoint, map_location=map_location)
6161

62-
net = UNet(num_classes).to(device)
62+
net = FPNSegmentation(num_classes).to(device)
6363
net = nn.DataParallel(net)
6464

6565
if cuda:

robosat/tools/serve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from flask import Flask, send_file, render_template, abort
1717

1818
from robosat.tiles import fetch_image
19-
from robosat.unet import UNet
19+
from robosat.fpn import FPNSegmentation
2020
from robosat.config import load_config
2121
from robosat.colors import make_palette
2222
from robosat.transforms import ConvertImageMode, ImageToTensor
@@ -180,7 +180,7 @@ def map_location(storage, _):
180180

181181
num_classes = len(self.dataset["common"]["classes"])
182182

183-
net = UNet(num_classes).to(self.device)
183+
net = FPNSegmentation(num_classes).to(self.device)
184184
net = nn.DataParallel(net)
185185

186186
if self.cuda:

robosat/tools/train.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from robosat.datasets import SlippyMapTilesConcatenation
2727
from robosat.metrics import MeanIoU
2828
from robosat.losses import CrossEntropyLoss2d
29-
from robosat.unet import UNet
29+
from robosat.fpn import FPNSegmentation
3030
from robosat.utils import plot
3131
from robosat.config import load_config
3232

@@ -51,13 +51,10 @@ def main(args):
5151
if model["common"]["cuda"] and not torch.cuda.is_available():
5252
sys.exit("Error: CUDA requested but not available")
5353

54-
# if args.batch_size < 2:
55-
# sys.exit('Error: PSPNet requires more than one image for BatchNorm in Pyramid Pooling')
56-
5754
os.makedirs(model["common"]["checkpoint"], exist_ok=True)
5855

5956
num_classes = len(dataset["common"]["classes"])
60-
net = UNet(num_classes).to(device)
57+
net = FPNSegmentation(num_classes).to(device)
6158

6259
if model["common"]["cuda"]:
6360
torch.backends.cudnn.benchmark = True
@@ -66,9 +63,7 @@ def main(args):
6663
optimizer = Adam(net.parameters(), lr=model["opt"]["lr"], weight_decay=model["opt"]["decay"])
6764

6865
weight = torch.Tensor(dataset["weights"]["values"])
69-
7066
criterion = CrossEntropyLoss2d(weight=weight).to(device)
71-
# criterion = FocalLoss2d(weight=weight).to(device)
7267

7368
train_loader, val_loader = get_dataset_loaders(model, dataset)
7469

robosat/unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
8484
8585
Args:
8686
num_classes: number of classes to predict.
87+
num_filters: the number of filters for the decoder block
8788
pretrained: use ImageNet pre-trained backbone feature extractor
8889
"""
8990

0 commit comments

Comments
 (0)