megengine.functional.nn.ctc_loss

ctc_loss(pred, pred_lengths, label, label_lengths, blank=0, reduction='mean')[源代码]

计算 Connectionist Temporal Classification loss 。

参数
  • pred (Tensor) – 概率张量,其尺寸为 (T, N, C),其中 T 是 input 长度,N 是 batch 个数,C 是类别数量(包括 blank)。

  • pred_lengths (Tensor) – pred 中每个序列的点数,尺寸为 (N, )。

  • label (Tensor) – groundtruth 标签,包含每个序列的每个点的 groundtruth 的位置,blank 不应包含在其中。尺寸是 (N, S) 或者 sum(label_lengths))。

  • label_lengths (Tensor) – groundtruth 的每个序列的点数,尺寸是 (N, )。

  • blank (int) – blank 的个数,默认值为 0。

  • reduction (str) – 计算输出的模式:none | mean | sum。默认值为:mean

返回类型

Tensor

返回

损失值。

实际案例

>>> pred = Tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
>>> pred_lengths = Tensor([2, 2])
>>> label = Tensor([1, 1])
>>> label_lengths = Tensor([1, 1])
>>> F.nn.ctc_loss(pred, pred_lengths, label, label_lengths)
Tensor(0.1504417, device=xpux:0)