PB(tf) to Pth(Torch)

[已经验证], 需要根据网络模块修改

将tensorpack的inference改为pytorch_.pb转换为.pth_云端一散仙的博客-CSDN博客

将pb文件转为pth文件

  • 相关的文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from collections import OrderedDict
import tensorflow as tf
from tensorflow.python.framework import tensor_util
def view_params():
pb_file = 'ocr/checkpoint/text_recognition_377500.pb'
graph = tf.Graph()
with graph.as_default():
with tf.gfile.FastGFile(pb_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
wts = [n for n in graph_nodes if n.op=='Const']

odic = OrderedDict()
for n in wts:
param = tensor_util.MakeNdarray(n.attr['value'].tensor)
if not param.size == 0:
odic[n.name] = tensor_util.MakeNdarray(n.attr['value'].tensor)
torch.save(odic, 'pb_377500.pth')

模型代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class TextRecognition(nn.Module):
def __init__(self):
super(TextRecognition, self).__init__()
self.features = nn.Sequential(OrderedDict([
('Conv2d_1a_3x3', BasicConv2d(3, 32, kernel_size=3, stride=2, padding='SAME')),
('Conv2d_2a_3x3', BasicConv2d(32, 32, kernel_size=3, stride=1, padding='SAME')),
...
('Mixed_6h', Inception_B()),
]))
self.attention_lstm = AttentionLstm()

def forward(self, x):
x = self.features(x)
x = self.attention_lstm(x)
return x

class LinearBias(nn.Module):
def __init__(self, size):
super(LinearBias, self).__init__()
self.param = nn.Parameter(torch.Tensor(size))

def forward(self, x):
x = x + self.param
return x

class AttentionLstm(nn.Module):
def __init__(self, seq_len=33, is_training=False, num_classes=7569,
wemb_size=256, channel=1024, lstm_size=512):
super(AttentionLstm, self).__init__()
self.seq_len = seq_len # 33
...
self.W_wemb = nn.Linear(self.num_classes, self.wemb_size, bias=False)
self.lstm_b = LinearBias(self.lstm_size*4)
self.tanh = nn.Tanh()
self.softmax_1d = nn.Softmax(dim=1)
self.sigmoid = nn.Sigmoid()
self.dropout_1d = nn.Dropout(0.)

def forward(self, cnn_feature): # bs, 1024, h, w
_, _, self.height, self.width = cnn_feature.size()
...
return output_array, attention_array

Pytorch 与 TensorFlow 二维卷积(Conv2d)填充(padding)上的差异,写卷积层的时候遇到的坑。
这种差异是由 TensorFlow 和 Pytorch 在卷积运算时使用的填充方式不同导致的。Pytorch 在填充的时候,上、下、左、右各方向填充的大小是一样的,但 TensorFlow 却允许不一样。
参考博客1Pytorch 与 TensorFlow 二维卷积(Conv2d)填充(padding)上的差异 - 简书
参考博客2tensorflow与pytorch卷积填充方式的差异 - 简书

在AttentionLstm中,有一个LinearBias类,该类会将pack和self.lstm_b加起来,但是如果在forward中写成相加的形式,就不能将该self.lstm_b保存下来,写成类可以使模型加载参数的时候可以一次加载完成。

1
2
3
4
5
6
7
8
9
10
11
12
class AttentionLstm(nn.Module):
def __init__(self):
super(AttentionLstm, self).__init__()
self.seq_len = 33 # 33
self.W_wemb = nn.Linear(10, 20, bias=False)
self.lstm_b = LinearBias(4)
self.a = nn.Parameter(torch.Tensor(1))
self.b = torch.randn(1, 3)

test = AttentionLstm()
# odict_keys(['a', 'W_wemb.weight', 'lstm_b.param']),self.b不会保存在state_dict中,而self.lstm_b会保存
print(test.state_dict())
1
2
pack = self.lstm_W(wemb_prev) + self.lstm_U(h_prev) + self.lstm_Z(attention_feature)  # bs, 2048
pack_with_bias = self.lstm_b(pack)

原代码使用的大都是tensorflow的函数,所以要改成相应的pytorch的函数。

tensorflow pytorch
tf.matmul torch.matmul
tf.multiply torch.mul
tf.sigmoid torch.nn.Sigmoid
tf.nn .dropout torch.nn.Dropout
tf.nn .softmax torch.nn.Softmax
tf.tanh torch.tanh
tf.split torch.split
tf.shape torch.size
tf.reshape / tf.transpose torch.reshape / view
tf.expand_dims torch.unsqueeze
tf.add_n /tf.add torch.add
tf.reduce_sum torch.sum
tf.reduce_mean torch.mean
tf.transpose torch.permute
tf.concat torch.cat
tf.nn .embedding_lookup torch.index_select

加载参数

最后加载参数验证

1
2
net = attention_ocr_pytorch.TextRecognition()
net.load_state_dict(torch.load('log/pytorch/pb_377500_fl.pth'))

end

Sequence 使用Reshape

pytorch中没有nn.Reshape层,如果想使用 reshape 功能,通常:

1
2
3
4
5
6
7
8
class Net(nn.Module):
def __init__(self):
super().__init__()
...
def forward(self, x):
...
h = h.view(-1, 128)
...

如果要想在 nn.Sequential 中使用 Reshape 功能,可以自定义Reshape层:

1
2
3
4
5
6
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view((x.size(0),)+self.shape)

然后就可以直接在nn.Sequential中使用Reshape功能了:

1
2
3
4
5
nn.Sequential(
nn.Linear(10, 64*7*7),
Reshape(64, 7, 7),
...
)

原文链接:https://blog.csdn.net/d14665/article/details/112218767

Reference

pytorch 模型与tf模型转换_tf.cast如何用pytorch实现_zhurui_xiaozhuzaizai的博客-CSDN博客

使用Transformers将TF模型转化成PyTorch模型_tf模型转pytorch_亚林瓜子的博客-CSDN博客

https://github.com/huggingface/transformers/blob/ad7196524695f3bb3e178d57d280bd18fa175ca6/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py