6 minute read

사전 지식:

Trainer:

native pytorch 코드가 아니라 간단하게 training argument, 사용할 metric 함수, dataset만 던져 넣으면 알아서 학습을 돌려준다. 그러면 내부에서 epoch 마다 돌리고, step 마다 돌리고, loss 계산하고, gradient 계산하고, evaluation 계산하고, 다 해준다 ^^ Customization을 하려면 Trainer을 상속하던지, 제공하는 API에 맞게 custom function, argument을 만들어서 집어넣어주면 된다.

trainer = Trainer( 
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
trainer.tarin()

Callbacks:

기본적인 Trainer 기능에 추가 기능을 넣는 것이라고 생각하면 편함. e.g Ealry Stop Callback, TensorBoard Callback, Wandb Callback 등이 있음.

trainer = Trainer( 
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks = [ ... callbacks list ...],
        compute_metrics=compute_metrics,
    )
trainer.tarin()

TrainerState

Trainer 내부에서 학습의 상태를 저장해두고 있는 object이다. TrainerState class에서 생성 됨. Callback과는 다르게 겉으로 드러나지 않고 Trainer 내부에서만 동작한다. trainer = Traner(…)으로 trainer obejct을 생성할 때 같이 생성 됨.

TrainerControl

Trainer 내부에서 학습을 도는 중간 중간에 특정 조건에 따라 특정 작업을 해야 할때 Control object에 있는 property를 조건문으로 사용하여 제어한다. should_log? should_evaluete? should_stop? 등의 정보를 가지고 있음. trainer = Traner(…)으로 trainer obejct을 생성할 때 TrainerControl class로부터 생성 됨.

Trainer API을 쓸때 사용하는 Class들

Trainer Class

  • 실제 epoch 돌고 학습하는 모든 과정을 가진 클래스. 여기에서 밑의 다른 Class들을 다 import 해서 객체를 만든다. 저것들 활용해서 학습 과정을 제어함.

    TrainerControl Class

  • 위 설명 참고

    TrainerCallback Class

  • 위에서는 필요한 추가기능 Callback을 만들어서 집어넣어준다고 했었음. 그런데 Callback을 만들때도 Trainer가 받아들일 수 있도록 정해진 interface에 따라서 만들어야 함. 그 기준이 되는 interface class가 TrainerCallback Class이다. 뜯어보면 구현이 안되어 있음.

    TranierState Class

  • 위 설명 참고

    CallbackHandler Class

  • Trainer가 받아들이는 Callback들을 init 하고, Trainer가 필요한 때에 불러서 사용할 수 있도록 대 객체로 만들어서 저장하고 있는 object을 만들어내는 class. Trainer Class 내부 구조
  • 의사 코드 형태로 어떻게 돌아가는지 파악해보기
  • 주석만 읽어도 돼요!
  • 위에 있는 callback handler, train control, train state, callback의 의미를 상기하면서 읽으면 조음.
class Trainer():
    def __init__(self, model, train_arg, compute_metrics, callbacks, train_set, eval_set, ...)
        # init하면 argument으로 받은 값들을 calss property으로 저장함. 
         self.model = model
         # 추가 기능 callbacks들은 CallbackHandler에 저장해두고, CallbackHandler 만을 가지고 있는다. 특정 조건과 시점에 callback의 추가기능을 사용할 것임.
         self.callback_handler = CallbackHandler(callbacks, ... )
         self.args = train_arg

        # 현재 학습 상태를 기록해둔다. 
        self.state = TrainerState()
        # state을 살펴보면서 제어를 해야 할 타이밍인지 여부를 기록해둔다.
        self.control = TrainerControl()
    
    def get_train_dataloader(self):
        ...
        train_dataset = self.train_dataset
        ...

        return DataLoader(train_dataset, self.arg의 다양한 조건들, ... )
    
    def get_eval_dataloader(self):
        ...
        eval_dataset = self.eval_dataset
        ...

        return DataLoader(eval_dataset, self.arg의 다양한 조건들, ... )

    def get_test_dataloader(self):
        ...
        test_dataset = self.test_dataset
        ...

        return DataLoader(test_dataset, self.arg의 다양한 조건들, ... )

    def train(self, resume_from_checkpoint, ...):
        # 만약 checkpoint가 있으면 그걸로 model을 대체함.
        resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint

        ...

        if resume_from_checkpoint is not None:
            ...
            # init에서 self.model이 기존 checkpoint model으로 대체 됨.
            self._load_state_dict_in_model(state_dict) 
            
        train_dataloader = self.get_train_dataloader()

        # 학습에 필요한 변수들, grad 초기화
        self.state.epoch = 0
        tr_loss = torch.tensor(0.0).to(args.device)
        model.zero_grad()

        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloade

        # 학습 직전에 필요한 추가기능을 여기서 실행시킴
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        # train arg에서 받은 arg으로 몇번 학습할지 설정
        num_train_epochs = math.ceil(args.num_train_epochs)
        epochs_trained = 0 # if resumed, change to the last epoch

        # epoch 반복
        for epoch in range(epochs_trained, num_train_epochs):
            epoch_iterator = train_dataloader

            # 각 미니배치 반복.
            # 1 step 학습 = 1개의 미니배치 단위 학습(32, 64, ...)
            for step, inputs in enumerate(epoch_iterator):
              tr_loss_step = self.training_step(model, inputs)
              tr_loss += tr_loss_step
              model.zero_grad()

              # 현재 학습 시점 업데이트
              self.state.epoch = epoch + (step + 1) / steps_in_epoch

              # 1 스텝 끝났을 때 필요한 추가기능 실행
              self.control = self.callback_handler.on_step_end(args, self.state, self.control)

              # 현재 스텝이 evaluatoin과 log을 해야하는 상황이면 eval 하고 log 수행
              # 이 이슈 해결하려면 여기를 봐야 함
              self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

            # 1 epoch 끝났을 때 필요한 추가기능 실행
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
            # 현재 epoch이 evaluatoin과 log을 해야하는 상황이면 eval 하고 log 수행
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

        # 학습 종료 후 결과 저장해서 log 찍음.
        metrics = {}
        metrics["train_loss"] = train_loss
        self.log(metrics)

        # 학습 종료 후 해야 하는 추가 기능 실행
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)
        
    # 현재 상황이 evaluatoin과 log을 해야하는 상황이면 eval 하고 log 수행
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
        # 지금 상황이 log 찍어야 하는 상황인가?
        # train_arguemnt에서 --log_step = 100 등에서 조건 만족되면 True
        if self.control.should_log:
            logs: Dict[str, float] = {}
        
            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)

            # 실제 log 찍는 함수. console, wandb 등이랑 모두 연결
            self.log(logs)

        metrics = None
        # 지금 상황이 evaluation 해야 하는 상황인가?
        # train_arguemnt에서 --eval_step = 100 등에서 조건 만족되면 True
        if self.control.should_evaluate:
            # 이 이슈 해결하려면 여기를 봐야 함 
            metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
    
    # 실제 log 찍는 함수. console, wandb 등이랑 모두 연결
    def log(self, logs: Dict[str, float]) -> None:
        # log 찍어야 하는 추가 기능을 실행.
        # wandb을 켰다면 callback hanlder 내부에서 연결되어 있음.
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:

        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        # eval_loop이 evaluation의 Trainer라고 보면 됨. eval_data 넣어서 loss 다 구해서 반환.
        # 이 이슈 해결하려면 여기를 봐야 함 
        output = eval_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix,
        )

        # 실제 log 찍는 함수. console, wandb 등이랑 모두 연결
        self.log(output.metrics)

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        ...
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))

       # 이슈 해결의 핵심. prefix에 eval을 달아준다.
        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

