# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Union
import numpy as np
from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
from .math import topk as _topk
from .tensor import broadcast_to, transpose
__all__ = [
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
) -> Union[Tensor, Iterable[Tensor]]:
r"""Calculates the classification accuracy given predicted logits and ground-truth labels.
logits: model predictions of shape `[batch_size, num_classes]`,
representing the probability (likelyhood) of each class.
target: ground-truth labels, 1d tensor of int32.
topk: specifies the topk values, could be an int or tuple of ints. Default: 1
tensor(s) of classification accuracy between 0.0 and 1.0.
if isinstance(topk, int):
topk = (topk,)
_, pred = _topk(logits, k=max(topk), descending=True)
accs = 
for k in topk:
correct = pred[:, :k].detach() == broadcast_to(
transpose(target, (0, "x")), (target.shape, k)
accs.append(correct.astype(np.float32).sum() / target.shape)
if len(topk) == 1: # type: ignore[arg-type]
accs = accs