Simon Shi的小站

人工智能,机器学习, 强化学习,大模型,自动驾驶

0%

Tensorflow权重迁移至Pytorch

Tensorflow权重迁移至Pytorch_tensorflow权重转pytorch_古月萝北的博客-CSDN博客

Conv2D层

Tensorflow的数据维度为(B,H,W,C), 而Pytorch的数据维度为(B,C,H,W), 因此二者卷积层的权重矩阵也是不一样的。Pytorch的为(out_channels,in_channels,H,W), Tensorflow的为(H,W,in_channels,out_channels), 因此权重迁移时需要转置权重矩阵。

此外,如果卷积带有bias,layer.get_weights()返回长度为2的列表,第一个元素为权重矩阵,第二个元素为bias.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Conv2dWithName(nn.Module):
def __init__(self,in_planes, out_planes, kernel_size=3, stride=1,padding=0, groups=1, use_bias=True, dilation=1,name=None):
super(Conv2dWithName, self).__init__()
self.conv2d=nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, groups=groups, bias=use_bias, dilation=dilation)
self.name=name #存储模块名称
self.use_bias=use_bias
def forward(self,x):
return self.conv2d(x)

def set_weight(self,layer):
with torch.no_grad():
print('INFO: init layer %s with tf weights'%self.name)
weights=layer.get_weights()
weight=weights[0]
weight=torch.from_numpy(weight)
weight=weight.permute((3,2,0,1))
self.conv2d.weight.copy_(weight)
if self.use_bias:
bias=weights[1]
bias = torch.from_numpy(bias)
self.conv2d.bias.copy_(bias)

Dense

类似的,dense层包含weight,bias两个权重参数。需要注意的是Pytorch的weight维度为(out_dims,in_dims),而Tensorflow正好相反为(in_dims,out_dims)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class DenseWithName(nn.Module):
def __init__(self,in_dim,out_dim,name=None):
super(DenseWithName, self).__init__()
self.dense=nn.Linear(in_dim,out_dim)
self.name=name
def set_weight(self,layer):
print('INFO: init layer %s with tf weights' % self.name)
with torch.no_grad():
weights = layer.get_weights()
weight = torch.from_numpy(weights[0]).transpose(0, 1)
self.dense.weight.copy_(weight)
bias = weights[1]
bias = torch.from_numpy(bias)
self.dense.bias.copy_(bias)
def forward(self,x):
return self.dense(x)

BatchNorm层

BatchNorm需要迁移weight、bias、running_mean、running_var四个参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class BatchNorm2dWithName(nn.Module):
def __init__(self,n_chaanels,name=None):
super(BatchNorm2dWithName, self).__init__()
self.bn=nn.BatchNorm2d(n_chaanels)
self.name=name
def forward(self,x):
return self.bn(x)

def set_weight(self,layer):
with torch.no_grad():
print('INFO: init layer %s with tf weights' % self.name)
weights=layer.get_weights()
gamma=torch.from_numpy(weights[0])
beta=torch.from_numpy(weights[1])
run_mean=torch.from_numpy(weights[2])
run_var= torch.from_numpy(weights[3])
self.bn.bias.copy_(beta)
self.bn.running_mean.copy_(run_mean)
self.bn.running_var.copy_(run_var)
self.bn.weight.copy_(gamma)

逐层迁移权重

我们可以参照已有的Tensorflow模型结构,利用上述封装好的层来搭建深度模型。迁移权重时可以遍历模型的所有层,逐层迁移权重。

1
2
3
4
for m in self.modules():#遍历模型的所有模块
if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):
layer=tf_model.get_layer(m.name)
m.set_weight(layer)

下面以ResNet50为例测试权重迁移:

Tensorflow模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def block1(x, filters, kernel_size=3, stride=1,
conv_shortcut=True, name=None):
"""A residual block.

# Arguments
x: input tensor.
filters: integer, filters of the bottleneck layer.
kernel_size: default 3, kernel size of the bottleneck layer.
stride: default 1, stride of the first layer.
conv_shortcut: default True, use convolution shortcut if True,
otherwise identity shortcut.
name: string, block label.

# Returns
Output tensor for the residual block.
"""
bn_axis = 3

