PyTorch 참고서 Self-Attention GAN 샘플 코드 추가 단편

「만들면서 배운다! PyTorch에 의한 발전 딥 러닝」의 「제5장 GAN에 의한 화상 생성(DCGAN, Self-Attention GAN)」의 「Attention Map」의 가시화에 대해 추가 단편을 작성했습니다.

참조 서적



만들면서 배우십시오! PyTorch의 발전 딥 러닝
오가와 유타로

htps : // 보오 k. my ゔぃ. jp / e c / p 로즈 cts /에서 원하는 l / i d = 104855
htps : // 기주 b. 코 m / 유타로 오가와 / py와 rch_ 아 d ゔ 센세 d
htps : // m / 곧 / ms / 07253d12b1fc72 16 아바

Attention Map 시각화 조각



'5-4_SAGAN.ipynb'의 끝에 셀을 추가하고 아래 코드를 붙여 넣습니다.

# Attentiom Mapを出力

print('1段目 生成した画像データ')
print('2段目 Attentin Map1 中央のピクセル→全ピクセル')
print('3段目 Attentin Map1 全ピクセル→中央のピクセル ⇒ 2段目と同じ結果が得られる')
print('4段目 Attentin Map1 右下のピクセル→全ピクセル ⇒ 7と8で差が出やすいピクセル')
print('5段目 Attentin Map1 左上端のピクセル→全ピクセル')
print('6段目 Attentin Map1 自ピクセル→自ピクセル')

row_num=6

# print('fake_images : ' + str(fake_images.size()))
# print('am1 : ' + str(am1.size()))

fig = plt.figure(figsize=(3*5, 3*row_num))
for i in range(0, 5):

  fake_image = fake_images[i][0]
  am = am1[i].view(16, 16, 16, 16)

  # print('fake_image : ' + str(fake_image.size()))
  # print('am : ' + str(am.size()))

  for j in range(0, row_num):
    plt.subplot(row_num, 5, 5*j+i+1)

    if j == 0:

      # 1段目 生成した画像データ
      plt.imshow(fake_image.cpu().detach().numpy(), 'gray')

    elif (j > 0) and (j < 5):

      if j == 1:
        # 2段目 Attentin Map1 中央のピクセル→全ピクセル
        am_tmp = am[7][7]
      elif j == 2:
        # 3段目 Attentin Map1 全ピクセル→中央のピクセル
        am_tmp = am[:][:][7][7]
      elif j == 3:
        # 4段目 Attentin Map1 右下のピクセル→全ピクセル
        am_tmp = am[11][11]
      elif j == 4:
        # 5段目 Attentin Map1 左上端のピクセル→全ピクセル
        am_tmp = am[0][0]

      am_tmp = am_tmp.cpu().detach().numpy()
      # print('i : ' + str(i) + ', j : ' + str(j) + ', max : ' + str(np.max(am_tmp)) + ', min : ' + str(np.min(am_tmp)))
      plt.imshow(am_tmp, 'Reds', vmin=0, vmax=0.05)

    elif j == 5:
      am_tmp = np.ones((16,16), dtype='float')

      # 6段目 Attentin Map1 自ピクセル→自ピクセル
      for k in range(16):
        for l in range(16):
          am_tmp[k][l] = am[k][l][k][l]

      # print('i : ' + str(i) + ', j : ' + str(j) + ', max : ' + str(np.max(am_tmp)) + ', min : ' + str(np.min(am_tmp)))
      plt.imshow(am_tmp, 'Reds', vmin=0, vmax=0.05)

    else:
      raise ValueError

출력 결과




1~3단째

4~6단째

좋은 웹페이지 즐겨찾기