어떻게 log을 찍는가?

# train.py
    def evaluate(self, ... ):
        output = eval_loop(  compute_metrics, metric_key_prefix=metric_key_prefix, ... )
        self.log(output.metrics)

    # 실제 log 찍는 함수. console, wandb 등이랑 모두 연결
    def log(self, logs: Dict[str, float]) -> None:
        # log 찍어야 하는 추가 기능을 실행.
        # wandb을 켰다면 callback hanlder 내부에서 연결되어 있음.
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, log

추가 기능 callback의 내부 구조

를 알아야 함 ㅜㅜ

trainer을 init 할때 callback list을 받음. 받아서 callback_handler에 때려넣으면 callback_handler instance가 각 callback instance을 가지고 있음.

class Trainer():
    def __init__(self, model, train_arg, compute_metrics, callbacks, train_set, eval_set, ...)
         # 추가 기능 callbacks들은 CallbackHandler에 저장해두고, CallbackHandler 만을 가지고 있는다. 특정 조건과 시점에 callback의 추가기능을 사용할 것임.
         self.callback_handler = CallbackHandler(callbacks, ... )

여기서 wandb가 실행되고 있으면 wandb Callback도 같이 생성되어서 callback_handler에 저장되어 있다.

그리고 callback들이 실행되는 시점을 보면 on_ 으로 시작하는 prefix 함수들을 가지고 있다. 모든 Callback class는 TrainerCallback interface을 상속해서 구현하고 있음. 그리고 callback_handler 역시 이 함수들을 가지고 있다. 따라서 callback_handler가 on_ 함수를 실행하면 해당 on_을 가지고 있는 callback들이 모두 실행 됨.

e.g. self.callback_handler.on_train_begin(...) 을 실행하면 callback_handler에 있는 모든 callback instance들이 on_train_begin을 실행 함.

class TrainerCallback:
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the beginning of training.
        """
        pass

class CallbackHandler(TrainerCallback):

    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        control.should_training_stop = False
        return self.call_event("on_train_begin", args, state, control)

    def call_event(self, event, args, state, control, **kwargs):
        for callback in self.callbacks:
            result = getattr(callback, event)(
                args,
                state,
                control,
                model=self.model,
                tokenizer=self.tokenizer,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                train_dataloader=self.train_dataloader,
                eval_dataloader=self.eval_dataloader,
                **kwargs,
            )
            # A Callback can skip the return of `control` if it doesn't change it.
            if result is not None:
                control = result
        return control


class TensorBoardCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):

class WandbCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, model=None, **kwargs):

on_log

on_log도 마찬가지이다.

# train.py
    def evaluate(self, ... ):
        output = eval_loop(  compute_metrics, metric_key_prefix=metric_key_prefix, ... )
        self.log(output.metrics)

    # 실제 log 찍는 함수. console, wandb 등이랑 모두 연결
    def log(self, logs: Dict[str, float]) -> None:
        # log 찍어야 하는 추가 기능을 실행.
        # wandb을 켰다면 callback hanlder 내부에서 연결되어 있음.
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, log

위 함수에서 log을 실행하고 wandb가 켜져 있으면 wandb callback이 실행 됨.


def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
        else:
            new_d["train/" + k] = v
    return new_d


class WandbCallback(TrainerCallback):
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if self._wandb is None:
            return
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            logs = rewrite_logs(logs)
            self._wandb.log({**logs, "train/global_step": state.global_step})

입력으로 들어오는 log dictionary에 prefix으로 eval이 찍혀 있으면 eval으로 구분해서 eval section에 로그 찍음.

따라서

해결 방법

wandb log에 들어가는 self.log(metrics)을 실행할 때 metrics dictionary에 eval prefix을 추가하면 된다.

Comments