損失のカスタマイズ¶
MMDetection は、ユーザーにさまざまな損失関数を提供します。しかし、デフォルトの構成は、異なるデータセットやモデルには適用できない場合があるため、ユーザーは新しい状況に適応するために特定の損失を変更したい場合があります。
このチュートリアルでは、最初に損失の計算パイプラインについて詳しく説明し、次に各ステップの変更方法についていくつかの指示を与えます。変更は、微調整と重み付けに分類できます。
損失の計算パイプライン¶
入力予測とターゲット、および重みが与えられると、損失関数は入力テンソルを最終的な損失スカラーにマッピングします。マッピングは、次の5つのステップに分けることができます。
正と負のサンプルをサンプリングするためのサンプリング方法を設定します。
損失カーネル関数によって、要素ごとまたはサンプルごとの損失を取得します。
損失に重みテンソルを要素ごとに乗算して重み付けします。
損失テンソルをスカラーに縮小します。
損失にスカラーを乗算して重み付けします。
サンプリング方法の設定 (ステップ 1)¶
一部の損失関数では、正と負のサンプルの間の不均衡を回避するためにサンプリング戦略が必要です。
たとえば、RPN ヘッドで CrossEntropyLoss
を使用する場合、train_cfg
で RandomSampler
を設定する必要があります。
train_cfg=dict(
rpn=dict(
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False))
Focal Loss、GHMC、QualityFocalLoss など、正と負のサンプルのバランスメカニズムを備えた他の損失の場合、サンプラーは不要になります。
損失の微調整¶
損失の微調整はステップ 2、4、5 とより関係があり、ほとんどの変更は構成で指定できます。ここでは、例としてFocal Loss (FL)を使用します。次のコードスニペットは、FL の構築方法と構成をそれぞれ示しており、実際には一対一に対応しています。
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0):
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)
ハイパーパラメータの微調整 (ステップ 2)¶
gamma
と beta
は、Focal Loss の 2 つのハイパーパラメータです。たとえば、gamma
の値を 1.5 に、alpha
の値を 0.5 に変更したい場合は、構成で次のように指定できます。
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.5,
alpha=0.5,
loss_weight=1.0)
縮小方法の微調整 (ステップ 3)¶
FL のデフォルトの縮小方法は mean
です。たとえば、縮小を mean
から sum
に変更したい場合は、構成で次のように指定できます。
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0,
reduction='sum')
損失重みの微調整 (ステップ 5)¶
ここでの損失重みは、マルチタスク学習における異なる損失の重みを制御するスカラーです。たとえば、分類損失の損失重みを 0.5 に変更したい場合は、構成で次のように指定できます。
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=0.5)
損失の重み付け (ステップ 3)¶
損失の重み付けとは、損失を要素ごとに重み付けし直すことを意味します。より具体的には、損失テンソルに同じ形状の重みテンソルを乗算します。その結果、損失の異なるエントリを異なるようにスケーリングでき、要素ごとと呼ばれます。損失の重みはモデルによって異なり、コンテキストに大きく関連していますが、全体として、分類損失の label_weights
と bbox 回帰損失の bbox_weights
の 2 種類の損失重みがあります。これらは、対応するヘッドの get_target
メソッドで確認できます。ここでは、ATSSHead を例にとります。これは AnchorHead を継承していますが、異なる label_weights
および bbox_weights
を生成する get_targets
メソッドを上書きします。
class ATSSHead(AnchorHead):
...
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):