Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I don't understand the implementation of EfficientNMSFilter in fficientNMS plugin #3786

Open
demuxin opened this issue Apr 9, 2024 · 3 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@demuxin
Copy link

demuxin commented Apr 9, 2024

Description

I need to write the CPU implementation of NMS, and I'd like to refer to the code for EfficientNMS.

I know what the EfficientNMSFilter in the EfficientNMS plugin does is, get the category with the highest confidence for each anchor and get filtered out by the threshold.

This is the code of EfficientNMSFilter:

template <typename T>
__global__ void EfficientNMSFilter(EfficientNMSParameters param, const T* __restrict__ scoresInput,
    int* __restrict__ topNumData, int* __restrict__ topIndexData, int* __restrict__ topAnchorsData,
    T* __restrict__ topScoresData, int* __restrict__ topClassData)
{
    int elementIdx = blockDim.x * blockIdx.x + threadIdx.x;
    int imageIdx = blockDim.y * blockIdx.y + threadIdx.y;

    // Boundary Conditions
    if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize)
    {
        return;
    }

    // Shape of scoresInput: [batchSize, numAnchors, numClasses]
    int scoresInputIdx = imageIdx * param.numScoreElements + elementIdx;

    // For each class, check its corresponding score if it crosses the threshold, and if so select this anchor,
    // and keep track of the maximum score and the corresponding (argmax) class id
    T score = scoresInput[scoresInputIdx];
    if (gte_mp(score, (T) param.scoreThreshold))
    {
        // Unpack the class and anchor index from the element index
        int classIdx = elementIdx % param.numClasses;
        int anchorIdx = elementIdx / param.numClasses;

        // If this is a background class, ignore it.
        if (classIdx == param.backgroundClass)
        {
            return;
        }

        // Use an atomic to find an open slot where to write the selected anchor data.
        if (topNumData[imageIdx] >= param.numScoreElements)
        {
            return;
        }
        int selectedIdx = atomicAdd((unsigned int*) &topNumData[imageIdx], 1);
        if (selectedIdx >= param.numScoreElements)
        {
            topNumData[imageIdx] = param.numScoreElements;
            return;
        }

        // Shape of topScoresData / topClassData: [batchSize, numScoreElements]
        int topIdx = imageIdx * param.numScoreElements + selectedIdx;

        if (param.scoreBits > 0)
        {
            score = add_mp(score, (T) 1);
            if (gt_mp(score, (T) (2.f - 1.f / 1024.f)))
            {
                // Ensure the incremented score fits in the mantissa without changing the exponent
                score = (2.f - 1.f / 1024.f);
            }
        }

        topIndexData[topIdx] = selectedIdx;
        topAnchorsData[topIdx] = anchorIdx;
        topScoresData[topIdx] = score;
        topClassData[topIdx] = classIdx;
    }
}

Can you explain how EfficientNMSFilter selects the category with the highest confidence for anchar?

@zerollzeng
Copy link
Collaborator

@samurdhikaru ^ ^

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Apr 12, 2024
@zerollzeng
Copy link
Collaborator

I think if you read the code and readme carefully and you will find the answer :-)

@demuxin
Copy link
Author

demuxin commented Apr 16, 2024

@zerollzeng Hi, can you give a little hint, I've read the source code and still can't understand it, I'm new to this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants