将Pytorch卷积层权重转到Tensorflow中
上面刚刚说了在Pytorch的卷积层中,kernel weights存储格式是[kernel_number, kernel_channel, kernel_height, kernel_width]
,但在Tensorflow的卷积层中kernel weights存储格式是[kernel_height, kernel_width, kernel_channel, kernel_number]
。还有就是在卷积层中如果使用了bias那么bias weights是不需要处理的,因为卷积的bias weights只有一个维度,所以Pytorch和Tensorflow中存储的格式是一样的(后面测试也能验证这个结论)。 在下面代码中:
- 分别使用Pytorch和Tensorflow的Keras模块创建了卷积层
- 获取Pytorch创建卷积层的kernel weight以及bias weight
- 使用numpy对kernel weight的进行transpose处理
- 将转换后的权重载入到tensorflow的卷积层中
- 将之前创建的数据分别传入Pytorch和Tensorflow的卷积层中进行正向传播
- 再使用numpy对Pytorch得到的结果进行transpose处理(保证和tensorflow输出的结果Tensor格式一致)
- 对比两者输出的结果是否一致
1 | def conv_test(torch_image, tf_image): |
将Pytorch DW卷积层权重转到Tensorflow中
在Pytorch的dw卷积层中,dw kernel weights存储格式是[kernel_number, kernel_channel, kernel_height, kernel_width]
,但在Tensorflow的dw卷积层中dw kernel weights存储格式是[kernel_height, kernel_width, kernel_number, kernel_channel]
(注意这里最后两个维度和卷积层有些差异)。同样在dw卷积层中如果使用了bias那么dw bias weights是不需要处理的。
在下面代码中:
分别使用Pytorch和Tensorflow的Keras模块创建了dw卷积层
获取Pytorch创建dw卷积层的dw kernel weight以及dw bias weight
使用numpy对dw kernel weight的进行transpose处理
将转换后的权重载入到tensorflow的dw卷积层中
将之前创建的数据分别传入Pytorch和Tensorflow的dw卷积层中进行正向传播
再使用numpy对Pytorch得到的结果进行transpose处理(保证和tensorflow输出的结果Tensor格式一致)
对比两者输出的结果是否一致
1 | def dw_conv_test(torch_image, tf_image): |
将Pytorch BN层权重转到Tensorflow中
BatchNorm中涉及4个参数:gamma,beta,mean,var
。由于这四个参数的shape都是一维的,所以只要找到对应权重名称关系就行了,不需要对数据进行转换。
在Pytorch中,这四个参数的名称分别对应weight,bias,running_mean,running_var
。
在Tensorflow中,分别对应gamma,beta,moving_mean,moving_variance
。
在下面代码中:
- 分别使用Pytorch和Tensorflow的Keras模块创建了bn层(注意,epsilon要保持一致)
- 随机初始化Pytorch创建bn层的权重信息(默认初始化weight都是1,bias都是0)
- 获取Pytorch随机初始化后bn的weight,bias,running_mean以及running_var
- 将对应的权重载入到tensorflow的bn层中
- 将之前创建的数据分别传入Pytorch和Tensorflow的bn层中进行正向传播
- 再使用numpy对Pytorch得到的结果进行transpose处理(保证和tensorflow输出的结果Tensor格式一致)
- 对比两者输出的结果是否一致
1 | def bn_test(torch_image, tf_image): |
将Pytorch全连接层权重转到Tensorflow中
在全连接层中涉及两个参数:输入节点个数,和输出节点个数。转换权重时只用转换fc weight
即可,fc bias
不用做任何处理。 在下面代码中:
- 对输入的特征矩阵在height以及width维度上进行全局平均池化
- 分别使用Pytorch和Tensorflow的Keras模块创建了fc层
- 获取Pytorch创建fc层的fc weight以及fc bias
- 使用numpy对fc weight的进行transpose处理
- 将转换后的权重载入到tensorflow的fc层中
- 将之前创建的数据分别传入Pytorch和Tensorflow的卷积层中进行正向传播
- 对比两者输出的结果是否一致
1 | def fc_test(torch_image, tf_image): |
完整代码
1 | import tensorflow as tf |