使用浪潮软件集团技术研究中心的IDSW工作台,运行MNist任务
配置TFoS
idsw的notebook启动时,启动Spark集群,在退出时关闭SparkContext
引入依赖
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import argparse
import subprocess
from tensorflowonspark import TFCluster
配置基本参数
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("--batch_size", help="number of examples per batch", type=int, default=100)
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--rdma", help="use rdma connection", default=False)
num_executors = 3
训练任务
配置训练参数
train_images_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/train/images"
train_labels_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/train/labels"
args = parser.parse_args(['--mode', 'train', '--steps', '3000', '--epochs', '1',
'--images', train_images_files,
'--labels', train_labels_files])
加载python文件,并引用
sc.addPyFile("hdfs://10.110.18.18:8020/user/root/mnist/mnist_dist.py")
import mnist_dist
启动TFoS
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, True, TFCluster.InputMode.SPARK)
训练并导出模型
# train model
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)
关闭TFoS
cluster.shutdown()
预测过程
配置预测参数
test_images_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/test/images"
test_labels_files = "hdfs://10.110.18.18:8020/user/root/mnist/csv/test/labels"
#Parse arguments for inference
args = parser.parse_args(['--mode', 'inference',
'--images', test_images_files,
'--labels', test_labels_files])
启动TFoS
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, False, TFCluster.InputMode.SPARK)
预测
#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)
关闭TFoS
cluster.shutdown()