如何从tf.tensor中获取字符串值,其中dtype是字符串

我想使用tf.data.Dataset.list_files函数提供数据集。

但是因为该文件不是图像,所以我需要手动加载它。

问题是tf.data.Dataset.list_files作为tf.tensor传递变量,而我的python代码无法处理张量。

如何从tf.tensor获取字符串值。dtype是字符串。

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')

train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):

print("file_path: ", file_path)

# i want do something like string_path = convert_tensor_to_string(file_path)

文件路径为 Tensor("arg0:0", shape=(), dtype=string)

我使用tensorflow 1.13.1和eager模式。

提前致谢

回答:

您可以使用tf.py_func包装load_audio_file()

import tensorflow as tf

tf.enable_eager_execution()

def load_audio_file(file_path):

# you should decode bytes type to string type

print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))

return file_path

train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')

train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))

for one_element in train_dataset:

print(one_element)

file_path: clean_4s_val/1.wav <class 'str'>

(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)

file_path: clean_4s_val/3.wav <class 'str'>

(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)

file_path: clean_4s_val/2.wav <class 'str'>

(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)

即使将替换tf.py_functf.py_function,上述解决方案也不适用于TF 2(已在2.2.0中测试)。

InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'

要使其在TF 2中工作,请进行以下更改:

  • 删除tf.enable_eager_execution()(默认情况下,TF 2中启用了渴望,您可以通过tf.executing_eagerly()返回进行验证True
  • 替换tf.py_functf.py_function
  • 替换的所有功能的参考file_pathfile_path.numpy()

以上是 如何从tf.tensor中获取字符串值,其中dtype是字符串 的全部内容, 来源链接: utcz.com/qa/420191.html

回到顶部