针对MNIST,采用分布式tensorflow进群进行测试。

首先在10.110.18.21610.110.18.21710.110.18.218 安装tensorflow。

修改mnist的分布式版本distribute_worker_mnist.py

import math
import time
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat

from tensorflow.examples.tutorials.mnist import input_data

tf.app.flags.DEFINE_string("ps_hosts", "", "Comma-separated list of hostname:port pairs")

tf.app.flags.DEFINE_string("worker_hosts", "", "Comma-separated list of hostname:port pairs")

tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")

tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")

tf.app.flags.DEFINE_integer("hidden_units", 100, "Number of units in the hidden layer of the NN")

tf.app.flags.DEFINE_string("data_dir", "MNIST_data", "Directory for storing mnist data")

tf.app.flags.DEFINE_integer("batch_size", 100, "Training batch size")

FLAGS = tf.app.flags.FLAGS

IMAGE_PIXELS = 28


def main(_):
    ps_hosts = FLAGS.ps_hosts.split(",")

    worker_hosts = FLAGS.worker_hosts.split(",")
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()

    elif FLAGS.job_name == "worker":
        start_time = time.time()
        with tf.device(
                tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,
                                               cluster=cluster)):
            hid_w = tf.Variable(
                tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS),
                name="hid_w")

            hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")

            sm_w = tf.Variable(tf.truncated_normal([FLAGS.hidden_units, 10], stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
                               name="sm_w")

            sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

            x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])

            y_ = tf.placeholder(tf.float32, [None, 10])

            hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)

            hid = tf.nn.relu(hid_lin)

            y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

            loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

            global_step = tf.Variable(0)

            train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step)

            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

            saver = tf.train.Saver()

            summary_op = tf.summary.merge_all()

            init_op = tf.global_variables_initializer()

            sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="train_logs", init_op=init_op,
                                     summary_op=summary_op, saver=saver, global_step=global_step, save_model_secs=600)
            mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
            with sv.managed_session(server.target) as sess:
                step = 0

                while not sv.should_stop() and step < 1000:
                    batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
                    train_feed = {x: batch_xs, y_: batch_ys}

                    _, step = sess.run([train_op, global_step], feed_dict=train_feed)

                    if step % 100 == 0:
                        print("global step: {}, accuracy:{}".format(step, sess.run(accuracy,
                                                                                   feed_dict={x: mnist.test.images,
                                                                                              y_: mnist.test.labels})))


    end_time = time.time()
    print('waste time:{}'.format(end_time - start_time))
    sv.stop()

if __name__ == "__main__":
    tf.app.run()

将10.110.18.216作为ps服务器,运行脚本

$ python distribute_worker_mnist.py \
     --ps_hosts=10.110.18.216:2222 \
     --worker_hosts=10.110.18.217:2222,10.110.18.218:2222 \
     --job_name=ps --task_index=0

将10.110.18.217和10.110.18.218作为worker,运行脚本

$ python distribute_worker_mnist.py \
     --ps_hosts=10.110.18.216:2222 \
     --worker_hosts=10.110.18.217:2222,10.110.18.218:2222 \
     --job_name=worker --task_index=0
$ python distribute_worker_mnist.py \
     --ps_hosts=10.110.18.216:2222 \
     --worker_hosts=10.110.18.217:2222,10.110.18.218:2222 \
     --job_name=worker --task_index=1

在脚本中保存的train_logs在本地,为了在集群保持统一;在3个主机都拷贝相同的目录。

测试执行成功。

results matching ""

    No results matching ""