如何从tf.tensor中获取字符串值,其中dtype是字符串
我想使用tf.data.Dataset.list_files函数提供数据集。
但是因为该文件不是图像,所以我需要手动加载它。
问题是tf.data.Dataset.list_files作为tf.tensor传递变量,而我的python代码无法处理张量。
如何从tf.tensor获取字符串值。dtype是字符串。
train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')train_dataset = train_dataset.map(lambda x: load_audio_file(x))
def load_audio_file(file_path):
print("file_path: ", file_path)
# i want do something like string_path = convert_tensor_to_string(file_path)
文件路径为 Tensor("arg0:0", shape=(), dtype=string)
我使用tensorflow 1.13.1和eager模式。
提前致谢
回答:
您可以使用tf.py_func
包装load_audio_file()
。
import tensorflow as tftf.enable_eager_execution()
def load_audio_file(file_path):
# you should decode bytes type to string type
print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
return file_path
train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
file_path: clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path: clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path: clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
即使将替换tf.py_func
为tf.py_function
,上述解决方案也不适用于TF 2(已在2.2.0中测试)。
InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
要使其在TF 2中工作,请进行以下更改:
- 删除
tf.enable_eager_execution()
(默认情况下,TF 2中启用了渴望,您可以通过tf.executing_eagerly()
返回进行验证True
) - 替换
tf.py_func
为tf.py_function
- 替换的所有功能的参考
file_path
用file_path.numpy()
以上是 如何从tf.tensor中获取字符串值,其中dtype是字符串 的全部内容, 来源链接: utcz.com/qa/420191.html