传输学习与tf.estimator.Estimator框架

我正在尝试使用我自己的数据集和类在imagenet pretrained传输学习的Inception-resnet v2模型。 我的原始代码库是一个tf.slim样本的修改,我找不到了,现在我试图用tf.estimator.*框架重写相同的代码。传输学习与tf.estimator.Estimator框架

但是,我正在运行,只能加载一些的权重从预训练检查点,初始化其余层与他们的默认初始值设定项。

研究这个问题,我发现this GitHub issue和this question,都提到需要在我的model_fn中使用tf.train.init_from_checkpoint。我试过了,但由于缺乏两个例子,我想我错了。

这是我的小例子:

import sys 

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import tensorflow as tf

import numpy as np

import inception_resnet_v2

NUM_CLASSES = 900

IMAGE_SIZE = 299

def input_fn(mode, num_classes, batch_size=1):

# some code that loads images, reshapes them to 299x299x3 and batches them

return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)

def model_fn(images, labels, num_classes, mode):

with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):

logits, end_points = inception_resnet_v2.inception_resnet_v2(images,

num_classes,

is_training=(mode==tf.estimator.ModeKeys.TRAIN))

predictions = {

'classes': tf.argmax(input=logits, axis=1),

'probabilities': tf.nn.softmax(logits, name='softmax_tensor')

}

if mode == tf.estimator.ModeKeys.PREDICT:

return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']

variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)

scopes = { os.path.dirname(v.name) for v in variables_to_restore }

tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',

{s+'/':s+'/' for s in scopes})

tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well

# Configure the training op

if mode == tf.estimator.ModeKeys.TRAIN:

global_step = tf.train.get_or_create_global_step()

optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)

train_op = optimizer.minimize(total_loss, global_step)

else:

train_op = None

return tf.estimator.EstimatorSpec(

mode=mode,

predictions=predictions,

loss=total_loss,

train_op=train_op)

def main(unused_argv):

# Create the Estimator

classifier = tf.estimator.Estimator(

model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),

model_dir='model/MCVE')

# Train the model

classifier.train(

input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),

steps=1000)

# Evaluate the model and print results

eval_results = classifier.evaluate(

input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))

print()

print('Evaluation results:\n %s' % eval_results)

if __name__ == '__main__':

tf.app.run(main=main, argv=[sys.argv[0]])

其中inception_resnet_v2为the model implementation in Tensorflow's models repository。

如果我运行这个脚本,我会从init_from_checkpoint得到一堆信息日志,但是在会话创建时,它似乎尝试从检查点加载Logits权重,并因形状不兼容而失败。这是完整的回溯:

Traceback (most recent call last): 

File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>

runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile

execfile(filename, namespace)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile

exec(compile(f.read(), filename, 'exec'), namespace)

File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>

tf.app.run(main=main, argv=[sys.argv[0]])

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run

_sys.exit(main(_sys.argv[:1] + flags_passthrough))

File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main

steps=1000)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train

loss = self._train_model(input_fn, hooks, saving_listeners)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model

log_step_count_steps=self._config.log_step_count_steps) as mon_sess:

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession

stop_grace_period_secs=stop_grace_period_secs)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__

stop_grace_period_secs=stop_grace_period_secs)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__

self._sess = _RecoverableSession(self._coordinated_creator)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__

_WrappedSession.__init__(self, self._create_session())

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session

return self._sess_creator.create_session()

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session

self.tf_sess = self._session_creator.create_session()

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session

init_fn=self._scaffold.init_fn)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session

sess.run(init_op, feed_dict=init_feed_dict)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run

run_metadata_ptr)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run

feed_dict_tensor, options, run_metadata)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run

options, run_metadata)

File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call

raise type(e)(node_def, op, message)

InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT,

_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,

_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]

我在做什么错误时使用init_from_checkpoint?我们究竟应该如何在我们的model_fn中“使用”它?为什么当我明确告诉它不要时,估计器试图从检查点加载Logits'权重?

更新:

的意见建议后,我想其他方法来调用tf.train.init_from_checkpoint

使用{v.name: v.name}

如果,如评论所说,我替换{v.name:v.name for v in variables_to_restore}电话,我得到这个错误:

ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map 

to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.

使用{v.name: v}

相反,如果我尝试使用name:variable映射,我得到以下错误:

ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in 

inception_resnet_v2_2016_08_30.ckpt checkpoint

{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256],

'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...

错误继续列出我认为所有检查点中的变量名称(或者它可能是范围?)。

更新(2)

上方这里检查最新的错误后,我看到InceptionResnetV2/Conv2d_2a_3x3/weights是在检查点变量列表。 问题在于末尾:0! 我现在要验证这是否确实解决了问题并发布了答案(如果是这种情况)。

回答:

感谢@ KathyWu的评论,我找到了正确的道路,发现了问题。

事实上,我被计算scopes将包括InceptionResnetV2/范围的方式,将触发“之下”的范围所有变量的负荷(即,网络中的所有变量)。然而,用正确的词典代替它并不是微不足道的。

可能范围模式init_from_checkpoint accepts中,有一个我不得不使用是'scope_variable_name': variable之一,但不使用实际的variable.name属性。

variable.name看起来像:'some_scope/variable_name:0':0不在检查点变量的名称中,因此使用scopes = {v.name:v.name for v in variables_to_restore}将引发“变量未找到”错误。

诀窍,使其工作从名字剥张量指数:

tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 

{v.name.split(':')[0]: v for v in variables_to_restore})

以上是 传输学习与tf.estimator.Estimator框架 的全部内容, 来源链接: utcz.com/qa/263030.html

回到顶部