读取NumPy数组

将数据读取存入NumPy数组

with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

验证feature与labels的行一致

assert features.shape[0] == labels.shape[0]

将NumPy数组转化为Tensor

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

上面的操作对于小数据集比较有效,但是却十分消耗内存,因为array会被复制多次。

下面,以Placeholder定义Dataset

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

results matching ""

    No results matching ""