Skip to content

Commit 1e9936c

Browse files
committed
get same model output with keras
1 parent 56d23d1 commit 1e9936c

File tree

4 files changed

+287
-86
lines changed

4 files changed

+287
-86
lines changed

cspnet.py

+157-82
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,82 @@
11
import torch
22
import math
33
import torch.nn as nn
4-
import h5py
5-
import numpy as np
6-
#from resnet import *
74
from l2norm import L2Norm
5+
from samepad import SamePad2d
6+
7+
class IdentityBlock(nn.Module):
8+
expansion = 4
9+
def __init__(self, inchannels, filters, dila=1):
10+
super(IdentityBlock, self).__init__()
11+
self.conv1 = nn.Conv2d(inchannels, filters, kernel_size=1)
12+
self.bn1 = nn.BatchNorm2d(filters, eps=1e-03, momentum=0.01)
13+
self.samepad = SamePad2d(3, 1, dilation=dila)
14+
self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, dilation=dila)
15+
self.bn2 = nn.BatchNorm2d(filters, eps=1e-03, momentum=0.01)
16+
self.conv3 = nn.Conv2d(filters, filters * self.expansion, kernel_size=1)
17+
self.bn3 = nn.BatchNorm2d(filters * self.expansion, eps=1e-03, momentum=0.01)
18+
self.relu = nn.ReLU(inplace=True)
19+
20+
def forward(self, x):
21+
out = self.conv1(x)
22+
out = self.bn1(out)
23+
out = self.relu(out)
24+
25+
print('a shape --- ', out.shape)
26+
out = self.samepad(out)
27+
print('b shape --- ', out.shape)
28+
out = self.conv2(out)
29+
print('c shape --- ', out.shape)
30+
out = self.bn2(out)
31+
out = self.relu(out)
32+
33+
out = self.conv3(out)
34+
out = self.bn3(out)
35+
36+
out += x
37+
out = self.relu(out)
38+
39+
return out
40+
41+
class ConvBlock(nn.Module):
42+
expansion = 4
43+
def __init__(self, inchannels, filters, s=2, dila=1):
44+
super(ConvBlock, self).__init__()
45+
self.conv1 = nn.Conv2d(inchannels, filters, kernel_size=1, stride=s)
46+
self.bn1 = nn.BatchNorm2d(filters, eps=1e-03, momentum=0.01)
47+
self.samepad = SamePad2d(3, 1, dilation=dila)
48+
self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, dilation=dila)
49+
self.bn2 = nn.BatchNorm2d(filters, eps=1e-03, momentum=0.01)
50+
self.conv3 = nn.Conv2d(filters, filters * self.expansion, kernel_size=1)
51+
self.bn3 = nn.BatchNorm2d(filters * self.expansion, eps=1e-03, momentum=0.01)
52+
self.conv4 = nn.Conv2d(inchannels, filters * self.expansion, kernel_size=1, stride=s)
53+
self.bn4 = nn.BatchNorm2d(filters * self.expansion, eps=1e-03, momentum=0.01)
54+
self.relu = nn.ReLU(inplace=True)
55+
56+
def forward(self, x):
57+
out = self.conv1(x)
58+
out = self.bn1(out)
59+
out = self.relu(out)
60+
61+
print('a shape --- ', out.shape)
62+
out = self.samepad(out)
63+
print('b shape --- ', out.shape)
64+
out = self.conv2(out)
65+
print('c shape --- ', out.shape)
66+
out = self.bn2(out)
67+
out = self.relu(out)
68+
69+
out = self.conv3(out)
70+
out = self.bn3(out)
71+
72+
shortcut = self.conv4(x)
73+
shortcut = self.bn4(shortcut)
74+
print('shortcut shape --- ', shortcut.shape)
75+
76+
out += shortcut
77+
out = self.relu(out)
78+
79+
return out
880

981

1082
class CSPNet_p3p4p5(nn.Module):
@@ -14,13 +86,30 @@ def __init__(self):
1486
#resnet = resnet50(pretrained=True, receptive_keep=True)
1587

