[210118 TIL] Class 상속

Today I Learned

written by 602



Class 상속 ( super().__init__)

aka. tensorflow에서 custom callback 만들기

class MyCallback(tf.keras.callbacks.Callback):
    
    def __init__(self, name, cfg, X, y, decoder_input_array,
                 output_file, output_gloss, output_skels):
        super().__init__()
        
        self.cfg = cfg
        self.X = X
        self.y = y
        self.decoder_input_array = decoder_input_array
        self.output_file = output_file
        self.output_gloss = output_gloss
        self.output_skels = output_skels
        
    def on_epoch_end(self, epoch, logs=None):
        #10
        if epoch > 0 and epoch % 5 == 0:
            lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
            print('learning rate : ', lr)
        #100
        if epoch > 0 and epoch % 10 == 0:
            model_path = self.cfg["model_path"]
            self.model.save(model_path+"model.h5")
            result_path = self.cfg["result_path"]
            make_predict(self.cfg, self.model, self.X, self.y, self.decoder_input_array,
                         self.output_file, self.output_gloss, self.output_skels,
                         result_path, epoch, best=False)

수어 생성 프로젝트 준비하면서 가장 애를 먹었던 부분 중 하나인 커스텀 콜백 지정 파트이다.

커스텀 콜백은 대부분 class를 새로 만들어주어야 하는데, 이때 tf.keras.callback.Callback을 부모 클래스로 상속받아와서 추가로 필요한 인자들을 self로 받아온다. 상속받을때는 super().__init__(self, 인자들)로 부모 클래스의 인자들과 추가로 불러올 인자들을 지정해주면 된다.

5 epoch마다 learning rate를 출력해보고, 10 epoch 마다 학습이 끊기는 것을 방지하기 위하여 모델을 저장하고 미리 만들어둔 make_predict 함수를 활용하여 결과를 시각화할 수 있는 나만의 콜백이다.

좋은 웹페이지 즐겨찾기