tensorflow 학습 노트의 tfrecord 파일 생성 및 읽기

모형을 훈련할 때, 우리는 직접 그림을 모형에 보내는 것이 아니라, 먼저 그림을 tfrecord 파일로 변환한 다음에 tfrecord 파일을 모형에 넣는다.이 예는 tfrecord 파일을 이해하기 위해 6폭의 이미지와 라벨을 tfrecord 파일로 변환한 다음에 tfrecord 파일을 읽고 6폭의 이미지와 라벨을 재현합니다.
1. tfrecord 파일 생성

import os
import numpy as np
import tensorflow as tf
from PIL import Image

filenames = [
'images/cat/1.jpg',
'images/cat/2.jpg',
'images/dog/1.jpg',
'images/dog/2.jpg',
'images/pig/1.jpg',
'images/pig/2.jpg',]

labels = {'cat':0, 'dog':1, 'pig':2}

def int64_feature(values):
	if not isinstance(values, (tuple, list)):
		values = [values]
	return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

with tf.Session() as sess:
	output_filename = os.path.join('images/train.tfrecords')
	with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
		for filename in filenames:
			# 
			image_data = Image.open(filename)
			# 
			image_data = np.array(image_data.convert('L'))
			# bytes
			image_data = image_data.tobytes()
			# label
			label = labels[filename.split('/')[-2]]
			# protocol 
			example = tf.train.Example(features=tf.train.Features(feature={'image': bytes_feature(image_data),
																			'label': int64_feature(label)}))
			tfrecord_writer.write(example.SerializeToString())
2. tfrecord 파일 읽기

import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image

#  
filename_queue = tf.train.string_input_producer(['images/train.tfrecords'])
reader = tf.TFRecordReader()
#  
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, 
									features={'image': tf.FixedLenFeature([], tf.string), 
												'label': tf.FixedLenFeature([], tf.int64)})
#  
image = tf.decode_raw(features['image'], tf.uint8)
#  [ , ]
image = tf.reshape(image, [60, 160])
#  label
label = tf.cast(features['label'], tf.int32)

with tf.Session() as sess:
	#  , 
	coord = tf.train.Coordinator()
	#  QueueRunner,  
	threads = tf.train.start_queue_runners(sess=sess, coord=coord)

	for i in range(6):
		image_b, label_b = sess.run([image, label])
		img = Image.fromarray(image_b, 'L')
		plt.imshow(img)
		plt.axis('off')
		plt.show()
		print(label_b)

	#  
	coord.request_stop()
	#  , 
	coord.join(threads)
tensorflow 학습 노트의 tfrecord 파일의 생성과 읽기에 관한 이 글을 소개합니다. 더 많은 tfrecord 파일의 생성과 읽기 내용은 저희 이전의 글을 검색하거나 아래의 관련 글을 계속 훑어보십시오. 앞으로 많은 응원 부탁드립니다!

좋은 웹페이지 즐겨찾기