1688
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
17-
#self.bn1 = resnet.bn1
18-
#self.relu = resnet.relu
19-
#self.maxpool = resnet.maxpool
20-
#self.layer1 = resnet.layer1
21-
#self.layer2 = resnet.layer2
22-
#self.layer3 = resnet.layer3
23-
#self.layer4 = resnet.layer4
89+
self.bn1 = nn.BatchNorm2d(64, eps=1e-03, momentum=0.01)
90+
self.relu = nn.ReLU(inplace=True)
91+
self.samepad1 = SamePad2d(3, 2)
92+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
93+
94+
self.convblk2a = ConvBlock(64, 64, s=1)
95+
self.identityblk2b = IdentityBlock(256, 64)
96+
self.identityblk2c = IdentityBlock(256, 64)
97+
98+
self.convblk3a = ConvBlock(256, 128)
99+
self.identityblk3b = IdentityBlock(512, 128)
100+
self.identityblk3c = IdentityBlock(512, 128)
101+
self.identityblk3d = IdentityBlock(512, 128)
102+
103+
self.convblk4a = ConvBlock(512, 256)
104+
self.identityblk4b = IdentityBlock(1024, 256)
105+
self.identityblk4c = IdentityBlock(1024, 256)
106+
self.identityblk4d = IdentityBlock(1024, 256)
107+
self.identityblk4e = IdentityBlock(1024, 256)
108+
self.identityblk4f = IdentityBlock(1024, 256)
109+
110+
self.convblk5a = ConvBlock(1024, 512, s=1, dila=2)
111+
self.identityblk5b = IdentityBlock(2048, 512, dila=2)
112+
self.identityblk5c = IdentityBlock(2048, 512, dila=2)
24113

25114
self.p3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
26115
self.p4 = nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=4, padding=0)
@@ -37,56 +126,70 @@ def __init__(self):
37126
self.p4_l2 = L2Norm(256, 10)
38127
self.p5_l2 = L2Norm(256, 10)
39128

40-
self.feat = nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1, bias=False)
41-
self.feat_bn = nn.BatchNorm2d(256, momentum=0.01)
42-
self.feat_act = nn.ReLU(inplace=True)
129+
self.feat = nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1)
130+
self.feat_bn = nn.BatchNorm2d(256, eps=1e-03, momentum=0.01)
43131

44-
self.pos_conv = nn.Conv2d(256, 1, kernel_size=1)
45-
self.reg_conv = nn.Conv2d(256, 1, kernel_size=1)
46-
self.off_conv = nn.Conv2d(256, 2, kernel_size=1)
132+
self.center_conv = nn.Conv2d(256, 1, kernel_size=1)
133+
self.height_conv = nn.Conv2d(256, 1, kernel_size=1)
134+
self.offset_conv = nn.Conv2d(256, 2, kernel_size=1)
47135

48136
nn.init.xavier_normal_(self.feat.weight)
49-
nn.init.xavier_normal_(self.pos_conv.weight)
50-
nn.init.xavier_normal_(self.reg_conv.weight)
51-
nn.init.xavier_normal_(self.off_conv.weight)
52-
nn.init.constant_(self.pos_conv.bias, -math.log(0.99/0.01))
53-
nn.init.constant_(self.reg_conv.bias, 0)
54-
nn.init.constant_(self.off_conv.bias, 0)
137+
nn.init.xavier_normal_(self.center_conv.weight)
138+
nn.init.xavier_normal_(self.height_conv.weight)
139+
nn.init.xavier_normal_(self.offset_conv.weight)
140+
nn.init.constant_(self.center_conv.bias, -math.log(0.99/0.01))
141+
nn.init.constant_(self.height_conv.bias, 0)
142+
nn.init.constant_(self.offset_conv.bias, 0)
55143

