megengine.functional.metric.topk_accuracy

topk_accuracy(logits, target, topk=1)[源代码]

根据给定的预测的logits和真实值标签计算分类准确率。

参数
  • logits (Tensor) – 模型预测值,形为 [batch_size, num_classes] ,表示其属于各类别(class)的概率。

  • target (Tensor) – 真实值标签,int32类型的一维张量。

  • topk (Union[int, Iterable[int]]) – 指定前k个值,可以是整型数,也可以是整型数构成的元组。 默认: 1

返回类型

Union[Tensor, Iterable[Tensor]]

返回

表示分类准确率的张量(一个或多个),数值介于0.0到1.0之间。

例如:

import numpy as np
from megengine import tensor
import megengine.functional as F

logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10))
target = tensor(np.arange(8, dtype=np.int32))
top1, top5 = F.metric.topk_accuracy(logits, target, (1, 5))
print(top1.numpy(), top5.numpy())

输出:

0.0 0.375