在Tensorflow中使用预训练的inception_resnet_v2

我一直在尝试使用Google发布的经过预先训练的inception_resnet_v2模型。我正在使用他们的模型定义(https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py)和给定的检查点(http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar

.gz)将模型加载到tensorflow中,如下所示[下载提取检查点文件并下载示例图像dog.jpg和panda.jpg来测试此代码]-

import tensorflow as tf

slim = tf.contrib.slim

from PIL import Image

from inception_resnet_v2 import *

import numpy as np

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'

sample_images = ['dog.jpg', 'panda.jpg']

#Load the model

sess = tf.Session()

arg_scope = inception_resnet_v2_arg_scope()

with slim.arg_scope(arg_scope):

logits, end_points = inception_resnet_v2(input_tensor, is_training=False)

saver = tf.train.Saver()

saver.restore(sess, checkpoint_file)

for image in sample_images:

im = Image.open(image).resize((299,299))

im = np.array(im)

im = im.reshape(-1,299,299,3)

predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})

print (np.max(predict_values), np.max(logit_values))

print (np.argmax(predict_values), np.argmax(logit_values))

但是,此模型代码的结果并未给出预期的结果(与输入图像无关,将预测类别918)。有人可以帮助我了解我要去哪里哪里吗?

回答:

盗梦空间网络期望输入图像具有从[-1,1]缩放的颜色通道。如这里所见。

您可以使用现有的预处理,也可以在示例中自行缩放图像:将图像im = 2*(im/255.0)-1.0馈送到网络之前。

如果不缩放比例,则输入[0-255]会比网络预期的要大得多,并且所有偏差都会非常强烈地预测类别918(漫画书)。

以上是 在Tensorflow中使用预训练的inception_resnet_v2 的全部内容, 来源链接: utcz.com/qa/417232.html

回到顶部