어떻게 cifar10을 tfrecords로 전환합니까

1268 단어 에세이
import tensorflow as tf

import pickle
import os
path = r'your cifar10 address'

def unpickle(file):
	with open(file,'rb') as f:
		dict = pickle.load(f,encoding='bytes')
	return dict

def _int64_feature(value):
	return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def restore(inputfilename,outputfilename):
	dict = unpickle(inputfilename)
	# dict: include [b'batch_label', b'labels', b'data', b'filenames']
	# we just choose labels and data. And we choose restore it by int64
	# labels:[10000,1]
	labels= dict[b'labels']
	#images:[10000,3072]
	images = dict[b'data']
	writer = tf.python_io.TFRecordWriter(outputfilename)
	for i in range(10000):
		image_raw = images[i].tostring()
		example = tf.train.Example(features=tf.train.Features(feature={
			'raw_image':_bytes_feature(image_raw),
			'label':_int64_feature(labels[i])
		}))
		writer.write(example.SerializeToString())
	
filenames = [os.path.join(path,'data_batch_%d' %i) for i in range(1,6)]
for filename in filenames:
	restore(filename,filename+'.tfrecord')

좋은 웹페이지 즐겨찾기