【Tensorflow】tf.app.run () 와 명령행 매개 변수 분석

6992 단어 TensorFlow

tf.app.run()


먼저 일반적인 코드를 제공합니다.
if __name__ == '__main__':
    tf.app.run()

이전 함수run()에 대한 Tensorflow의 소스를 찾습니다.
def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(_sys.argv[:1] + flags_passthrough))


_allowed_symbols = [
    'run',
    # Allowed submodule.
    'flags',
]

remove_undocumented(__name__, _allowed_symbols)

원본 코드의 과정은 먼저 flags의 매개 변수 항목을 불러온 다음에 main 함수를 실행하는 것을 볼 수 있다.여기서 매개변수는 tf.app.flags.FLAGS로 정의됩니다.

tf.app.flags.FLAGS

tf.app.flags.FLAGS 사용 정보:
# fila_name: temp.py
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

print('string: ', FLAGS.string)
print('learning_rate: ', FLAGS.learning_rate)
print('flag: ', FLAGS.flag)

출력:
string:  train
learning_rate:  0.001
flag:  True

명령줄에서 python3 temp.py --help을 실행하면 다음과 같이 출력됩니다.
usage: temp.py [-h] [--string STRING] [--learning_rate LEARNING_RATE]
               [--flag [FLAG]] [--noflag]

optional arguments:
  -h, --help            show this help message and exit
  --string STRING       This is a string
  --learning_rate LEARNING_RATE
                        This is the rate in training
  --flag [FLAG]         This is a flag
  --noflag
FLAGS의 기본값을 수정하려면 명령을 입력하면 됩니다.
python3 temp.py --string 'test' --learning_rate 0.2 --flag False

공동 사용

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

def main(unuse_args):
    print('string: ', FLAGS.string)
    print('learning_rate: ', FLAGS.learning_rate)
    print('flag: ', FLAGS.flag)

if __name__ == '__main__':
    tf.app.run()


주함수 중의 tf.app.run()main를 호출하고 파라미터를 전달하기 때문에 main 함수에 파라미터의 위치를 설정해야 한다.main 이름을 바꾸려면 tf.app.run()에 지정한 함수 이름을 입력하면 됩니다.
def test(args):
    # test
    ...
if __name__ == '__main__':
    tf.app.run(test)

좋은 웹페이지 즐겨찾기