import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", " ", "启动服务的类型ps or worker")
tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的哪一台服务器以task:0,task:1")
def main(argv):
# 定义一个全局计数的op,给钩子列表中的训练步数使用
global_step = tf.contrib.framework.get_or_create_global_step()
# 指定集群描述对象,ps worker,多台worker或者ps的定位规则,第一台:/job:worker/task:0,第二台:/job:worker/task:1,ps也是如此
cluster = tf.train.ClusterSpec({"ps":["192.168.0.4:2222",], "worker":["192.168.109.128:2323",]})
# 创建不同的服务 ps worker,job_name指定是ps还是worker,task_index,指定启动哪台服务器
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 根据不同的服务器做不同的事情,ps保存参数,worker指定设备运行模型计算
if FLAGS.job_name == 'ps':
# 参数服务器只需接受参数
server.join()
else:
worker_device = "/job:worker/task:0/cpu:0/"
# 指定设备去运行
with tf.device(tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster)):
# 演示一个矩阵乘法运算
x = tf.Variable([[1, 2, 3, 4]])
w = tf.Variable([[2], [4], [5], [7]])
mat = tf.matmul(x, w)
# 创建分布式会话
with tf.train.MonitoredTrainingSession(
master="grpc://192.168.0.1:2222", # 指定是否是主work
is_chief=(FLAGS.task_index==0), # 判断书否是主worker
config=tf.ConfigProto(log_device_placement =True), # 打印设备信息
hooks=[tf.train.StopAtStepHook(last_step=1000)] # 指定训练步数,指定步数需要定义一个全局计数的op
) as mon_sess:
while not mon_sess.should_stop():
# should_stops是否异常停止
mon_sess.run(mat)
if __name__ == "__main__":
tf.app.run()