56144
def forward(self, x):
57145
x = self.conv1(x)
58-
x = x.permute(0, 2, 3, 1)
59-
return x
60-
#x = self.bn1(x)
61-
#x = self.relu(x)
62-
#x = self.maxpool(x)
63-
64-
#x = self.layer1(x)
65-
66-
#x = self.layer2(x)
67-
#p3 = self.p3(x)
68-
#p3 = self.p3_l2(p3)
69-
70-
#x = self.layer3(x)
71-
#p4 = self.p4(x)
72-
#p4 = self.p4_l2(p4)
73-
74-
#x = self.layer4(x)
75-
#p5 = self.p5(x)
76-
#p5 = self.p5_l2(p5)
77-
78-
#cat = torch.cat([p3, p4, p5], dim=1)
79-
80-
#feat = self.feat(cat)
81-
#feat = self.feat_bn(feat)
82-
#feat = self.feat_act(feat)
83-
84-
#x_cls = self.pos_conv(feat)
85-
#x_cls = torch.sigmoid(x_cls)
86-
#x_reg = self.reg_conv(feat)
87-
#x_off = self.off_conv(feat)
88-
89-
#return x_cls, x_reg, x_off
146+
x = self.bn1(x)
147+
x = self.relu(x)
148+
x = self.samepad1(x)
149+
x = self.maxpool(x)
150+
151+
x = self.convblk2a(x)
152+
x = self.identityblk2b(x)
153+
stage2 = self.identityblk2c(x)
154+
155+
x = self.convblk3a(stage2)
156+
x = self.identityblk3b(x)
157+
x = self.identityblk3c(x)
158+
stage3 = self.identityblk3d(x)
159+
160+
x = self.convblk4a(stage3)
161+
x = self.identityblk4b(x)
162+
x = self.identityblk4c(x)
163+
x = self.identityblk4d(x)
164+
x = self.identityblk4e(x)
165+
stage4 = self.identityblk4f(x)
166+
167+
x = self.convblk5a(stage4)
168+
x = self.identityblk5b(x)
169+
stage5 = self.identityblk5c(x)
170+
171+
p3up = self.p3(stage3)
172+
p4up = self.p4(stage4)
173+
p5up = self.p5(stage5)
174+
p3up = self.p3_l2(p3up)
175+
p4up = self.p4_l2(p4up)
176+
p5up = self.p5_l2(p5up)
177+
cat = torch.cat([p3up, p4up, p5up], dim=1)
178+
179+
feat = self.feat(cat)
180+
feat = self.feat_bn(feat)
181+
feat = self.relu(feat)
182+
183+
x_cls = self.center_conv(feat)
184+
x_cls = torch.sigmoid(x_cls)
185+
x_reg = self.height_conv(feat)
186+
x_off = self.offset_conv(feat)
187+
188+
x_cls = x_cls.permute(0, 2, 3, 1)
189+
x_reg = x_reg.permute(0, 2, 3, 1)
190+
x_off = x_off.permute(0, 2, 3, 1)
191+
192+
return x_cls, x_reg, x_off
90193

91194
# def train(self, mode=True):
92195
# # Override train so that the training mode is set as we want
@@ -104,32 +207,4 @@ def forward(self, x):
104207
# else:
105208
# m.eval()
106209
# self.layer1.apply(set_bn_train)
107-
def load_keras_weights(self, weights_path):
108-
with h5py.File(weights_path, 'r') as f:
109-
#model_weights = f['model_weights']
110-
#layer_names = list(map(str, model_weights.keys()))
111-
#state_dict = OrderedDict()
112-
113-
print(f.attrs['layer_names'])
114-
print(f['conv1'].attrs.keys())
115-
print(f['conv1'].attrs['weight_names'])
116-
print(f['conv1']['conv1_1/kernel:0'])
117-
118-
w = np.asarray(f['conv1']['conv1_1/kernel:0'], dtype='float32')
119-
b = np.asarray(f['conv1']['conv1_1/bias:0'], dtype='float32')
120-
self.conv1.weight = torch.nn.Parameter(torch.from_numpy(w).permute(3, 2, 0, 1))
121-
self.conv1.bias = torch.nn.Parameter(torch.from_numpy(b))
122-
print('b shape ', b.shape)
123-
124-
125-
print(self.conv1.weight.shape)
126-
print(self.conv1.bias.shape)
127-
#print('weight, ', self.conv1.weight.permute(2, 3, 1, 0))
128-
#print('bias, ', self.conv1.bias)
129-
#num_w = conv_layer.weight.numel()
130-
#conv_w = torch.from_numpy(weights[ptr:ptr + num_w]).view_as(conv_layer.weight)
131-
#conv_layer.weight.data.copy_(conv_w)
132-
133-
134-
135210

