如何解决这个问题呢,tensorflow下无法运行

图片描述

import tensorflow as tf

import numpy as np

import os

import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# In[3]:

train_dir = "G:/苹果测试数据/"

def get_files(file_dir):

good = []

label_good = []

bad = []

label_bad = []

medium = []

label_medium = []

for file in os.listdir(file_dir):

name = file.split(sep = ".")

if name[0] == "good":

good.append(file_dir + file)

label_good.append(0)

elif name[0] == "medium":

medium.append(file_dir + file)

label_medium.append(1)

else:

bad.append(file_dir + file)

label_bad.append(2)

print("There are %d good apples\nThere are %d medium apples\nThere are %d bad apples" %(len(good),len(medium),len(bad)))

image_list = np.hstack((good,medium,bad))

label_list = np.hstack((label_good,label_medium,label_bad))

temp = np.array([image_list,label_list])

temp = temp.transpose()

np.random.shuffle(temp)

image_list = list (temp[:,0])

label_list = list (temp[:,1])

label_list = [float(i) for i in label_list]

return image_list,label_list

def get_batch(image,label,image_W,image_H,batch_size,capacity):

image = tf.cast(image,tf.string)

label = tf.cast(label,tf.int32)

input_queue = tf.train.slice_input_producer([image,label])

label = input_queue[1]

image_contents = tf.read_file(input_queue[0])

image = tf.image.decode_jpeg(image_contents,channels = 3)

#数据增强可以加

image = tf.image.resize_image_with_crop_or_pad(image,image_W,image_H)

image = tf.image.per_image_standardization(image)

image_batch,label_batch = tf.train.batch([image,label],

batch_size = batch_size,

num_threads = 64,

capacity = capacity)

label_batch = tf.reshape(label_batch,[batch_size])

return image_batch,label_batch

BATCH_SIZE = 2

CAPACITY = 256

IMG_W = 208

IMG_H = 208

train_dir = "G:/苹果测试数据/"

image_list,label_list = get_files(train_dir)

image_batch,label_batch = get_batch(image_list,label_list,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)

with tf.Session() as sess:

i=0

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(coord=coord)

try:

while not coord.should_stop() and i<2:

img,label = sess.run([image_batch,label_batch])

for j in np.arange(BATCH_SIZE):

print ("label:%d" %label[j])

plt.imshow(img[j,:,:,:])

plt.show()

i+=1

except tf.errors.OutOfRangeError:

print ("done!")

finally:

coord.request_stop()

coord.join(threads)

图片描述

以上是 如何解决这个问题呢,tensorflow下无法运行 的全部内容, 来源链接: utcz.com/a/162211.html

回到顶部