tf.data模块包含众多类,允许你轻松的加载数据,操作数据,并将数据传入模型。
以Iris示例的iris_data.py
为例
def train_input_fn(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
Slice
tf.data.Dataset.from_tensor_slices
接收array,返回tf.data.Dataset
比如,在mnist示例中,训练数据(60000,28,28)
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
当打印dataset时,会显示数据集中每条记录的shape和types
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
假设feature是dict
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
有时,会加上label
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
操作
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
Return
features_result, labels_result = dataset.make_one_shot_iterator().get_next()
读取CSV
读取文本文件,并跳过第一行
ds = tf.data.TextLineDataset(train_path).skip(1)
构建csv line parser
COLUMNS = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth',
'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
fields = tf.decode_csv(line, FIELD_DEFAULTS)
features = dict(zip(COLUMNS,fields))
label = features.pop('label')
return features, label