Skip to content

Commit 08c8532

Browse files
authored
Merge pull request #483 from guoshengCS/add-tf2paddle
Add tf2paddle to convert TensorFlow models to PaddlePaddle models.
2 parents 2e13e06 + de95c53 commit 08c8532

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
## 使用说明
2+
3+
`tf2paddle.py`脚本中的工具类`TFModelConverter`实现了将TensorFlow训练好的模型文件转换为PaddlePaddle可加载的模型文件。目前能够支持图像领域常用的:卷积(`Convolution`)层、`Batch Normalization`层和全连接(`Full Connection`)层。图像领域常用的 `ResNet` `VGG` 网络都以这些层此为基础,使用TensorFlow训练的`ResNet``VGG`模型能够被转换为PaddlePaddle可加载的模型,进一步用于预训练或是预测服务的开发等。
4+
5+
模型转换的基本流程是:
6+
1. 将TensorFlow模型等价地使用PaddlePaddle Python API接口进行改写。
7+
1. 在TensorFlow中可学习参数用 `Variable` 表示,基于TensorFlow的Python API获取网络中的 Variable。
8+
1. 确定TensorFlow模型中`Variable`与PaddlePaddle中`paddle.layer`的可学习参数的对应关系。
9+
1. 对TensorFlow中的`Variable`进行一定的适配(详见下文),转化为PaddlePaddle中的参数存储格式并进行序列化保存。
10+
11+
### 需要遵守的约定
12+
13+
为使TensorFlow模型中的`Variable`能够正确对应到`paddle.layer`中的可学习参数,目前版本在使用时有如下约束需要遵守:
14+
15+
1. 目前仅支持将TensorFlow中 `conv2d``batchnorm``fc`这三种带有可学习`Variable`的Operator训练出的参数向PaddlePaddle模型参数转换。
16+
1. TensorFlow网络配置中同一Operator内的`Variable`属于相同的scope,以此为依据将`Variable`划分到不同的`paddle.layer`
17+
1. `conv2d``batchnorm``fc`的scope需分别包含`conv``bn``fc`,以此获取对应`paddle.layer`的类型。也可以通过为`TFModelConverter`传入`layer_type_map``dict`,将scope映射到对应的`paddle.layer`的type来规避此项约束。
18+
1. `conv2d``fc``Variable`的顺序为:先可学习`Weight``Bias``batchnorm``Variable`的顺序为:`scale``shift``mean``var`,请注意参数存储的顺序将`Variable`对应到`paddle.layer.batch_norm`相应位置的参数。
19+
1. TensorFlow网络拓扑顺序需和PaddlePaddle网络拓扑顺序一致,尤其注意网络包含分支结构时分支定义的先后顺序,如ResNet的bottleneck模块中两分支定义的先后顺序。这是针对模型转换和PaddlePaddle网络配置均使用PaddlePaddle默认参数命名的情况,此时将根据拓扑顺序进行参数命名。
20+
1. 若PaddlePaddle网络配置中需要通过调用`param_attr=paddle.attr.Param(name="XX"))`显示地设置可学习参数名字,这时可通过为`TFModelConverter`传入`layer_name_map``param_name_map`字典(类型为Python `dict`),在模型转换时将`Variable`的名字映射为所对应的`paddle.layer.XX`中可学习参数的名字。
21+
1. 要求提供`build_model`接口以从此构建TensorFlow网络,加载模型并返回session。可参照如下示例进行编写:
22+
23+
```python
24+
def build_model():
25+
build_graph()
26+
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
27+
sess.run(tf.tables_initializer())
28+
saver = tf.train.Saver()
29+
saver.restore(sess, 'model/model.ckpt')
30+
return sess
31+
```
32+
33+
### 使用说明
34+
35+
按照以上原则操作后,`tf2paddle.py` 脚本的`main`函数提供了一个调用示例,将TensorFlow训练的`ResNet50`模型转换为PaddlePaddle可加载模型。若要对其它各种自定义的模型进行转换,只需修改相关变量的值,在终端执行`python tf2paddle.py`即可。
36+
37+
下面是一个简单的调用示例:
38+
39+
```python
40+
# 定义相关变量
41+
tf_net = "TF_ResNet50" # 提供build_model的module名
42+
paddle_tar_name = "Paddle_ResNet50.tar.gz" # 输出的Paddle模型的文件名
43+
44+
# 初始化并加载模型
45+
converter = TFModelConverter(tf_net=tf_net,
46+
paddle_tar_name=paddle_tar_name)
47+
# 进行模型转换
48+
converter.convert()
49+
```
50+
51+
### 注意事项
52+
53+
1. 由于TensorFlow中的padding机制较为特殊,在编写PaddlePaddle网络配置时,对`paddle.layer.conv`这种需要padding的层可能需要推算size后在`paddle.layer.conv`外使用`paddle.layer.pad`进行padding。
54+
1. 与TensorFlow图像输入多使用NHWC的数据组织格式有所不同,PaddlePaddle按照NCHW的格式组织图像输入数据。
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import os
2+
import re
3+
import collections
4+
import struct
5+
import gzip
6+
import tarfile
7+
import cStringIO
8+
import numpy as np
9+
10+
import tensorflow as tf
11+
12+
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
13+
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
14+
15+
16+
class ModelConverter(object):
17+
def __init__(self,
18+
paddle_tar_name,
19+
param_name_map=None,
20+
layer_name_map=None,
21+
layer_type_map=None):
22+
self.tar_name = paddle_tar_name
23+
self.param_name_map = param_name_map
24+
self.layer_name_map = layer_name_map
25+
self.layer_type_map = layer_type_map
26+
self.params = dict()
27+
28+
def convert(self):
29+
layers_params = self.arrange_layer_params()
30+
for layer_name in layers_params.keys():
31+
layer_params, layer_params_names, layer_type = layers_params[
32+
layer_name]
33+
if len(layer_params) > 0:
34+
if not layer_type:
35+
assert layer_type_map and (
36+
layer_type_map.get(layer_name) in ["conv", "bn", "fc"])
37+
layer_type = layer_type_map[layer_name]
38+
self.pre_layer_name = getattr(
39+
self, "convert_" + layer_type + "_layer")(
40+
layer_params,
41+
params_names=[
42+
self.param_name_map.get(name)
43+
if self.param_name_map else None
44+
for name in layer_params_names
45+
],
46+
name=None if self.layer_name_map == None else
47+
self.layer_name_map.get(layer_name))
48+
with gzip.open(self.tar_name, 'w') as f:
49+
self.to_tar(f)
50+
return
51+
52+
def to_tar(self, f):
53+
tar = tarfile.TarFile(fileobj=f, mode='w')
54+
for param_name in self.params.keys():
55+
param_conf, param_data = self.params[param_name]
56+
57+
confStr = param_conf.SerializeToString()
58+
tarinfo = tarfile.TarInfo(name="%s.protobuf" % param_name)
59+
tarinfo.size = len(confStr)
60+
buf = cStringIO.StringIO(confStr)
61+
buf.seek(0)
62+
tar.addfile(tarinfo, fileobj=buf)
63+
64+
buf = cStringIO.StringIO()
65+
self.serialize(param_data, buf)
66+
tarinfo = tarfile.TarInfo(name=param_name)
67+
buf.seek(0)
68+
tarinfo.size = len(buf.getvalue())
69+
tar.addfile(tarinfo, buf)
70+
71+
@staticmethod
72+
def serialize(data, f):
73+
f.write(struct.pack("IIQ", 0, 4, data.size))
74+
f.write(data.tobytes())
75+
76+
77+
class TFModelConverter(ModelConverter):
78+
def __init__(self,
79+
tf_net,
80+
paddle_tar_name,
81+
param_name_map=None,
82+
layer_name_map=None,
83+
layer_type_map=None):
84+
super(TFModelConverter, self).__init__(paddle_tar_name, param_name_map,
85+
layer_name_map, layer_type_map)
86+
self.sess = __import__(tf_net).build_model()
87+
88+
def arrange_layer_params(self):
89+
all_vars = tf.global_variables()
90+
layers_params = collections.OrderedDict()
91+
for var in all_vars:
92+
var_name = var.name
93+
scope_pos = var_name.rfind('/')
94+
if scope_pos != -1:
95+
layer_scope = var_name[:scope_pos]
96+
if layers_params.has_key(layer_scope):
97+
layer_params, layer_params_names, layer_type = layers_params[
98+
layer_scope]
99+
layer_params.append(var.eval(self.sess))
100+
layer_params_names.append(var_name)
101+
else:
102+
layer_type = re.search('conv|bn|fc', layer_scope)
103+
layers_params[layer_scope] = ([var.eval(self.sess)],
104+
[var_name], layer_type.group()
105+
if layer_type else None)
106+
return layers_params
107+
108+
@wrap_name_default("conv")
109+
def convert_conv_layer(self, params, params_names=None, name=None):
110+
for i in range(len(params)):
111+
data = np.transpose(params[i], (
112+
3, 2, 0, 1)) if len(params[i].shape) == 4 else params[i]
113+
if len(params) == 2:
114+
suffix = "0" if i == 0 else "bias"
115+
file_name = "_%s.w%s" % (name, suffix) if not (
116+
params_names and params_names[i]) else params_names[i]
117+
else:
118+
file_name = "_%s.w%s" % (name, str(i)) if not (
119+
params_names and params_names[i]) else params_names[i]
120+
param_conf = ParameterConfig()
121+
param_conf.name = file_name
122+
dims = list(data.shape)
123+
if len(dims) == 1:
124+
dims.insert(1, 1)
125+
param_conf.dims.extend(dims)
126+
param_conf.size = reduce(lambda a, b: a * b, data.shape)
127+
self.params[file_name] = (param_conf, data.flatten())
128+
129+
@wrap_name_default("fc_layer")
130+
def convert_fc_layer(self, params, params_names=None, name=None):
131+
for i in range(len(params)):
132+
data = params[i]
133+
if len(params) == 2:
134+
suffix = "0" if i == 0 else "bias"
135+
file_name = "_%s.w%s" % (name, suffix) if not (
136+
params_names and params_names[i]) else params_names[i]
137+
else:
138+
file_name = "_%s.w%s" % (name, str(i)) if not (
139+
params_names and params_names[i]) else params_names[i]
140+
param_conf = ParameterConfig()
141+
param_conf.name = file_name
142+
dims = list(data.shape)
143+
if len(dims) < 2:
144+
dims.insert(0, 1)
145+
param_conf.size = reduce(lambda a, b: a * b, dims)
146+
param_conf.dims.extend(dims)
147+
self.params[file_name] = (param_conf, data.flatten())
148+
return name
149+
150+
@wrap_name_default("batch_norm")
151+
def convert_bn_layer(self, params, params_names=None, name=None):
152+
params = [params[i] for i in (0, 2, 3, 1)]
153+
params_names = [params_names[i]
154+
for i in (0, 2, 3, 1)] if params_names else params_names
155+
for i in range(len(params)):
156+
data = params[i]
157+
file_name = "_%s.w%s" % (name, str(i)) if i < 3 else "_%s.w%s" % (
158+
name, "bias")
159+
file_name = file_name if not (params_names and
160+
params_names[i]) else params_names[i]
161+
param_conf = ParameterConfig()
162+
param_conf.name = file_name
163+
dims = list(data.shape)
164+
assert len(dims) == 1
165+
dims.insert(0, 1)
166+
param_conf.size = reduce(lambda a, b: a * b, dims)
167+
param_conf.dims.extend(dims)
168+
self.params[file_name] = (param_conf, data.flatten())
169+
return name
170+
171+
172+
if __name__ == "__main__":
173+
tf_net = "TF_ResNet"
174+
paddle_tar_name = "Paddle_ResNet50.tar.gz"
175+
176+
converter = TFModelConverter(tf_net=tf_net, paddle_tar_name=paddle_tar_name)
177+
converter.convert()

0 commit comments

Comments
 (0)