AssertionError: can only join a child process

처음보는 에러다.

custom한 Dataset과 Transformation이 완전 naiver version이여서 속도가 상당히 느리다.
이상하게 짠 부분이 있다.

어째는 pytorch내장 transformation하고 속도좀 비교해보려고 1-epoch을 읽어 보았다.
그러던중 뭔 이상한 error가 발생했다.

우선 error가 난 code는 다음과 같다.

class custom_dataset(Dataset):
    
    def __init__(self, inputs_dir , targets_dir, transform = None):
        
        self.inputs_dir = inputs_dir 
        self.inputs_list =  os.listdir(inputs_dir)
        
        self.targets_dir = targets_dir
        self.targets_list =  os.listdir(targets_dir)      
        
        self.transform = transform
            
    def __len__(self):
        return len(self.inputs_list)
    
    def __getitem__(self,idx):
    
        os.chdir(self.inputs_dir)
        input_image = Image.open(self.inputs_list[idx])
        
        os.chdir(self.targets_dir)
        target_image = Image.open(self.targets_list[idx])
        
        combine = {'input':input_image, 'target':target_image}
        
        if self.transform:
            combine = self.transform(combine)

            
        return (combine['input'] , combine['target'])
class RandomFlip(object):
    
    def __call__(self, combine):
    
        inputs = combine['input']
        targets = combine['target']

        inputs_np = np.array(inputs) #물론 flip을 먼저하면 np로 바뀌어있긴하다. 하지만 항상 flip을 하는것은 아니니까
        targets_np = np.array(targets)


        if np.random.rand() > 0.5:   # 좌우
            inputs_np = np.fliplr(inputs_np)
            targets_np = np.fliplr(targets_np)

        if np.random.rand() > 0.5:   # 상하
            inputs_np = np.flipud(inputs_np)
            targets_np = np.flipud(targets_np)

        combine = {'input': inputs_np, 'target': targets_np}  #출력은 np type이다.

        return combine
class ToTensor(object):
    def __call__(self, combine):
    
        inputs = combine['input']
        targets = combine['target']

        inputs_np = np.array(inputs) #물론 flip을 먼저하면 np로 바뀌어있긴하다. 하지만 항상 flip을 하는것은 아니니까
        targets_np = np.array(targets)
        
        inputs_tensor = torch.from_numpy(inputs_np)
        targets_tensor = torch.from_numpy(targets_np)
        
        inputs_np_trans = np.transpose(inputs_tensor, (2,0,1))
        targets_np_trans = np.transpose(targets_tensor.unsqueeze(2), (2,0,1))

        
        combine = {'input':inputs_np_trans, 'target':targets_np_trans}

        return combine
class Resize(object):
    
    def __init__(self, output_size, mode):
        
        self.output_size = output_size
        self.mode = mode
    
    def __call__(self, combine):
        
        # ToTensor 이후라고 가정하니까 
        inputs_tensor = combine['input']
        targets_tensor = combine['target']
        
        inputs_rescaled = torch.nn.functional.interpolate(inputs_tensor.unsqueeze(dim=0).float(), size = self.output_size, mode = self.mode).squeeze(dim=0)
        targets_rescaled = torch.nn.functional.interpolate(targets_tensor.unsqueeze(dim=0).float(), size = self.output_size, mode = self.mode).squeeze(dim=0)
        
        combine = {'input':inputs_rescaled, 'target':targets_rescaled }

        return combine
path_train_inputs = '/home/mskang/hyeokjong/cancer/2018/task1/train_input_pad'
path_train_targets = '/home/mskang/hyeokjong/cancer/2018/task1/train_targets_pad'

transformation = transforms.Compose([
                                    RandomFlip(),
                                    ToTensor(), 
                                    Resize((112,112), mode ='nearest')
                                      ])

train_dataset = custom_dataset(path_train_inputs, path_train_targets, transformation  )




batch_size = 16

train_dl = DataLoader(train_dataset, batch_size, shuffle=True,
                      num_workers=4, pin_memory=True)

여기까지가 baseline이다.

아래가 error이다.

from tqdm import tqdm

start_time = time.time()


for x,y in tqdm(train_dl):
    inputs = x
    targets = y
end_time = time.time()

print(end_time - start_time)

여기서 Error가 나왔는데 내용은 다음과 같다.

뭐지 정리하려 하니까 error가 안나온다,..........

아무튼
https://discuss.pytorch.org/t/error-while-multiprocessing-in-dataloader/46845/17

여기 보고 해결했는데 방법은 다음과 같이 code를 수정하는 것이다.

from tqdm.auto import tqdm

위 링크에서 알려준 방법인데 알려만 줬다. 아무도 이유를 모른다.
그리고 보통 error가 나면 for문이 멈춰야 하는데 또 잘 된다.
뭔지 모르겠다.

좋은 웹페이지 즐겨찾기