博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
(原)tensorflow使用eager在mnist上训练的简单例子
阅读量:6933 次
发布时间:2019-06-27

本文共 8527 字,大约阅读时间需要 28 分钟。

转载请注明出处:

代码网址:

参考网址:

总体流程

tensorflow使用eager时,需要下面几句话(如果不使用第三句话,则依旧可以使用静态图):

import tensorflow as tfimport tensorflow.contrib.eager as tfetfe.enable_eager_execution()

tensorflow使用eager模式后,感觉和pytorch一样方便。使用eager后,不需要tf.placeholder,用起来更加方便。

目前貌似tf.keras.layers和tf.layers支持eager,slim不支持。

总体流程如下:

initial optimizerfor I in range(epochs):    for imgs, targets in training_data:        with tf.GradientTape() as tape:            logits = model(imgs, training=True)            loss_value = calc_loss(logits, targets)        grads = tape.gradient(loss_value, model.variables)        optimizer.apply_gradients(zip(grads, model.variables), global_step=step_counter)        update training_accurate, total_loss    test model    save model

创建模型

可以使用下面三种方式创建模型

1. 类似pytorch的方式

先在__init__中定义用到的层,然后重载call函数,构建网络。模型前向计算时,会调用call函数。如下面代码所示:

1 class simpleModel(tf.keras.Model): 2     def __init__(self, num_classes): 3         super(simpleModel, self).__init__() 4  5         input_shape = [28, 28, 1] 6         data_format = 'channels_last' 7         self.reshape = tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(input_shape[0] * input_shape[1],)) 8  9         self.conv1 = tf.keras.layers.Conv2D(16, 5, padding="same", activation='relu')10         self.batch1 = tf.keras.layers.BatchNormalization()11         self.pool1 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)12 13         self.conv2 = tf.keras.layers.Conv2D(32, 5, padding="same", activation='relu')14         self.batch2 = tf.keras.layers.BatchNormalization()15         self.pool2 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)16 17         self.conv3 = tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu')18         self.batch3 = tf.keras.layers.BatchNormalization()19         self.pool3 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)20 21         self.conv4 = tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu')22         self.batch4 = tf.keras.layers.BatchNormalization()23         self.pool4 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)24 25         self.flat = tf.keras.layers.Flatten()26         self.fc5 = tf.keras.layers.Dense(1024, activation='relu')27         self.batch5 = tf.keras.layers.BatchNormalization()28 29         self.fc6 = tf.keras.layers.Dense(num_classes)30         self.batch6 = tf.keras.layers.BatchNormalization()31 32     def call(self, inputs, training=None):33         x = self.reshape(inputs)34 35         x = self.conv1(x)36         x = self.batch1(x, training=training)37         x = self.pool1(x)38 39         x = self.conv2(x)40         x = self.batch2(x, training=training)41         x = self.pool2(x)42 43         x = self.conv3(x)44         x = self.batch3(x, training=training)45         x = self.pool3(x)46 47         x = self.conv4(x)48         x = self.batch4(x, training=training)49         x = self.pool4(x)50 51         x = self.flat(x)52         x = self.fc5(x)53         x = self.batch5(x, training=training)54 55         x = self.fc6(x)56         x = self.batch6(x, training=training)57         # x = tf.layers.dropout(x, rate=0.3, training=training)58         return x59 60     def get_acc(self, target):61         correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(target, 1))62         acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))63         return acc64 65     def get_loss(self):66         return self.loss67 68     def loss_fn(self, images, target, training):69         self.logits = self(images, training)  # call call(self, inputs, training=None) function70         self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=target))71         return self.loss72 73     def grads_fn(self, images, target, training):  # do not return loss and acc if unnecessary74         with tfe.GradientTape() as tape:75             loss = self.loss_fn(images, target, training)76         return tape.gradient(loss, self.variables)

2. 直接使用tf.keras.Sequential

如下面代码所示:

1 def create_model1(): 2     data_format = 'channels_last' 3     input_shape = [28, 28, 1] 4     l = tf.keras.layers 5     max_pool = l.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) 6     # The model consists of a sequential chain of layers, so tf.keras.Sequential (a subclass of tf.keras.Model) makes for a compact description. 7     return tf.keras.Sequential( 8         [ 9             l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),10             l.Conv2D(16, 5, padding='same', data_format=data_format, activation=tf.nn.relu),11             l.BatchNormalization(),12             max_pool,13 14             l.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu),15             l.BatchNormalization(),16             max_pool,17 18             l.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),19             l.BatchNormalization(),20             max_pool,21 22             l.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),23             l.BatchNormalization(),24             max_pool,25 26             l.Flatten(),27             l.Dense(1024, activation=tf.nn.relu),28             l.BatchNormalization(),29 30             # # l.Dropout(0.4),31             l.Dense(10),32             l.BatchNormalization()33         ])

