TensorFlow:如何从SavedModel进行预测?
我已经导出了SavedModel
,现在我可以将其加载回并进行预测。经过培训,具有以下功能和标签:
F1 : FLOAT32F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32
所以说我要输入的值得20.9, 1.8, 0.9
到一个FLOAT32
预测。我该怎么做?我已经成功地加载了模型,但是我不确定如何访问它以进行预测调用。
with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
"/job/export/Servo/1503723455"
)
# How can I predict from here?
# I want to do something like prediction = model.predict([20.9, 1.8, 0.9])
该问题不是此处发布的问题的重复。这个问题集中于在SavedModel
任何模型类(不仅仅限于tf.estimator
)上进行推理的最小示例,以及指定输入和输出节点名称的语法。
回答:
加载图形后,它就可以在当前上下文中使用,您可以通过它馈入输入数据以获得预测。每个用例都有很大的不同,但是在代码中添加的内容如下所示:
with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
"/job/export/Servo/1503723455"
)
prediction = sess.run(
'prefix/predictions/Identity:0',
feed_dict={
'Placeholder:0': [20.9],
'Placeholder_1:0': [1.8],
'Placeholder_2:0': [0.9]
}
)
print(prediction)
在这里,您需要知道预测输入的名称。如果您没有给他们带来天真serving_fn
,则它们默认为Placeholder_n
,这n
是第n个功能。
的第一个字符串参数sess.run
是预测目标的名称。这将根据您的用例而有所不同。
以上是 TensorFlow:如何从SavedModel进行预测? 的全部内容, 来源链接: utcz.com/qa/405812.html