서버에 tensorflow가 협동할 때 진정한 협동이 아니라 데이터를 어지럽히는 문제를 해결한다

4586 단어 tf
queuerunner 중 shuffle=False는 다중 라인이 순환할 때 누가 빠르고 누가 먼저 나오는지
비동기적인 것 같아.
 
그래서fetch, 파일 이름도 던져서 wenj 파일 이름이 맞지 않는 문제를 해결할 수 밖에 없습니다.
"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
                       --input input_sample.jpg \
                       --output output_sample.jpg \
                       --image_size 128
"""

import tensorflow as tf
import os
from model import CycleGAN
import utils
from glob import glob

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('model', '', 'model path (.pb)')
tf.flags.DEFINE_string('input_dir', 'samples/input_n2v', 'input image path (.jpg)')
tf.flags.DEFINE_string('output_dir', 'samples/output_n2v', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '128', 'image size, default: 128')

os.environ["CUDA_VISIBLE_DEVICES"] = '1'
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.allow_growth = True

def get_all_org_files(file_dir):
    L = []
    for root, dirs, files in os.walk(file_dir):
        for file in files:
            if os.path.splitext(file)[1] == '.bmp':
                L.append(os.path.join(root, file))
    return L
#   os.path.splitext()           +   

def inference():
  graph = tf.Graph()
  # old_img_file_path = get_all_org_files(FLAGS.input_dir)
  
  # !!! 
  model_name = FLAGS.model.split('/')[-1]
  if model_name == 'n2v.pb':
    input_shape = [FLAGS.image_size, FLAGS.image_size, 1]
    input_channel = 1
    input_name = 'input_image_x'
  else:
    input_shape = [FLAGS.image_size, FLAGS.image_size, 3]
    input_channel = 3
    input_name = 'input_image_y'
  
  with graph.as_default():
      #   decoder
      # !!!  linux
      paths = glob("{}/*.{}".format(FLAGS.input_dir, 'bmp'))  #            ,          jpg  tf decode   jpg

      tf_decode = tf.image.decode_bmp
      
      

      #      ,path       ,               
      # shuffle    true      false     queue        
      filename_queue = tf.train.string_input_producer(list(paths), shuffle=False)

      #   queue   reader
      reader = tf.WholeFileReader()

      #       ,     ,        , sess       
      filename, data = reader.read(filename_queue)

      #      image  ,decode,   session    
      image = tf_decode(data, channels=input_channel)  #   channels=0     bmp channel  ,    ,    

      # image      
      # reshape  tensor,     
      image.set_shape(input_shape)

      image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    #       image   ,         Tensor,        !!!
    # input = tf.placeholder(dtype=tf.string)
    # with tf.gfile.FastGFile(input, 'rb') as f:
    #   image_data = f.read()
    #   input_image = tf.image.decode_jpeg(image_data, channels=1)
    #   input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
    #   input_image = utils.convert2float(input_image)
    #   input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 1])

      with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(model_file.read())
      # !!!  ???  export_graph name  ,           !
      [output_image] = tf.import_graph_def(graph_def,
                              input_map={input_name: image},
                              return_elements=['output_image:0'],
                              name='output')

  with tf.Session(graph=graph,config=config) as sess:
    coord = tf.train.Coordinator()  #        
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)  #         

    for i in range(len(paths)):
        
        generated, it_path = sess.run((output_image, filename))
        # import pdb;pdb.set_trace()
        it_name = it_path.decode('utf8').split('/')[-1]
        
        with open('{}/{}'.format(FLAGS.output_dir,it_name), 'wb') as f:
              f.write(generated)
    # for i in range(len(paths)):
        # generated = output_image.eval()
        # with open('{}/{:05d}.jpg'.format(FLAGS.output_dir,i+1), 'wb') as f:
              # f.write(generated)

    coord.request_stop()  #        
    coord.join(threads)

def main(unused_argv):
  inference()

if __name__ == '__main__':
  tf.app.run()

좋은 웹페이지 즐겨찾기