๐Pytorch ignite๊ฐ ๋ญ์ผ๐!!
๐ค Boilerplate๋?
ํ
ํ๋ฆฟ(Template)๊ฐ์ ๋๋์ผ๋ก, ๋ฐ๋ณต์ ์ธ ์ฝ๋๋ฅผ ํ์ดํํ ํ์์์ด ๋ฐ๋ณต์ ์ธ ์ผ๋ค์ ํ์ง ์๋๋ก ๋์์ฃผ๋ ๊ฒ์ด ๋ฐ๋ก ๋ณด์ผ๋ฌํ๋ ์ดํธ(Boilerplate)์ด๋ค. pytorch ignite๋ ๋ฅ๋ฌ๋ ๋ถ์ผ์์ ๋ชจ๋ธ์ ์ฝ๋ฉํ๋ ์๊ฐ๋ณด๋ค ๋ถ์์ ์ธ ์์์ ์ฝ๋ฉ(trainer, dataset ๋ฑ)์ ๋ ๋ง์ ์๊ฐ์ด ์์๋๊ธฐ ๋๋ฌธ์ ์ฌ์ฌ์ฉ ๊ฐ๋ฅํ ์ฝ๋๋ฅผ ๋ง๋ค์๋ ์๋ฏธ์ด๋ค.
(์ข) ignite๋ฅผ ์ฌ์ฉํ ์ฝ๋ (์ฐ) ์ผ๋ฐ์ ์ธ ์ฝ๋
๐คฉ IGNITE YOUR NETWORKS!
Pytorch-ignite๋ model์ ํ๋ จ์ํค๊ณ ์
๋ฐ์ดํธํ๋ ๋ชจ๋ ๊ณผ์ ์ ์ด์ ๊ด๋ จ๋ ๋ฉ์๋๋ฅผ ์ ๊ณตํ์ฌ ๊น๋ํ๊ณ ์ฌ์ฌ์ฉํ ์ ์๊ฒ ๋์์ฃผ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ด๋ค. ignite์ ๋ณธ์ง์ ignite.engine.Engine
์ผ๋ก ์ด๋ฃจ์ด์ ธ์๋๋ฐ, Engine
์ ์
๋ ฅ ๋ฐ์ ์ฐ์ฐ์ ๊ณ์ํด์ ๋ฐ๋ณต ์ํํ๋ ์ญํ ์ ํ๋ค.
while epoch < max_epochs:
# run an epoch on data
data_iter = iter(data)
while True:
try:
batch = next(data_iter)
output = process_function(batch)
iter_counter += 1
except StopIteration:
data_iter = iter(data)
if iter_counter == epoch_length:
break
์์ ์ฝ๋๋ Engine
์ ์๋ฏธ๋ฅผ ๋ํ๋ด๋ ์ฝ๋์ด๋ค.
๐ Engine
Engine
์ ์ฝ๊ฒ ์๊ฐํด ํ์ตํ๋ ๋ถ๋ถ์ ๊ณ์ ๋๋ ค์ฃผ๋ ๊ฒ์ด๋ค. ๋ง์น ์๋์ฐจ ์์ง์ฒ๋ผ ๋ง์ด๋ค. Engine
์ ์ฌ์ฉ๋ฒ์ ignite.engine.engine.Engine(process_function)
์ด๋ค. process_function
์ ์ฌ์ฉ์๊ฐ ์ง์ ์ฝ๋ฉํ๋ ๋ถ๋ถ์ feed-forward, loss ๊ณ์ฐ, ์ญ์ ํ ๊ณ์ฐ, Gradient Descent ์ํ ๋ฑ ์ด๋ค. ์๋ ์ฝ๋๋ ๊ธฐ๋ณธ ํธ๋ ์ด๋๋ฅผ ๋ง๋ ์์์ด๋ค.
def update_model(engine, batch): # process function
inputs, targets = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(update_model) # ignite.engine.engine.Engine()์ ํ์
@trainer.on(Events.ITERATION_COMPLETED(every=100)) # event
def log_training(engine):
batch_loss = engine.state.output
lr = optimizer.param_groups[0]['lr']
e = engine.state.epoch
n = engine.state.max_epochs
i = engine.state.iteration
print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}")
trainer.run(data_loader, max_epochs=5) # Engine์ ๊ฐ๋จํ๊ฒ .run()์ผ๋ก ๋๋ฆด ์ ์๋ค.
> Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
> ...
> Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01
์๋ ์ฝ๋๋ evaluator ์์ ์ฝ๋์ด๋ค.
from ignite.metrics import Accuracy
def predict_on_batch(engine, batch)
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
return y_pred, y
evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)
๐ฅณ EVENT
๊ธฐ๋ณธ์ ์ธ train process
Pytorch ignite์๋ Engine
์ ํจ์จ, ์ ์ฐ์ฑ์ ํฅ์์ํค๊ธฐ ์ํด EVENT ์์คํ
์ด ๋์
๋ฌ๋ค. ์๋ฅผ๋ค๋ฉด
- STARTED : ์์ง ์คํ์ด ์์๋ ๋ ๋ฐ์ํ๋ ์ด๋ฒคํธ
- EPOCH_STARTED : Epoch๊ฐ ์์๋ ๋ ๋ฐ์ํ๋ ์ด๋ฒคํธ
- GET_BATCH_STARTED : ๋ค์ ๋ฐฐ์น๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ ์ ๋ฐ์ํ๋ ์ด๋ฒคํธ
๋ฑ, ๋ง์ ์ด๋ฒคํธ๊ฐ ์กด์ฌํ๋ค. ๊ทธ๋์ ์ฌ์ฉ์๋ ์ฌ์ฉ์๊ฐ ์ ์ํ ์ฝ๋๋ฅผ Event handler๋ก ์คํํ ์ ์๋ค. handler๋ lambda, function, class method ๋ฑ๊ณผ ๊ฐ์ ๋ชจ๋ ํจ์๊ฐ ๋ ์ ์๋ค. Pytorch ignite๋ ๋ง์ ์ด๋ฒคํธ๊ฐ ์กด์ฌํ๊ธฐ ๋๋ฌธ์ ํจ์๋ฅผ ๋ฑ๋ก๋ง ํด์ฃผ๋ฉด ์ฝ๊ฒ ์งํํ ์ ์๋ค.
Event Handler๋ add_event_handler
ํจ์๋ฅผ ์ฌ์ฉํด์ ์์ฑํ ์ ์๋ค. ์๋๋ ์์ ์ฝ๋์ด๋ค.
def run_validation(engine, validation_engine, valid_loader):
validation_engine.run(valid_loader, max_epoch=1)
train_engine.add_event_handler(
Events.EPOCH_COMPLETED,
run_validation,
validation_engine,
valid_loader,
์๋๋ decorator๋ฅผ ํ์ฉํ์ฌ Event call-back ํจ์๋ฅผ ์์ฑํ ์์์ด๋ค.
@train_engine.on(Events.EPOCH_COMPLETED)
def print_train_logs(engine):
avg_p_norm = engine.state.metrics['|param|']
avg_g_norm = engine.state.metrics['|g_param|']
avg_loss = engine.state.metrics['loss']
avg_accuracy = engine.state.metrics['accuracy']
print('Epoch {} - |param|={:.2e} |g_param|={:.2e} loss={}, accuracy={}'
engine.state.epoch,
avg_p_norm,
avg_g_norm,
avg_loss,
avg_accuracy,
))
๋ค์์๋ ๋น์ทํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ธ pytorch lighting์ ๋ํด ์์๋ณด๋๋ก ํ๊ฒ ๋ค.
Author And Source
์ด ๋ฌธ์ ์ ๊ดํ์ฌ(๐Pytorch ignite๊ฐ ๋ญ์ผ๐!!), ์ฐ๋ฆฌ๋ ์ด๊ณณ์์ ๋ ๋ง์ ์๋ฃ๋ฅผ ๋ฐ๊ฒฌํ๊ณ ๋งํฌ๋ฅผ ํด๋ฆญํ์ฌ ๋ณด์๋ค https://velog.io/@sanha9999/Pytorch-ignite๊ฐ-๋ญ์ผ์ ์ ๊ท์: ์์์ ์ ๋ณด๊ฐ ์์์ URL์ ํฌํจ๋์ด ์์ผ๋ฉฐ ์ ์๊ถ์ ์์์ ์์ ์ ๋๋ค.
์ฐ์ํ ๊ฐ๋ฐ์ ์ฝํ ์ธ ๋ฐ๊ฒฌ์ ์ ๋ (Collection and Share based on the CC Protocol.)
์ข์ ์นํ์ด์ง ์ฆ๊ฒจ์ฐพ๊ธฐ
๊ฐ๋ฐ์ ์ฐ์ ์ฌ์ดํธ ์์ง
๊ฐ๋ฐ์๊ฐ ์์์ผ ํ ํ์ ์ฌ์ดํธ 100์ ์ถ์ฒ ์ฐ๋ฆฌ๋ ๋น์ ์ ์ํด 100๊ฐ์ ์์ฃผ ์ฌ์ฉํ๋ ๊ฐ๋ฐ์ ํ์ต ์ฌ์ดํธ๋ฅผ ์ ๋ฆฌํ์ต๋๋ค