if conv_shortcut is True:
shortcut = layers.Conv2D(4 * filters, 1, strides=stride,
name=name + '_0_conv')(x)
shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_0_bn')(shortcut)
else:
shortcut = x

x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x)
x = layers.Activation('relu', name=name + '_1_relu')(x)

x = layers.Conv2D(filters, kernel_size, padding='SAME',
name=name + '_2_conv')(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_2_bn')(x)
x = layers.Activation('relu', name=name + '_2_relu')(x)

x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_3_bn')(x)

x = layers.Add(name=name + '_add')([shortcut, x])
x = layers.Activation('relu', name=name + '_out')(x)
return x


def stack1(x, filters, blocks, stride1=2, name=None):
"""A set of stacked residual blocks.

# Arguments
x: input tensor.
filters: integer, filters of the bottleneck layer in a block.
blocks: integer, blocks in the stacked blocks.
stride1: default 2, stride of the first layer in the first block.
name: string, stack label.

# Returns
Output tensor for the stacked blocks.
"""
x = block1(x, filters, stride=stride1, name=name + '_block1')
for i in range(2, blocks + 1):
x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
return x


def ResNet50_TF(inputs,
preact=False,
use_bias=True,
model_name='resnet50'):

bn_axis = 3

x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(inputs)
x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)

if preact is False:
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name='conv1_bn')(x)
x = layers.Activation('relu', name='conv1_relu')(x)

x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)

outputs = []
x = stack1(x, 64, 3, stride1=1, name='conv2')

x = stack1(x, 128, 4, name='conv3')

x = stack1(x, 256, 6, name='conv4')

x = stack1(x, 512, 3, name='conv5')

x = layers.GlobalAveragePooling2D(name='avg_pool')(x)

x = layers.Dense(1, activation='linear', name='final_fc')(x)

# Create model.
model = models.Model(inputs, x, name=model_name)

return model

Pytorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1,name=None):
"""3x3 convolution with padding"""
return Conv2dWithName(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=True, dilation=dilation,name=name)


def conv1x1(in_planes, out_planes, stride=1,name=None):
"""1x1 convolution"""
return Conv2dWithName(in_planes, out_planes, kernel_size=1, stride=stride, bias=True,name=name)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out

class Bottleneck(nn.Module):

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None,name=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = BatchNorm2dWithName
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width,stride=stride,name=name+'_1_conv')
self.bn1 = norm_layer(width,name=name+'_1_bn')
self.conv2 = conv3x3(width, width, name=name+'_2_conv')
self.bn2 = norm_layer(width,name=name+'_2_bn')
self.conv3 = conv1x1(width, planes * self.expansion,name=name+'_3_conv')
self.bn3 = norm_layer(planes * self.expansion,name=name+'_3_bn')
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if not self.downsample is None:
self.downsample[0].name=name+'_0_conv'
self.downsample[1].name = name + '_0_bn'
self.stride = stride
self.name=name

def forward(self, x):
identity = x

out = checkpoint(self.conv1,x)
out = checkpoint(self.bn1,out)
out = self.relu(out)

out = checkpoint(self.conv2,out)
out = checkpoint(self.bn2,out)
out = self.relu(out)

out = checkpoint(self.conv3,out)
out = checkpoint(self.bn3,out)

if self.downsample is not None:
identity = checkpoint(self.downsample,x)

out += identity
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self, block, layers, width_per_group=64):
super(ResNet, self).__init__()

norm_layer = BatchNorm2dWithName
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1

self.base_width = width_per_group
self.conv1 = Conv2dWithName(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=True,name='conv1_conv')
self.bn1 = norm_layer(self.inplanes,name='conv1_bn')
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0],name='conv2')
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
name='conv3')
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
name='conv4')
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
name='conv5')
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

self.final_fc=DenseWithName(2048,1,name='final_fc')

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)


def _make_layer(self, block, planes, blocks, stride=1, name=None):

