tf.data模块包含众多类,允许你轻松的加载数据,操作数据,并将数据传入模型。

以Iris示例的iris_data.py为例

def train_input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    return dataset.make_one_shot_iterator().get_next()

Slice

tf.data.Dataset.from_tensor_slices接收array,返回tf.data.Dataset

比如,在mnist示例中,训练数据(60000,28,28)

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

当打印dataset时,会显示数据集中每条记录的shape和types

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

假设feature是dict

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)

有时,会加上label

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操作

dataset = dataset.shuffle(1000).repeat().batch(batch_size)

Return

features_result, labels_result = dataset.make_one_shot_iterator().get_next()

读取CSV

读取文本文件,并跳过第一行

ds = tf.data.TextLineDataset(train_path).skip(1)

构建csv line parser

COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
           'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    features = dict(zip(COLUMNS,fields))

    label = features.pop('label')

    return features, label

results matching ""

    No results matching ""