[한번 해봤는데] PyTorch의 MNIST.
가능하면 매주 투고하고 싶은데 유혹이 너무 심해요.
정신을 차리고 매일 컨트롤러를 잡고 있어요.
다시 정신을 가다듬고 모델의 정밀도를 평가합시다!
지난번 회고
MNIST에서 마지막으로 읽은 데이터는 모델을 먼저 구성합니다.
그러나
val_loss
가장 작을 때 무게를 보존하지 않고 공부를 좀 한 모델을 평가했다.이번에는
val_loss
가 가장 시간이 지나면 모델 저장네 가지 노력!
val_loss를 최소화할 때 모델 저장
저번에 왜 안 이루어졌는지 깜짝 놀랐어요.
기본적으로
val_loss
가장 낮은 부분은 정밀도가 높다고 여겨지기 때문에 코드를 수정하여 모델을 저장해야 한다.best_loss = None
for e in range(EPOCH):
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (best_loss is None) or (best_loss > val_total_loss):
best_loss = val_total_loss
model_path = 'model.pth'
torch.save(model.state_dict(), model_path)
print()
이렇게 되면 모형이 보존되어 정말 안심이 된다그럼 지난번에 제대로 진행되지 못했던 정밀도를 평가해 봅시다!
평가 모델
어쨌든 저장된 모델부터 평가를 해보도록 하겠습니다.
test_total_acc = 0
model.eval()
model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))
with torch.no_grad():
for n,(data,label) in enumerate(test_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
test_total_acc += cal_acc(label,output)
print(f"test acc:{test_total_acc/len(test_data)*100}")
결과test acc:98.65999603271484
오, 선이 좋네요.반대로 뭘 잘못했을까?일단 공부한 그래프를 확인해 보도록 하겠습니다.
도표 그리기
학습 과정을 표시하려면 아래 코드를 실행하십시오
참조사이트 축소판 그림
fig, ax = plt.subplots()
t = np.linspace(1,10,10)
ax.set_xlabel('Epoch')
ax.set_ylabel('loss')
ax.grid()
ax.plot(t,train_loss,color = 'red',label='train')
ax.plot(t,val_loss,color = 'green',label='val')
ax.legend(loc = 0)
fig.tight_layout()
plt.show()
5~6Epoch 근처부터 과잉 학습을 했네요.다음은 정확도를 봅시다.
import numpy as np
fig, ax = plt.subplots()
t = np.linspace(1,10,10)
ax.set_xlabel('Epoch')
ax.set_ylabel('Acc')
ax.grid()
ax.plot(t,train_acc,color = 'red',label='train')
ax.plot(t,val_acc,color = 'green',label='val')
ax.legend(loc = 0)
fig.tight_layout()
plt.show()
큰 괴리가 없네.역시 MNIST는 쉽네요.
이제 퓨전 팀 정해서 확인해주세요!
혼합대를 짓다
sklearn
의 모듈을 사용하여 혼합 행렬을 출력합니다.실현
from sklearn.metrics import confusion_matrix
print(confusion_matrix(y_test, y_pred))
이런 느낌이지만 지금은 저장이 안 돼서output
저장할 수 있도록list
코드를 변경했습니다.test_total_acc = 0
model.eval()
model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))
pred_list = []
true_list = []
with torch.no_grad():
for n,(data,label) in enumerate(test_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
test_total_acc += cal_acc(label,output)
pred = torch.argmax(output , dim =1)
pred_list += pred.detach().cpu().numpy().tolist()
true_list += label.detach().cpu().numpy().tolist()
print(f"test acc:{test_total_acc/len(test_data)*100}")
오세요!평범한 경기![[ 976 0 1 0 0 0 1 0 1 1]
[ 0 1126 2 0 1 0 1 1 4 0]
[ 1 2 1006 11 1 1 0 4 6 0]
[ 0 0 0 1004 0 2 0 1 2 1]
[ 0 0 0 0 975 0 0 0 0 7]
[ 2 0 0 12 1 873 1 0 2 1]
[ 4 1 0 1 5 4 940 0 3 0]
[ 0 2 2 5 0 0 0 1011 1 7]
[ 2 0 2 6 2 1 0 2 956 3]
[ 2 2 0 3 1 1 0 1 0 999]]
···어느 것이 어느 숫자인가!!!잘 모르니까 라벨을 붙여야 돼요.
참고사이트 축소판 그림 라벨을 붙여봤어요.
import pandas as pd
def add_label(matrix,columns):
columns_num = len(columns)
act = ['正解データ'] * columns_num
pred = ['予測データ'] * columns_num
cm = pd.DataFrame(matrix,columns = [pred,columns],index = [act,columns])
return cm
cm = add_label(confusion_matrix(true_list, pred_list),[x for x in range(10)])
display(cm)
결과는 다음과 같다.아까 데이터보다 더 보기 쉬워요!
결과를 보고 좀 오인했더라고요.
보아하니
0
유공 타입의 숫자2
로 오인된 것 같은데1
취미가 있는 타입인가요?)2
와 8
는 자주 틀렸다.7
면 알겠는데 왜2
와8
잘못된 숫자를 시각화하다
평가 코드 수정
test_total_acc = 0
model.eval()
model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))
pred_list = []
true_list = []
data_list = []
with torch.no_grad():
for n,(data,label) in enumerate(test_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
test_total_acc += cal_acc(label,output)
pred = torch.argmax(output , dim =1)
pred_list += pred.detach().cpu().numpy().tolist()
true_list += label.detach().cpu().numpy().tolist()
data_list.append(data.cpu())
print(f"test acc:{test_total_acc/len(test_data)*100}")
이후 다음 작업 수행fig = plt.figure(figsize = (20,5))
data_block = torch.cat(data_list,dim = 0)
idx_list = [n for n,(x,y) in enumerate(zip(true_list,pred_list)) if x!=y ]
len(idx_list)
for i,idx in enumerate(idx_list[:20]):
ax = fig.add_subplot(2,10,1+i)
ax.axis("off")
ax.set_title(f'true:{true_list[idx]} pred:{pred_list[idx]}')
ax.imshow(data_block[idx,0])
좌하 따위도 모르는군.전체적으로 틀린 디지털 이미지는 많은 사람들이 잘 몰라서 어쩔 수 없다.
창구 업무라면 다시 한 번 쓸 수 있다
시각화했더라도 이번엔 끝내고 싶다.
최초의 정밀도는 98% 였지만, 도대체 여기서 얼마나 높아질 수 있을까?
Reference
이 문제에 관하여([한번 해봤는데] PyTorch의 MNIST.), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://zenn.dev/opamp/articles/ddcfd32bcd1315텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)