|
| 1 | +import os |
| 2 | +import re |
| 3 | +import collections |
| 4 | +import struct |
| 5 | +import gzip |
| 6 | +import tarfile |
| 7 | +import cStringIO |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +import tensorflow as tf |
| 11 | + |
| 12 | +from paddle.proto.ParameterConfig_pb2 import ParameterConfig |
| 13 | +from paddle.trainer_config_helpers.default_decorators import wrap_name_default |
| 14 | + |
| 15 | + |
| 16 | +class ModelConverter(object): |
| 17 | + def __init__(self, |
| 18 | + paddle_tar_name, |
| 19 | + param_name_map=None, |
| 20 | + layer_name_map=None, |
| 21 | + layer_type_map=None): |
| 22 | + self.tar_name = paddle_tar_name |
| 23 | + self.param_name_map = param_name_map |
| 24 | + self.layer_name_map = layer_name_map |
| 25 | + self.layer_type_map = layer_type_map |
| 26 | + self.params = dict() |
| 27 | + |
| 28 | + def convert(self): |
| 29 | + layers_params = self.arrange_layer_params() |
| 30 | + for layer_name in layers_params.keys(): |
| 31 | + layer_params, layer_params_names, layer_type = layers_params[ |
| 32 | + layer_name] |
| 33 | + if len(layer_params) > 0: |
| 34 | + if not layer_type: |
| 35 | + assert layer_type_map and ( |
| 36 | + layer_type_map.get(layer_name) in ["conv", "bn", "fc"]) |
| 37 | + layer_type = layer_type_map[layer_name] |
| 38 | + self.pre_layer_name = getattr( |
| 39 | + self, "convert_" + layer_type + "_layer")( |
| 40 | + layer_params, |
| 41 | + params_names=[ |
| 42 | + self.param_name_map.get(name) |
| 43 | + if self.param_name_map else None |
| 44 | + for name in layer_params_names |
| 45 | + ], |
| 46 | + name=None if self.layer_name_map == None else |
| 47 | + self.layer_name_map.get(layer_name)) |
| 48 | + with gzip.open(self.tar_name, 'w') as f: |
| 49 | + self.to_tar(f) |
| 50 | + return |
| 51 | + |
| 52 | + def to_tar(self, f): |
| 53 | + tar = tarfile.TarFile(fileobj=f, mode='w') |
| 54 | + for param_name in self.params.keys(): |
| 55 | + param_conf, param_data = self.params[param_name] |
| 56 | + |
| 57 | + confStr = param_conf.SerializeToString() |
| 58 | + tarinfo = tarfile.TarInfo(name="%s.protobuf" % param_name) |
| 59 | + tarinfo.size = len(confStr) |
| 60 | + buf = cStringIO.StringIO(confStr) |
| 61 | + buf.seek(0) |
| 62 | + tar.addfile(tarinfo, fileobj=buf) |
| 63 | + |
| 64 | + buf = cStringIO.StringIO() |
| 65 | + self.serialize(param_data, buf) |
| 66 | + tarinfo = tarfile.TarInfo(name=param_name) |
| 67 | + buf.seek(0) |
| 68 | + tarinfo.size = len(buf.getvalue()) |
| 69 | + tar.addfile(tarinfo, buf) |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def serialize(data, f): |
| 73 | + f.write(struct.pack("IIQ", 0, 4, data.size)) |
| 74 | + f.write(data.tobytes()) |
| 75 | + |
| 76 | + |
| 77 | +class TFModelConverter(ModelConverter): |
| 78 | + def __init__(self, |
| 79 | + tf_net, |
| 80 | + paddle_tar_name, |
| 81 | + param_name_map=None, |
| 82 | + layer_name_map=None, |
| 83 | + layer_type_map=None): |
| 84 | + super(TFModelConverter, self).__init__(paddle_tar_name, param_name_map, |
| 85 | + layer_name_map, layer_type_map) |
| 86 | + self.sess = __import__(tf_net).build_model() |
| 87 | + |
| 88 | + def arrange_layer_params(self): |
| 89 | + all_vars = tf.global_variables() |
| 90 | + layers_params = collections.OrderedDict() |
| 91 | + for var in all_vars: |
| 92 | + var_name = var.name |
| 93 | + scope_pos = var_name.rfind('/') |
| 94 | + if scope_pos != -1: |
| 95 | + layer_scope = var_name[:scope_pos] |
| 96 | + if layers_params.has_key(layer_scope): |
| 97 | + layer_params, layer_params_names, layer_type = layers_params[ |
| 98 | + layer_scope] |
| 99 | + layer_params.append(var.eval(self.sess)) |
| 100 | + layer_params_names.append(var_name) |
| 101 | + else: |
| 102 | + layer_type = re.search('conv|bn|fc', layer_scope) |
| 103 | + layers_params[layer_scope] = ([var.eval(self.sess)], |
| 104 | + [var_name], layer_type.group() |
| 105 | + if layer_type else None) |
| 106 | + return layers_params |
| 107 | + |
| 108 | + @wrap_name_default("conv") |
| 109 | + def convert_conv_layer(self, params, params_names=None, name=None): |
| 110 | + for i in range(len(params)): |
| 111 | + data = np.transpose(params[i], ( |
| 112 | + 3, 2, 0, 1)) if len(params[i].shape) == 4 else params[i] |
| 113 | + if len(params) == 2: |
| 114 | + suffix = "0" if i == 0 else "bias" |
| 115 | + file_name = "_%s.w%s" % (name, suffix) if not ( |
| 116 | + params_names and params_names[i]) else params_names[i] |
| 117 | + else: |
| 118 | + file_name = "_%s.w%s" % (name, str(i)) if not ( |
| 119 | + params_names and params_names[i]) else params_names[i] |
| 120 | + param_conf = ParameterConfig() |
| 121 | + param_conf.name = file_name |
| 122 | + dims = list(data.shape) |
| 123 | + if len(dims) == 1: |
| 124 | + dims.insert(1, 1) |
| 125 | + param_conf.dims.extend(dims) |
| 126 | + param_conf.size = reduce(lambda a, b: a * b, data.shape) |
| 127 | + self.params[file_name] = (param_conf, data.flatten()) |
| 128 | + |
| 129 | + @wrap_name_default("fc_layer") |
| 130 | + def convert_fc_layer(self, params, params_names=None, name=None): |
| 131 | + for i in range(len(params)): |
| 132 | + data = params[i] |
| 133 | + if len(params) == 2: |
| 134 | + suffix = "0" if i == 0 else "bias" |
| 135 | + file_name = "_%s.w%s" % (name, suffix) if not ( |
| 136 | + params_names and params_names[i]) else params_names[i] |
| 137 | + else: |
| 138 | + file_name = "_%s.w%s" % (name, str(i)) if not ( |
| 139 | + params_names and params_names[i]) else params_names[i] |
| 140 | + param_conf = ParameterConfig() |
| 141 | + param_conf.name = file_name |
| 142 | + dims = list(data.shape) |
| 143 | + if len(dims) < 2: |
| 144 | + dims.insert(0, 1) |
| 145 | + param_conf.size = reduce(lambda a, b: a * b, dims) |
| 146 | + param_conf.dims.extend(dims) |
| 147 | + self.params[file_name] = (param_conf, data.flatten()) |
| 148 | + return name |
| 149 | + |
| 150 | + @wrap_name_default("batch_norm") |
| 151 | + def convert_bn_layer(self, params, params_names=None, name=None): |
| 152 | + params = [params[i] for i in (0, 2, 3, 1)] |
| 153 | + params_names = [params_names[i] |
| 154 | + for i in (0, 2, 3, 1)] if params_names else params_names |
| 155 | + for i in range(len(params)): |
| 156 | + data = params[i] |
| 157 | + file_name = "_%s.w%s" % (name, str(i)) if i < 3 else "_%s.w%s" % ( |
| 158 | + name, "bias") |
| 159 | + file_name = file_name if not (params_names and |
| 160 | + params_names[i]) else params_names[i] |
| 161 | + param_conf = ParameterConfig() |
| 162 | + param_conf.name = file_name |
| 163 | + dims = list(data.shape) |
| 164 | + assert len(dims) == 1 |
| 165 | + dims.insert(0, 1) |
| 166 | + param_conf.size = reduce(lambda a, b: a * b, dims) |
| 167 | + param_conf.dims.extend(dims) |
| 168 | + self.params[file_name] = (param_conf, data.flatten()) |
| 169 | + return name |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + tf_net = "TF_ResNet" |
| 174 | + paddle_tar_name = "Paddle_ResNet50.tar.gz" |
| 175 | + |
| 176 | + converter = TFModelConverter(tf_net=tf_net, paddle_tar_name=paddle_tar_name) |
| 177 | + converter.convert() |
0 commit comments