caffe 멀티탭 생성lmdb

2491 단어
#encoding:utf-8
import numpy as np
import os
import lmdb
from PIL import Image 
import numpy as np 
import sys
# Make sure that caffe is on the python path:
#TODO
caffe_root = ''
TRAIN_LMDB = ''
VAL_LMDB = ''
ORGINAL_IMAGES_PATH = ''
sys.path.insert(0, caffe_root + '/python')
import caffe
####################pre-treatment############################
#txt with labels eg. (0001.jpg 2 5)
file_input=open('your label txt','r')
img_list=[]
label1_list=[]
label2_list=[]
for line in file_input.readlines():
    content=line.strip()
    content=content.split(' ')
    # label 
    img_list.append(content[0])
    label1_list.append(content[1])
    label2_list.append(content[2])
    del content
file_input.close() 
####################train data(images)############################
#your data lmdb path
# lmdb, lmdb , 
#os.system('rm -rf  ' + your data(images) lmdb path)
in_db=lmdb.open(TRAIN_LMDB,map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx,in_ in enumerate(img_list):         
        im_file=ORGINAL_IMAGES_PATH+in_
        im=Image.open(im_file)
        im = im.resize((w,h),Image.BILINEAR)# , 
        # BILINEAR, NEAREST,** **
        im=np.array(im) # im: (w,h)RGB->(h,w,3)RGB
        im=im[:,:,::-1]# im RGB BGR
        im=im.transpose((2,0,1))# height*width*channel channel*height*width
        im_dat=caffe.io.array_to_datum(im)
        in_txn.put('{:0>10d}'.format(in_idx),im_dat.SerializeToString())   
        print 'data train: {} [{}/{}]'.format(in_, in_idx+1, len(img_list))        
        del im_file, im, im_dat
in_db.close()
print 'train data(images) are done!'
######train data of label################    
#your labels lmdb path
in_db=lmdb.open(VAL_LMDB,map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx,in_ in enumerate(img_list):
        target_label=np.zeros((2,1,1))# 2 label
        target_label[0,0,0]=label1_list[in_idx]
        target_label[1,0,0]=label2_list[in_idx]
        label_data=caffe.io.array_to_datum(target_label)
        in_txn.put('{:0>10d}'.format(in_idx),label_data.SerializeToString())
        print 'label train: {} [{}/{}]'.format(in_, in_idx+1, len(img_list))
        del target_label, label_data    
in_db.close()
print 'train labels are done!'

참고


https://blog.csdn.net/u013010889/article/details/53098346

좋은 웹페이지 즐겨찾기