TensorFlow Queue를 사용한 비동기 정보

※ 저자의 착각이 남아 있을 가능성이 있으므로 주의해서 읽어 주었으면 합니다.

강화 학습 분야에 있는 A3C 라는 비동기로 Actor Critic 를 학습하는 알고리즘을 공부함에 있어서, 어떻게 비동기를 TensorFlow 로 실현하는지 조사했으므로 정리했습니다.

미안 정도에 몇개의 클래스에 대한 설명을 싣고 있습니다만, 그러한 클래스를 이용한 프로그램을 마지막에 싣고 있으므로, 움직이고 나서 조사해 주는 편이 좋을지도 모릅니다.

큐(Queue)



tf.FIFOQueue


TensorFlow 에서 여러 스레드를 비동기적으로 텐서를 계산하는 편리한 대기열입니다.
예를 들어, 비동기로 움직이고 싶은 메소드를 이 큐에 추가 ( queue.enqueue(function) ) 해 갑니다. 그 외에도, 랜덤에 디큐를 하는 큐가 구현되고 있어 tf.RandomShuffleQueue 가 있습니다.

필요한 인수는 다음 두 가지입니다.
capacity: integer型。 キューに追加できる上限数。
dtypes: リストに入るオブジェクトの型。

자세한 내용은 tf.FIFOQueue을 참조하십시오.

tf.train.QueueRunner



오퍼레이션 ( operation )을 enqueue() 한 것을 리스트에 보관 유지해, 각각을 thread내에서 실행합니다.
qr = tf.train.QueueRunner(queue, [queue.enqueue() for _ in [...]])
tf.train.queue_runner.add_queue_runner(qr)

자세한 내용은 tf.train.QueueRunner을 참조하십시오.

tf.train.Coordinator


tf.train.Coordinator 클래스는 개시한 thread를 조정해 줍니다. tf.train.start_queue_runners() 는 세션 내의 그래프로 수집된 모든 tf.QueueRunner() 의 thread를 실행해 줍니다. 이 두 가지를 함께 사용하면 비동기 적으로 실행할 수 있습니다.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)

자세한 내용은 tf.train.Coordinator을 참조하십시오.

구현 예



비동기적으로 변수( loss )에서 균일한 난수( tf.random_uniform )를 빼기만 하면 됩니다.
import logging

import tensorflow as tf


class Group(object):
  def __init__(self, scope):
    with tf.name_scope(scope):
      self.loss = tf.Variable(1000., trainable=False, name="loss")

  def loss_op(self, queue: tf.FIFOQueue):
    loss = tf.assign_sub(self.loss, tf.random_uniform(shape=[], maxval=10.))
    print("loss op: {0}".format(self.loss))
    tf.summary.scalar(self.loss.name, self.loss)
    return queue.enqueue(loss)


def main(_):
  thread_size = 5
  groups = [Group(scope="thread_{0}".format(i)) for i in range(thread_size)]
  queue = tf.FIFOQueue(capacity=thread_size * 10,
                       dtypes=[tf.float32], )
  qr = tf.train.QueueRunner(queue, [g.loss_op(queue) for g in groups])
  tf.train.queue_runner.add_queue_runner(qr)
  loss = queue.dequeue()
  mean_loss = tf.reduce_mean([g.loss for g in groups])

  init = tf.global_variables_initializer()
  with tf.Session() as sess:
    init.run()
    summaries = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter('log', sess.graph)
    sess.graph.finalize()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)

    try:
      for i in range(100):
        if coord.should_stop():
          break
        _loss, _mean, _summaries = sess.run([loss, mean_loss, summaries])
        print("loss: {0:0.2f}, mean loss: {1:0.2f}".format(_loss, _mean))
        summary_writer.add_summary(_summaries, global_step=i)
    except Exception as e:
      coord.request_stop(e)
    finally:
      coord.request_stop()
      coord.join(threads)


if __name__ == '__main__':
tf.app.run()

TensorBoard



좋은 웹페이지 즐겨찾기