便利なフック¶
MMDetection と MMEngine は、ログフック、NumClassCheckHook
など、さまざまな便利なフックを提供しています。このチュートリアルでは、MMDetection に実装されているフックの機能と使用方法を紹介します。MMEngine でのフックの使用方法については、MMEngine の API ドキュメントを参照してください。
CheckInvalidLossHook¶
NumClassCheckHook¶
MemoryProfilerHook¶
メモリプロファイラフック は、仮想メモリ、スワップメモリ、現在のプロセスのメモリなどのメモリ情報を記録します。このフックは、システムのメモリ使用量を把握し、潜在的なメモリリークバグを発見するのに役立ちます。このフックを使用するには、ユーザーは最初に pip install memory_profiler psutil
によって memory_profiler
と psutil
をインストールする必要があります。
使用方法¶
このフックを使用するには、設定ファイルに次のコードを追加する必要があります。
custom_hooks = [
dict(type='MemoryProfilerHook', interval=50)
]
結果¶
トレーニング中に、以下のように MemoryProfilerHook
によって記録されたログメッセージが表示されます。
The system has 250 GB (246360 MB + 9407 MB) of memory and 8 GB (5740 MB + 2452 MB) of swap memory in total. Currently 9407 MB (4.4%) of memory and 5740 MB (29.9%) of swap memory were consumed. And the current training process consumed 5434 MB of memory.
2022-04-21 08:49:56,881 - mmengine - INFO - Memory information available_memory: 246360 MB, used_memory: 9407 MB, memory_utilization: 4.4 %, available_swap_memory: 5740 MB, used_swap_memory: 2452 MB, swap_memory_utilization: 29.9 %, current_process_memory: 5434 MB
SetEpochInfoHook¶
SyncNormHook¶
SyncRandomSizeHook¶
YOLOXLrUpdaterHook¶
YOLOXModeSwitchHook¶
カスタムフックを実装する方法¶
一般的に、モデルのトレーニングの開始から終了まで、フックを挿入できるポイントは20箇所あります。ユーザーはカスタムフックを実装し、トレーニングプロセス中のさまざまなポイントに挿入して、必要な処理を実行できます。
グローバルポイント:
before_run
、after_run
トレーニングのポイント:
before_train
、before_train_epoch
、before_train_iter
、after_train_iter
、after_train_epoch
、after_train
検証のポイント:
before_val
、before_val_epoch
、before_val_iter
、after_val_iter
、after_val_epoch
、after_val
テストのポイント:
before_test
、before_test_epoch
、before_test_iter
、after_test_iter
、after_test_epoch
、after_test
その他のポイント:
before_save_checkpoint
、after_save_checkpoint
たとえば、損失をチェックし、損失が NaN になったときにトレーニングを終了するフックを実装できます。そのためには、3つの手順があります。
MMEngine の
Hook
クラスを継承する新しいフックを実装し、n
回のトレーニング反復ごとに損失が NaN になるかどうかをチェックするafter_train_iter
メソッドを実装します。実装されたフックは、以下のコードに示すように、
@HOOKS.register_module()
によってHOOKS
に登録する必要があります。設定ファイルに
custom_hooks = [dict(type='MemoryProfilerHook', interval=50)]
を追加します。
from typing import Optional
import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
"""Check invalid loss hook.
This hook will regularly check whether the loss is valid
during training.
Args:
interval (int): Checking interval (every k iterations).
Default: 50.
"""
def __init__(self, interval: int = 50) -> None:
self.interval = interval
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[dict] = None) -> None:
"""Regularly check whether the loss is valid every n iterations.
Args:
runner (:obj:`Runner`): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict, Optional): Data from dataloader.
Defaults to None.
outputs (dict, Optional): Outputs from model. Defaults to None.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']), \
runner.logger.info('loss become infinite or NaN!')
カスタムフックの実装の詳細については、customize_runtime を参照してください。