ショートカット

便利なフック

MMDetection と MMEngine は、ログフック、NumClassCheckHook など、さまざまな便利なフックを提供しています。このチュートリアルでは、MMDetection に実装されているフックの機能と使用方法を紹介します。MMEngine でのフックの使用方法については、MMEngine の API ドキュメントを参照してください。

CheckInvalidLossHook

NumClassCheckHook

MemoryProfilerHook

メモリプロファイラフック は、仮想メモリ、スワップメモリ、現在のプロセスのメモリなどのメモリ情報を記録します。このフックは、システムのメモリ使用量を把握し、潜在的なメモリリークバグを発見するのに役立ちます。このフックを使用するには、ユーザーは最初に pip install memory_profiler psutil によって memory_profilerpsutil をインストールする必要があります。

使用方法

このフックを使用するには、設定ファイルに次のコードを追加する必要があります。

custom_hooks = [
    dict(type='MemoryProfilerHook', interval=50)
]

結果

トレーニング中に、以下のように MemoryProfilerHook によって記録されたログメッセージが表示されます。

The system has 250 GB (246360 MB + 9407 MB) of memory and 8 GB (5740 MB + 2452 MB) of swap memory in total. Currently 9407 MB (4.4%) of memory and 5740 MB (29.9%) of swap memory were consumed. And the current training process consumed 5434 MB of memory.
2022-04-21 08:49:56,881 - mmengine - INFO - Memory information available_memory: 246360 MB, used_memory: 9407 MB, memory_utilization: 4.4 %, available_swap_memory: 5740 MB, used_swap_memory: 2452 MB, swap_memory_utilization: 29.9 %, current_process_memory: 5434 MB

SetEpochInfoHook

SyncNormHook

SyncRandomSizeHook

YOLOXLrUpdaterHook

YOLOXModeSwitchHook

カスタムフックを実装する方法

一般的に、モデルのトレーニングの開始から終了まで、フックを挿入できるポイントは20箇所あります。ユーザーはカスタムフックを実装し、トレーニングプロセス中のさまざまなポイントに挿入して、必要な処理を実行できます。

  • グローバルポイント: before_runafter_run

  • トレーニングのポイント: before_trainbefore_train_epochbefore_train_iterafter_train_iterafter_train_epochafter_train

  • 検証のポイント: before_valbefore_val_epochbefore_val_iterafter_val_iterafter_val_epochafter_val

  • テストのポイント: before_testbefore_test_epochbefore_test_iterafter_test_iterafter_test_epochafter_test

  • その他のポイント: before_save_checkpointafter_save_checkpoint

たとえば、損失をチェックし、損失が NaN になったときにトレーニングを終了するフックを実装できます。そのためには、3つの手順があります。

  1. MMEngine の Hook クラスを継承する新しいフックを実装し、n 回のトレーニング反復ごとに損失が NaN になるかどうかをチェックする after_train_iter メソッドを実装します。

  2. 実装されたフックは、以下のコードに示すように、@HOOKS.register_module() によって HOOKS に登録する必要があります。

  3. 設定ファイルに custom_hooks = [dict(type='MemoryProfilerHook', interval=50)] を追加します。

from typing import Optional

import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner

from mmdet.registry import HOOKS


@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.

    This hook will regularly check whether the loss is valid
    during training.

    Args:
        interval (int): Checking interval (every k iterations).
            Default: 50.
    """

    def __init__(self, interval: int = 50) -> None:
        self.interval = interval

    def after_train_iter(self,
                         runner: Runner,
                         batch_idx: int,
                         data_batch: Optional[dict] = None,
                         outputs: Optional[dict] = None) -> None:
        """Regularly check whether the loss is valid every n iterations.

        Args:
            runner (:obj:`Runner`): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict, Optional): Data from dataloader.
                Defaults to None.
            outputs (dict, Optional): Outputs from model. Defaults to None.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']), \
                runner.logger.info('loss become infinite or NaN!')

カスタムフックの実装の詳細については、customize_runtime を参照してください。