norm_layer = self._norm_layer
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride=stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(inplanes=self.inplanes, planes=planes, stride=stride, downsample=downsample,
name=name+'_block1'))
self.inplanes = planes * block.expansion
for lyer in range(1, blocks):
layers.append(block(self.inplanes, planes, base_width=self.base_width, dilation=self.dilation,
name=name+'_block%d'%(lyer+1)))

return nn.Sequential(*layers)

def init_from_tf(self,tf_model):
for m in self.modules():
if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):
layer=tf_model.get_layer(m.name)
m.set_weight(layer)


def _forward_impl(self, x):

# See note [TorchScript super()]
x = checkpoint(self.conv1,x)
x = checkpoint(self.bn1,x)
x = self.relu(x)
x = F.max_pool2d(x,kernel_size=3, stride=2, padding=1)

x = self.layer1(x)

x = self.layer2(x)

x = self.layer3(x)

x = self.layer4(x)

x=self.avgpool(x).squeeze(-1).squeeze(-1)

x=self.final_fc(x)
return x

def forward(self, x):
return self._forward_impl(x)


def _resnet(arch, block, layers, **kwargs):
model = ResNet(block, layers, **kwargs)
return model

def resnet50_torch(**kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
input_shape = (None, None, 3)
inputs = Input(shape=input_shape)
res50_tf=ResNet50_TF(inputs)
res50_tf.load_weights('./src/Resnet——weights.h5',by_name=True)
res50_torch=resnet50_torch().float()
res50_torch.init_from_tf(res50_tf)
res50_torch.eval()
img=np.random.rand(1,224,224,3)
img2=torch.from_numpy(img).permute([0,3,1,2]).float()
p_tf=res50_tf.predict(img)
p_torch=res50_torch(img2).data.numpy()
print('tensorflow predict: %f '%p_tf[0])
print('pytorch predict: %f '%p_torch[0])

Data

1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf
import torch
from torch import nn
import numpy as np


image = np.random.rand(5, 5, 3)
torch_image = np.transpose(image, (2, 0, 1)).astype(np.float32)
# [B, C, H, W] for pytorch
torch_image = torch.unsqueeze(torch.as_tensor(torch_image), dim=0)
# [B, H, W, C] for tensorflow
tf_image = np.expand_dims(image, axis=0)

参考

将Tensorflow权重转移到等效的Pytorch模型 _大数据知识库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def weight_loading(pretrained_weights):
# Load the weights
tf_model = tf.keras.models.load_model(pretrained_weights)
tf_weights = tf_model.get_weights()
# Load the PyTorch model
pt_model = UNet() #implemented based on the previous model (by myself)
initial_state_dict = pt_model.state_dict()
new_state_dict = {}
with torch.no_grad():
x = 0
for i, layer in enumerate(pt_model.modules()):
if isinstance(layer, torch.nn.Conv2d):
# extract the weights and biases from the TensorFlow weights
weight_tf = tf_weights[x*2]
bias_tf = tf_weights[x*2+1]

# convert the weights and biases to PyTorch format
weight_pt = torch.tensor(weight_tf.transpose())
bias_pt = torch.tensor(bias_tf)
# get the name of the weight and bias tensors
weight_name = list(pt_model.named_parameters())[x*2][0]
bias_name = list(pt_model.named_parameters())[x*2+1][0]
# set the weights and biases in the PyTorch model state_dict
new_state_dict[weight_name]= weight_pt
new_state_dict[bias_name] = bias_pt
x = x + 1
if isinstance(layer, torch.nn.ConvTranspose2d):
weight_tf = tf_weights[x*2]
bias_tf = tf_weights[x*2+1]

# convert the weights and biases to PyTorch format
weight_pt = torch.tensor(np.transpose(weight_tf, (2, 3, 0, 1)))
bias_pt = torch.tensor(bias_tf)
# get the name of the weight and bias tensors
weight_name = list(pt_model.named_parameters())[x*2][0]
bias_name = list(pt_model.named_parameters())[x*2+1][0]
# set the weights and biases in the PyTorch model state_dict
new_state_dict[weight_name] = weight_pt
new_state_dict[bias_name] = bias_pt
x = x + 1
# load the new generated state_dict to pt_model
pt_model.load_state_dict(new_state_dict)
return pt_model