keras_weights_loader.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import h5py
3+
import numpy as np
4+
#from cspnet import CSPNet_p3p4p5, ConvBlock
5+
6+
def load_conv_weights(conv, f, layer_name):
7+
w = np.asarray(f[layer_name][layer_name + '_1/kernel:0'], dtype='float32')
8+
b = np.asarray(f[layer_name][layer_name + '_1/bias:0'], dtype='float32')
9+
conv.weight = torch.nn.Parameter(torch.from_numpy(w).permute(3, 2, 0, 1))
10+
conv.bias = torch.nn.Parameter(torch.from_numpy(b))
11+
12+
def load_bn_weights(bn, f, layer_name):
13+
w = np.asarray(f[layer_name][layer_name + '_1/gamma:0'], dtype='float32')
14+
b = np.asarray(f[layer_name][layer_name + '_1/beta:0'], dtype='float32')
15+
m = np.asarray(f[layer_name][layer_name + '_1/moving_mean:0'], dtype='float32')
16+
v = np.asarray(f[layer_name][layer_name + '_1/moving_variance:0'], dtype='float32')
17+
bn.weight = torch.nn.Parameter(torch.from_numpy(w))
18+
bn.bias = torch.nn.Parameter(torch.from_numpy(b))
19+
bn.running_mean = torch.from_numpy(m)
20+
bn.running_var = torch.from_numpy(v)
21+
22+
def load_conv_block_weights(conv_blk, f, blk_name):
23+
load_conv_weights(conv_blk.conv1, f, 'res' + blk_name + '_branch2a')
24+
load_bn_weights(conv_blk.bn1, f, 'bn' + blk_name + '_branch2a')
25+
load_conv_weights(conv_blk.conv2, f, 'res' + blk_name + '_branch2b')
26+
load_bn_weights(conv_blk.bn2, f, 'bn' + blk_name + '_branch2b')
27+
load_conv_weights(conv_blk.conv3, f, 'res' + blk_name + '_branch2c')
28+
load_bn_weights(conv_blk.bn3, f, 'bn' + blk_name + '_branch2c')
29+
load_conv_weights(conv_blk.conv4, f, 'res' + blk_name + '_branch1')
30+
load_bn_weights(conv_blk.bn4, f, 'bn' + blk_name + '_branch1')
31+
32+
def load_identity_block_weights(identity_blk, f, blk_name):
33+
load_conv_weights(identity_blk.conv1, f, 'res' + blk_name + '_branch2a')
34+
load_bn_weights(identity_blk.bn1, f, 'bn' + blk_name + '_branch2a')
35+
load_conv_weights(identity_blk.conv2, f, 'res' + blk_name + '_branch2b')
36+
load_bn_weights(identity_blk.bn2, f, 'bn' + blk_name + '_branch2b')
37+
load_conv_weights(identity_blk.conv3, f, 'res' + blk_name + '_branch2c')
38+
load_bn_weights(identity_blk.bn3, f, 'bn' + blk_name + '_branch2c')
39+
40+
def load_l2norm_weights(l2norm, f, layer_name):
41+
w = np.asarray(f[layer_name][layer_name + '_1/' + layer_name + '_gamma:0'], dtype='float32')
42+
l2norm.weight = torch.nn.Parameter(torch.from_numpy(w))
43+
44+
def load_keras_weights(model, weights_path):
45+
with h5py.File(weights_path, 'r') as f:
46+
print(f.attrs['layer_names'])
47+
48+
load_conv_weights(model.conv1, f, 'conv1')
49+
load_bn_weights(model.bn1, f, 'bn_conv1')
50+
51+
load_conv_block_weights(model.convblk2a, f, '2a')
52+
load_identity_block_weights(model.identityblk2b, f, '2b')
53+
load_identity_block_weights(model.identityblk2c, f, '2c')
54+
55+
load_conv_block_weights(model.convblk3a, f, '3a')
56+
load_identity_block_weights(model.identityblk3b, f, '3b')
57+
load_identity_block_weights(model.identityblk3c, f, '3c')
58+
load_identity_block_weights(model.identityblk3d, f, '3d')
59+
60+
load_conv_block_weights(model.convblk4a, f, '4a')
61+
load_identity_block_weights(model.identityblk4b, f, '4b')
62+
load_identity_block_weights(model.identityblk4c, f, '4c')
63+
load_identity_block_weights(model.identityblk4d, f, '4d')
64+
load_identity_block_weights(model.identityblk4e, f, '4e')
65+
load_identity_block_weights(model.identityblk4f, f, '4f')
66+
67+
load_conv_block_weights(model.convblk5a, f, '5a')
68+
load_identity_block_weights(model.identityblk5b, f, '5b')
69+
load_identity_block_weights(model.identityblk5c, f, '5c')
70+
71+
load_conv_weights(model.p3, f, 'P3up')
72+
load_conv_weights(model.p4, f, 'P4up')
73+
load_conv_weights(model.p5, f, 'P5up')
74+
75+
load_l2norm_weights(model.p3_l2, f, 'P3norm')
76+
load_l2norm_weights(model.p4_l2, f, 'P4norm')
77+
load_l2norm_weights(model.p5_l2, f, 'P5norm')
78+
79+
load_conv_weights(model.feat, f, 'feat')
80+
load_bn_weights(model.feat_bn, f, 'bn_feat')
81+
82+
load_conv_weights(model.center_conv, f, 'center_cls')
83+
load_conv_weights(model.height_conv, f, 'height_regr')
84+
load_conv_weights(model.offset_conv, f, 'offset_regr')
85+

