读取NumPy数组
将数据读取存入NumPy数组
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
验证feature与labels的行一致
assert features.shape[0] == labels.shape[0]
将NumPy数组转化为Tensor
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
上面的操作对于小数据集比较有效,但是却十分消耗内存,因为array会被复制多次。
下面,以Placeholder定义Dataset
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})