问题描述
说,我们输入了x
和标签y
:
Say, we have input x
and label y
:
iterator = tf.data.Iterator.from_structure((x_type, y_type), (x_shape, y_shape))
tf_x, tf_y = iterator.get_next()
现在我使用 generate 函数来创建数据集:
Now I use generate function to create dataset:
def gen():
for ....: yield (x, y)
ds = tf.data.Dataset.from_generator(gen, (x_type, y_type), (x_shape, y_shape))
在我的图表中,我使用 tf_x
和 tf_y
进行训练,这很好.但现在我想做参考,我没有标签 y
.我提出的一种解决方法是伪造一个 y(如 tf.zeros(y_shape)),然后使用占位符来初始化迭代器.
In my graph, I use tf_x
and tf_y
to do training, that is fine. But now I want to do referring, where I don't have label y
. One workaround I made is to fake a y (like tf.zeros(y_shape)), then I use a placeholder to init the iterator.
x_placeholder = tf.placeholder(...)
y_placeholder = tf.placeholder(...)
ds = tf.data.Dataset.from_tensors((x_placeholder, y_placeholder))
ds_init_op = iterator.make_initializer(ds)
sess.run(ds_init_op, feed_dict={x_placeholder=x, y_placeholder=fake(y))})
我的问题是,有没有更清洁的方法来做到这一点?在推断期间没有伪造 y
?
My question is, is there a cleaner way to do that? without fake a y
during inferring time?
更新:
我实验了一下,貌似少了一个数据集操作unzip
:
I experiment a little bit, looks like there is one dataset operation unzip
missing:
import numpy as np
import tensorflow as tf
x_type = tf.float32
y_type = tf.float32
x_shape = tf.TensorShape([None, 128])
y_shape = tf.TensorShape([None, 10])
x_shape_nobatch = tf.TensorShape([128])
y_shape_nobatch = tf.TensorShape([10])
iterator_x = tf.data.Iterator.from_structure((x_type,), (x_shape,))
iterator_y = tf.data.Iterator.from_structure((y_type,), (y_shape,))
def gen1():
for i in range(100):
yield np.random.randn(128)
ds1 = tf.data.Dataset.from_generator(gen1, (x_type,), (x_shape_nobatch,))
ds1 = ds1.batch(5)
ds1_init_op = iterator_x.make_initializer(ds1)
def gen2():
for i in range(80):
yield np.random.randn(128), np.random.randn(10)
ds2 = tf.data.Dataset.from_generator(gen2, (x_type, y_type), (x_shape_nobatch, y_shape_nobatch))
ds2 = ds2.batch(10)
# my ds2 has two tensors in one element, now the problem is
# how can I unzip this dataset so that I can apply them to iterator_x and iterator_y?
# such as:
ds2_x, ds2_y = tf.data.Dataset.unzip(ds2) #?? missing this unzip operation!
ds2_x_init_op = iterator_x.make_initializer(ds2_x)
ds2_y_init_op = iterator_y.make_initializer(ds2_y)
tf_x = iterator_x.get_next()
tf_y = iterator_y.get_next()
推荐答案
数据集 API 的目的是避免将值直接提供给会话(因为这会导致数据首先流向客户端,然后流向设备).
The purpose of datasets API is to avoid feeding the values directly to session (because that causes the data to flow first to the client, then to a device).
我见过的所有使用数据集 API 的示例也使用 estimator API,您可以在其中可以为训练和推理提供不同的输入函数.
All examples I've seen that use datasets API also use estimator API, where you can provide different input functions for training and inference.
def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
data = input_data.read_data_sets(data_dir, one_hot=True).train
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
def infer_dataset(data_dir):
"""Returns a tf.data.Dataset yielding images for inference."""
data = input_data.read_data_sets(data_dir, one_hot=True).test
return tf.data.Dataset.from_tensors((data.images,))
...
def train_input_fn():
dataset = train_dataset(FLAGS.data_dir)
dataset = dataset.shuffle(buffer_size=50000).batch(1024).repeat(10)
(images, labels) = dataset.make_one_shot_iterator().get_next()
return (images, labels)
mnist_classifier.train(input_fn=train_input_fn)
...
def infer_input_fn():
return infer_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
mnist_classifier.predict(input_fn=infer_input_fn)
这篇关于如何在训练和推理中使用 tf.Dataset 设计?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持跟版网!