samepad.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
5+
6+
class SamePad2d(nn.Module):
7+
"""Mimics tensorflow's 'SAME' padding.
8+
"""
9+
10+
def __init__(self, kernel_size, stride, dilation=1):
11+
super(SamePad2d, self).__init__()
12+
self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
13+
self.stride = torch.nn.modules.utils._pair(stride)
14+
self.dilation = torch.nn.modules.utils._pair(dilation)
15+
16+
def forward(self, input):
17+
in_width = input.size()[3]
18+
in_height = input.size()[2]
19+
out_width = math.ceil(float(in_width) / float(self.stride[0]))
20+
out_height = math.ceil(float(in_height) / float(self.stride[1]))
21+
22+
effective_kernel_size_width = (self.kernel_size[0] - 1) * self.dilation[0] + 1
23+
effective_kernel_size_height = (self.kernel_size[1] - 1) * self.dilation[1] + 1
24+
25+
pad_along_width = ((out_width - 1) * self.stride[0] +
26+
effective_kernel_size_width - in_width)
27+
pad_along_height = ((out_height - 1) * self.stride[1] +
28+
effective_kernel_size_height - in_height)
29+
pad_left = math.floor(pad_along_width / 2)
30+
pad_top = math.floor(pad_along_height / 2)
31+
pad_right = pad_along_width - pad_left
32+
pad_bottom = pad_along_height - pad_top
33+
return F.pad(input, (pad_left, pad_right, pad_top, pad_bottom), 'constant', 0)
34+
35+
def __repr__(self):
36+
return self.__class__.__name__

0 commit comments

Comments
 (0)