使用浪潮软件集团技术研究中心的IDSW工作台,运行MNist任务

配置TFoS

idsw的notebook启动时,启动Spark集群,在退出时关闭SparkContext

引入依赖

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import argparse
import subprocess
from tensorflowonspark import TFCluster

配置基本参数

parser = argparse.ArgumentParser()
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("--batch_size", help="number of examples per batch", type=int, default=100)
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--rdma", help="use rdma connection", default=False)
num_executors = 3

训练任务

配置训练参数

train_images_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/train/images"
train_labels_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/train/labels"
args = parser.parse_args(['--mode', 'train', '--steps', '3000', '--epochs', '1',
                          '--images', train_images_files, 
                          '--labels', train_labels_files])

加载python文件,并引用

sc.addPyFile("hdfs://10.110.18.18:8020/user/root/mnist/mnist_dist.py")
import mnist_dist

启动TFoS

cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, True, TFCluster.InputMode.SPARK)

训练并导出模型

# train model
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)

关闭TFoS

cluster.shutdown()

预测过程

配置预测参数

test_images_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/test/images"
test_labels_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/test/labels"

#Parse arguments for inference
args = parser.parse_args(['--mode', 'inference', 
                          '--images', test_images_files, 
                          '--labels', test_labels_files])

启动TFoS

cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, False, TFCluster.InputMode.SPARK)

预测

#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)

关闭TFoS

cluster.shutdown()

results matching ""

    No results matching ""