๐Ÿ™ˆPytorch lightning์ด ๋ญ์•ผ๐Ÿ™‰!!

3438 ๋‹จ์–ด pytorch lightningPyTorchPyTorch

๐Ÿค” Pytorch lightning์ด๋ž€ ?

Pytorch lightning๋„ ์ €๋ฒˆ์— ๋ฆฌ๋ทฐํ•œ ignite์™€ ๋น„์Šทํ•œ, ๊ทธ๋ž˜์„œ ๋น„๊ต๋˜๋Š” ์˜คํ”ˆ์†Œ์Šค ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ด๋‹ค. ignite๋Š” pytorch์˜ ๊ณต์‹์ ์ธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ผ๊ณ ๋Š” ํ•˜์ง€๋งŒ, lightning์ด ํ•œ๊ตญ์–ด๋กœ ๋œ ์ž๋ฃŒ๊ฐ€ ๋”์šฑ ๋งŽ์€ ๊ฒƒ ๊ฐ™๋‹ค. ignite์˜ ํ•ต์‹ฌ์ด Engine์ด์—ˆ๋˜ ๊ฒƒ์ฒ˜๋Ÿผ lightning์˜ ํ•ต์‹ฌ์€ Trainer์™€ lightningmodule์ด๋‹ค.

โšกLightning module

lightning module์€ ๋Œ€๋ถ€๋ถ„์˜ ํ•™์Šต ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ์ •์˜๋˜๋Š” ํด๋ž˜์Šค์ด๋‹ค. lightning module์„ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” LightningModule ํด๋ž˜์Šค๋ฅผ ์ƒ์†๋ฐ›์•„์•ผํ•œ๋‹ค.

class MyClass(pl.LightningModule):
	def __init__(self):
        super().__init__()
        self.model = ๋ชจ๋ธ
    
    def forward(self, x):
        pass

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass
    
    def configure_optimizers(self):
        pass	

LightningModule์ž์ฒด๊ฐ€ pytorch์˜ nn.module์„ ์ƒ์†๋ฐ›์€ ํด๋ž˜์Šค์ด๊ธฐ ๋•Œ๋ฌธ์— nn.module์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ๋“ค์„ ๋‹ค ์“ธ ์ˆ˜ ์žˆ๋‹ค.

๐Ÿ‘ค Trainer

Trainer๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ optimizer step, backward, logging, ๋ถ„์‚ฐํ•™์Šต๋“ฑ์„ ๋‹ค๋ฃจ๋Š” ๋ถ€๋ถ„์ด๋‹ค. ๊ทธ๋ž˜์„œ ์œ ์ €๊ฐ€ ์ง์ ‘ ์ˆ˜์ •ํ•ด์„œ ์‚ฌ์šฉํ•˜์ง€๋Š” ์•Š๊ณ  ๊ตฌํ˜„๋œ ๋ถ€๋ถ„์„ Trainer์—์„œ ๊ฐ€์ ธ์™€ ์“ด๋‹ค๊ณ  ํ•œ๋‹ค.

์ด ์ •๋ฆฌ

์œ ์ €๋“ค์˜ ์ฝ”๋“œ์Šคํƒ€์ผ์ด ๋น„์Šทํ•ด์ง„๋‹ค๋Š” ์žฅ์ ์ด ๊ต‰์žฅํžˆ ํฐ ๊ฒƒ ๊ฐ™๋‹ค. ์•„๋ฌด๋ž˜๋„ ๊นƒํ—™์—์„œ ๋‹ค๋ฅธ ๊ฐœ๋ฐœ์ž์˜ ์ฝ”๋“œ๋ฅผ ๋ดค์„๋•Œ "์ด ๋ถ€๋ถ„์€ ์–ด๋””์žˆ๋Š”๊ฑฐ์ง€?" "์ด๊ฒŒ ๋ญ์ง€?"์ด๋Ÿฐ ์งˆ๋ฌธ์„ ๋˜์กŒ๋˜ ๋‚˜์—๊ฒŒ๋Š” ignite๋‚˜ lightning๊ฐ™์€ ์˜คํ”ˆ์†Œ์Šค๊ฐ€ ํ™œ์„ฑํ™” ๋ฌ์œผ๋ฉด ์ข‹๊ฒ ๋‹ค๋Š” ๋ฐ”๋žจ์ด ์žˆ๋‹ค. ๋‚ด๊ฐ€ ์ง„ํ–‰ํ•˜๋Š” ํ”„๋กœ์ ํŠธ์—๋„ lightning์„ ์ ์šฉํ•ด ๋ณด์•„์•ผ๊ฒ ๋‹ค.

์ข‹์€ ์›นํŽ˜์ด์ง€ ์ฆ๊ฒจ์ฐพ๊ธฐ