3. 使用tf.keras.Sequential()及add函数

如下面代码所示:

1 def create_model2(): 2     data_format = 'channels_last' 3     input_shape = [28, 28, 1] 4  5     model = tf.keras.Sequential() 6  7     model.add(tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(input_shape[0] * input_shape[1],))) 8  9     model.add(tf.keras.layers.Conv2D(16, 5, padding="same", activation='relu'))10     model.add(tf.keras.layers.BatchNormalization())11     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))12 13     model.add(tf.keras.layers.Conv2D(32, 5, padding="same", activation='relu'))14     model.add(tf.keras.layers.BatchNormalization())15     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))16 17     model.add(tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu'))18     model.add(tf.keras.layers.BatchNormalization())19     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))20 21     model.add(tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu'))22     model.add(tf.keras.layers.BatchNormalization())23     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))24 25     model.add(tf.keras.layers.Flatten())26     model.add(tf.keras.layers.Dense(1024, activation='relu'))27     model.add(tf.keras.layers.BatchNormalization())28 29     model.add(tf.keras.layers.Dense(10))30     model.add(tf.keras.layers.BatchNormalization())31 32 return model

使用动态图更新梯度

在更新梯度时,需要加上下面的几句话

1 with tf.GradientTape() as tape:2     logits = model(imgs, training=True)3     loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labs))4 grads = tape.gradient(loss_value, model.variables)5 optimizer.apply_gradients(zip(grads, model.variables), global_step=step_counter)

第二行得到特征,第三行得到损失,第四行得到梯度,第五行将梯度应用到模型,更新模型参数。

保存及载入模型

1. 使用tfe.Saver

代码如下

1 def saveModelV1(model_dir, model, global_step, modelname='model1'):2     tfe.Saver(model.variables).save(os.path.join(model_dir, modelname), global_step=global_step)3 def restoreModelV1(model_dir, model):4     dummy_input = tf.constant(tf.zeros((1, 28, 28, 1)))  # Run the model once to initialize variables5     dummy_pred = model(dummy_input, training=False)6 7     saver = tfe.Saver(model.variables)  # Restore the variables of the model8     saver.restore(tf.train.latest_checkpoint(model_dir))

2. 使用tf.train.Checkpoint

代码如下

1 step_counter = tf.train.get_or_create_global_step()2 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer, step_counter=step_counter)3 4 def saveModelV2(model_dir, checkpoint, modelname='model2'):5     checkpoint_prefix = os.path.join(model_dir, modelname)6     checkpoint.save(checkpoint_prefix)7 8 def restoreModelV2(model_dir, checkpoint):9     checkpoint.restore(tf.train.latest_checkpoint(model_dir))

具体代码

代码未严格按照总体流程的步骤,仅供参考,见

其中eagerFlag为使用eager的方式,0为不使用eager(使用静态图),1为使用V1的方式,2为使用V2的方式。当使用静态图时,不要加tfe.enable_eager_execution(),否则会报错。具体可参考代码。

 

你可能感兴趣的文章
vue.js仿饿了么app---webpack配置项
查看>>
架构周报:微信后台系统的演进之路
查看>>
Oracle宣布提供新的Java支持价格体系
查看>>
phpstrom配置svn/git提交
查看>>
关于Redux的一些总结(一):Action & 中间件 & 异步
查看>>
专访1药网技术副总裁黄哲铿:揭秘技术跨界管理之道
查看>>
Markdown通用的常用语法说明
查看>>
gulp关于scss的基础配置
查看>>
PHP:echo、print、print_r() 和 var_dump()
查看>>
Gerrit代码Review入门实战
查看>>
Swift中一个类中的枚举(enum)类型的数据该如何实现序列化(NSCoder)
查看>>
WebSocket 原理
查看>>
按端口终止进程
查看>>
Permutations I & II leetcode
查看>>
[LeetCode/LintCode] Factorial Trailing Zeros
查看>>
iOS病毒XcodeGhost批量检测工具,开源Github(检测ipa文件)
查看>>
npm 加入 TC39 委员会,参与定制 JavaScript 标准
查看>>
centos7.2安装mysql
查看>>
关于 Python
查看>>
AVFoundation学习Demo--拍摄视频
查看>>