diff --git a/CHANGELOG.md b/CHANGELOG.md index d7ce23a39e..138161065b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,31 @@ # NVIDIA CUTLASS Changelog -# CUTLASS 2.0 +# CUTLASS 2.x + +## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) + * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) + * Fast Tensor Core operations: + * Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends) + * Tensor Float 32, BFloat16, and double-precision data types + * Mixed integer data types (int8, int4, bin1) + * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) + * Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required) + * Features: + * SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM + * Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32 + * Gaussian complex GEMMs using 3m complex multiply algorithm + * Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions + * Policy updates: + * [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features + * Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON` + +## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06) + * BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library) + * API to launch compiled kernel instances for GEMM and planar complex GEMM + * Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores + * Computes complex matrix products on matrices stored as disjoint real and imaginary parts + * [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu) + * Minor enhancements and bug fixes ## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19) * Substantially refactored for @@ -22,7 +47,7 @@ * Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions * Batched GEMM operations * Complex-valued GEMMs - * Note: a host compiler supporting C++11 or greater is required. + * **Note: a host compiler supporting C++11 or greater is required.** # CUTLASS 1.x @@ -76,7 +101,7 @@ ## Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 85d395939f..b6583747c6 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -32,15 +32,14 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -project(CUTLASS VERSION 2.0.0 LANGUAGES CXX) +project(CUTLASS VERSION 2.2.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) find_package(Doxygen QUIET) # -# CUTLASS 2.0 requires C++11 +# CUTLASS 2.x requires C++11 # - set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) @@ -49,7 +48,7 @@ if(CUTLASS_NATIVE_CUDA) set(CMAKE_CUDA_STANDARD 11) set(CMAKE_CUDA_STANDARD_REQUIRED ON) else() - string(APPEND NVCC_FLAGS " --std=c++11") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11) endif() if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) @@ -58,13 +57,26 @@ endif() message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}") -if(${CMAKE_PROJECT_NAME} MATCHES ${PROJECT_NAME}) - set(_CUTLASS_ENABLE_TESTS ON) +set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") + +if(CUTLASS_ENABLE_HEADERS_ONLY) + set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) + set(CUTLASS_ENABLE_TOOLS_INIT OFF) else() - set(_CUTLASS_ENABLE_TESTS OFF) + set(CUTLASS_ENABLE_EXAMPLES_INIT ON) + set(CUTLASS_ENABLE_TOOLS_INIT ON) endif() -set(CUTLASS_ENABLE_TESTS ${_CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS Tests") +set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples") +set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") + +if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) + set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT}) +else() + set(CUTLASS_ENABLE_TESTS_INIT OFF) +endif() + +set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") if (CUTLASS_ENABLE_TESTS) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake) @@ -72,7 +84,7 @@ endif() set(CUTLASS_NVCC_ARCHS_SUPPORTED "") if (NOT CUDA_VERSION VERSION_LESS 7.5) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 50) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53) endif() if (NOT CUDA_VERSION VERSION_LESS 8.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61) @@ -86,31 +98,25 @@ endif() if (NOT CUDA_VERSION VERSION_LESS 10.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75) endif() - -if(CUDA_COMPILER MATCHES "[Cc]lang") - if(NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) - message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) - endif() - if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) - message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") - endif() +if (NOT CUDA_VERSION VERSION_LESS 11.0) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80) endif() - set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") # Special policy introduced in CMake 3.13 if (POLICY CMP0076) cmake_policy(SET CMP0076 NEW) -endif() +endif() -# check if the configuration is supported -if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) +if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 ) message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!") endif() include(GNUInstallDirs) +link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) + ################################################################################################### # # Configure CMake variables @@ -120,11 +126,14 @@ include(GNUInstallDirs) message(STATUS "CUDA Compilation Architectures: ${CUTLASS_NVCC_ARCHS_ENABLED}") if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES)) - # By default we want to build in Release mode to ensure that we're getting best performance. + # By default we want to build in Release mode to ensure that we're getting best performance. set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release") endif() +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries") + if(WIN32) # On Windows we link against the shared (DLL) runtime. Change gtest settings to match this. set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE) @@ -132,29 +141,35 @@ endif() if (WIN32) # Enable more warnings and treat as errors - string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX) # Disable warning on Unicode characters - string(APPEND NVCC_FLAGS " -Xcompiler /wd4819") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819) # Disable excess x86 floating point precision that can lead to results being labeled incorrectly - string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/fp:strict) endif(WIN32) if (${CUTLASS_NVCC_VERBOSE}) - string(APPEND NVCC_FLAGS " -v") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v) endif() set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.") set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") -set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.") -set(CUTLASS_LIBRARY_KERNELS "128x128" CACHE STRING "Comma delimited list of kernel name filters. Default '' means all kernels are enabled.") +set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") + +# +# CUTLASS generator cmake configuration +# +set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.") +set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.") + # Test Levels L0, L1, L2 set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2) -string(APPEND NVCC_FLAGS " -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}") +list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}) # # CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations. @@ -166,7 +181,7 @@ else() set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON) endif() -set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL +set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL "Enable PTX mma instruction for collective matrix multiply operations.") # @@ -182,7 +197,7 @@ set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CAC # ... # if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer - string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --compiler-options=-fsanitize=address --compiler-options=-fno-omit-frame-pointer) string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address") endif() @@ -192,85 +207,127 @@ endif() # ################################################################################################### -foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED}) - if(CUTLASS_NVCC_EMBED_CUBIN) - string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=sm_${ARCH}") - endif() - if(CUTLASS_NVCC_EMBED_PTX) - string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=compute_${ARCH}") - endif() - string(APPEND CLANG_FLAGS " --cuda-gpu-arch=sm_${ARCH}") -endforeach() - if(CUTLASS_NVCC_EMBED_PTX) - string(APPEND CLANG_FLAGS " --cuda-include-ptx=all") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all) endif() if (CUTLASS_ENABLE_TENSOR_CORE_MMA) - string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1") + list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR}) - string(APPEND NVCC_FLAGS " --keep") # --keep-dir may not work with nvcc for some directories. - string(APPEND CLANG_FLAGS " -save-temps=${CUTLASS_NVCC_KEEP_DIR}") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories. + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR}) endif() -if (CUTLASS_ENABLE_F16C) - string(APPEND COMPILER_FLAGS " -DCUTLASS_ENABLE_F16C=1") +if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING) + list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1) if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) - string(APPEND NVCC_FLAGS " -Xcompiler -mf16c") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c) elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC")) - string(APPEND NVCC_FLAGS " -Xcompiler /arch:AVX2") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2) endif() endif() -string(APPEND NVCC_FLAGS " -lineinfo") - -string(APPEND CLANG_FLAGS " -gmlt") +list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-Wconversion>) +list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-fno-strict-aliasing>) -if (UNIX) - string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion") - string(APPEND NVCC_FLAGS " -Xcompiler -fno-strict-aliasing") +# Don't leak lineinfo in release builds +if (NOT CMAKE_BUILD_TYPE MATCHES "Release") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo) endif() if(CUDA_COMPILER MATCHES "[Cc]lang") - string(APPEND CLANG_FLAGS " --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}") - string(APPEND CLANG_FLAGS " -mllvm -pragma-unroll-threshold=100000") - string(APPEND CLANG_FLAGS " -mllvm -unroll-threshold=5000") - string(APPEND CLANG_FLAGS " -Wno-unused-command-line-argument") + if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) + message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) + endif() + + if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) + message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") + endif() + + list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument) string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION}) list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR) list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR) - string(APPEND CLANG_FLAGS " -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}) + # needed for libcublasLt.so in case it's installed in the same location as libcudart.so # dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags) # Otherwise linker uses RUNPATH and that does not propagate to loaded libs. - string(APPEND CLANG_FLAGS " -Wl,--disable-new-dtags") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags) link_libraries(nvidia::cudart) endif() -if(CUDA_COMPILER MATCHES "[Cc]lang") - string(APPEND CMAKE_CXX_FLAGS "${COMMON_FLAGS} ${CLANG_FLAGS}") - string(APPEND CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${CLANG_FLAGS_RELEASE}") - string(APPEND CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${CLANG_FLAGS_RELWITHDEBINFO}") - string(APPEND CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${CLANG_FLAGS_DEBUG}") -elseif (CUTLASS_NATIVE_CUDA) - string(APPEND CMAKE_CUDA_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}") - string(APPEND CMAKE_CUDA_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}") - string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}") - string(APPEND CMAKE_CUDA_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}") -else() - string(APPEND CUDA_NVCC_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}") - string(APPEND CUDA_NVCC_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}") - string(APPEND CUDA_NVCC_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}") - string(APPEND CUDA_NVCC_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}") -endif() +function(cutlass_apply_cuda_gencode_flags TARGET) + + set(NVCC_FLAGS) + set(CLANG_FLAGS) + foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED}) + list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH}) + set(CODES) + if(CUTLASS_NVCC_EMBED_CUBIN) + list(APPEND CODES sm_${ARCH}) + endif() + if(CUTLASS_NVCC_EMBED_PTX) + list(APPEND CODES compute_${ARCH}) + endif() + list(JOIN CODES "," CODES_STR) + list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) + endforeach() + + if (CUDA_COMPILER MATCHES "[Cc]lang") + target_compile_options( + ${TARGET} + PRIVATE + $<$:${CLANG_FLAGS}> + ) + else() + target_compile_options( + ${TARGET} + PRIVATE + $<$:${NVCC_FLAGS}> + ) + endif() + +endfunction() + +function(cutlass_apply_standard_compile_options TARGET) + + if(CUDA_COMPILER MATCHES "[Cc]lang") + set(CUDA_COMPILE_LANGUAGE CXX) + set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS}) + set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE}) + set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO}) + set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG}) + else() + set(CUDA_COMPILE_LANGUAGE CUDA) + set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS}) + set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE}) + set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO}) + set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG}) + endif() + + target_compile_options( + ${TARGET} + PRIVATE + $<$:${_FLAGS}> + $<$:$<$:${_FLAGS_RELEASE}>> + $<$:$<$:${_FLAGS_RELWITHDEBINFO}>> + $<$:$<$:${_FLAGS_DEBUG}>> + ) + +endfunction() # # The following items should eventually be pushed into cutlass/CMakeLists.txt @@ -324,8 +381,8 @@ if (NOT DEFINED CUTLASS_REVISION) endif() configure_file( - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in - ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in + ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h @ONLY) target_include_directories( @@ -338,8 +395,8 @@ target_include_directories( ) install( - DIRECTORY - ${CUTLASS_INCLUDE_DIR}/ + DIRECTORY + ${CUTLASS_INCLUDE_DIR}/ ${CMAKE_CURRENT_BINARY_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) @@ -399,27 +456,6 @@ endif() ################################################################################ -set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") - -if(CUTLASS_ENABLE_HEADERS_ONLY) - set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) - set(CUTLASS_ENABLE_TOOLS_INIT OFF) -else() - set(CUTLASS_ENABLE_EXAMPLES_INIT ON) - set(CUTLASS_ENABLE_TOOLS_INIT ON) -endif() - -set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples") -set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") - -if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) - set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT}) -else() - set(CUTLASS_ENABLE_TESTS_INIT OFF) -endif() - -set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") - if(CUTLASS_ENABLE_TOOLS) add_subdirectory(tools) endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index fc95674d2a..f8778b80e6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -9,15 +9,17 @@ This is the official list of CUTLASS developers and contributors. ## DEVELOPERS Andrew Kerr Haicheng Wu -Naila Farooqui +Manish Gupta Dustyn Blasig Pradeep Ramani -Manish Gupta -Aditya Atluri +Naila Farooqui +Piotr Majcher Paul Springer -David Tanner -Scott Yokim Jin Wang +Scott Yokim +Markus Hohnerbach +Aditya Atluri +David Tanner ## CONTRIBUTORS Timothy Costa @@ -25,12 +27,10 @@ Julien Demouth Brian Fahs Michael Goldfarb Mostafa Hagog -Markus Hohnerbach Fei Hu Alan Kaatz Tina Li Timmy Liu -Piotr Majcher Duane Merrill Kevin Siu Markus Tavenrath diff --git a/CUDA.cmake b/CUDA.cmake index 6978a5187a..b8b343a723 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -39,23 +39,27 @@ if(CUTLASS_NATIVE_CUDA) enable_language(CUDA) + if(NOT CUDA_VERSION) + set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) + endif() + if(NOT CUDA_TOOLKIT_ROOT_DIR) + get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) + endif() + else() find_package(CUDA REQUIRED) + # We workaround missing variables with the native flow by also finding the CUDA toolkit the old way. -endif() + if(NOT CMAKE_CUDA_COMPILER_VERSION) + set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION}) + endif() -if(NOT CUDA_VERSION) - set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) -endif() -if(NOT CUDA_TOOLKIT_ROOT_DIR) - get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) endif() if (CUDA_VERSION VERSION_LESS 9.2) message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.") endif() - if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang") set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") @@ -74,7 +78,7 @@ find_library( # in the CUDA toolkit we're building against. ) -if(CUDART_LIBRARY) +if(NOT TARGET cudart AND CUDART_LIBRARY) message(STATUS "CUDART: ${CUDART_LIBRARY}") @@ -95,6 +99,10 @@ if(CUDART_LIBRARY) ${CUDART_LIBRARY} ) +elseif(TARGET cudart) + + message(STATUS "CUDART: Already Found") + else() message(STATUS "CUDART: Not Found") @@ -116,7 +124,7 @@ find_library( # in the CUDA toolkit we're building against. ) -if(CUDA_DRIVER_LIBRARY) +if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY) message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}") @@ -137,6 +145,10 @@ if(CUDA_DRIVER_LIBRARY) ${CUDA_DRIVER_LIBRARY} ) +elseif(TARGET cuda_driver) + + message(STATUS "CUDA Driver: Already Found") + else() message(STATUS "CUDA Driver: Not Found") @@ -156,7 +168,7 @@ find_library( # in the CUDA toolkit we're building against. ) -if(NVRTC_LIBRARY) +if(NOT TARGET nvrtc AND NVRTC_LIBRARY) message(STATUS "NVRTC: ${NVRTC_LIBRARY}") @@ -177,6 +189,10 @@ if(NVRTC_LIBRARY) ${NVRTC_LIBRARY} ) +elseif(TARGET nvrtc) + + message(STATUS "NVRTC: Already Found") + else() message(STATUS "NVRTC: Not Found") @@ -190,55 +206,144 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) function(cutlass_correct_source_file_language_property) if(CUDA_COMPILER MATCHES "clang") foreach(File ${ARGN}) - if(${File} MATCHES ".*\.cu$") + if(File MATCHES ".*\.cu$") set_source_files_properties(${File} PROPERTIES LANGUAGE CXX) endif() endforeach() endif() endfunction() -function(cutlass_add_library) +set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation") +set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files") - set(options INTERFACE STATIC SHARED OBJECT) - set(oneValueArgs) +function(cutlass_unify_source_files TARGET_ARGS_VAR) + + set(options) + set(oneValueArgs BATCH_SOURCES BATCH_SIZE) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang" OR __INTERFACE) - cutlass_correct_source_file_language_property(${ARGN}) - add_library(${ARGN}) + if (NOT DEFINED TARGET_ARGS_VAR) + message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required") + endif() + + if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE) + set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE}) + endif() + + if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1) + + set(CUDA_FILE_ARGS) + set(TARGET_SOURCE_ARGS) + + foreach(ARG ${__UNPARSED_ARGUMENTS}) + if(${ARG} MATCHES ".*\.cu$") + list(APPEND CUDA_FILE_ARGS ${ARG}) + else() + list(APPEND TARGET_SOURCE_ARGS ${ARG}) + endif() + endforeach() + + list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS) + while(NUM_CUDA_FILE_ARGS GREATER 0) + list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH) + string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}") + string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH) + set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu) + message(STATUS "Generating ${BATCH_FILE}") + file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n") + foreach(CUDA_FILE ${CUDA_FILE_BATCH}) + get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE) + file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n") + endforeach() + list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE}) + if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE) + break() + endif() + list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS) + list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS) + endwhile() + + else() + + set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + + endif() + + set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE) + +endfunction() + +function(cutlass_add_library NAME) + + set(options) + set(oneValueArgs EXPORT_NAME) + set(multiValueArgs) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + + if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") + cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) + add_library(${NAME} ${TARGET_SOURCE_ARGS}) else() set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${ARGN}) + cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS}) + endif() + + cutlass_apply_standard_compile_options(${NAME}) + cutlass_apply_cuda_gencode_flags(${NAME}) + + target_compile_features( + ${NAME} + INTERFACE + cxx_std_11 + ) + + if(__EXPORT_NAME) + add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME}) + set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME}) endif() endfunction() -function(cutlass_add_executable) +function(cutlass_add_executable NAME) set(options) set(oneValueArgs) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${ARGN}) - add_executable(${ARGN}) + cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) + add_executable(${NAME} ${TARGET_SOURCE_ARGS}) else() set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_executable(${ARGN}) + cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS}) endif() + cutlass_apply_standard_compile_options(${NAME}) + cutlass_apply_cuda_gencode_flags(${NAME}) + + target_compile_features( + ${NAME} + INTERFACE + cxx_std_11 + ) + endfunction() -function(cutlass_target_sources) +function(cutlass_target_sources NAME) set(options) set(oneValueArgs) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cutlass_correct_source_file_language_property(${ARGN}) - target_sources(${ARGN}) + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) + target_sources(${NAME} ${TARGET_SOURCE_ARGS}) endfunction() diff --git a/LICENSE.txt b/LICENSE.txt index 283345b553..64a49d680b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2017 - 2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.md b/README.md index 3b5f472892..b0a91e77c6 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.0 +# CUTLASS 2.2 -_CUTLASS 2.0 - November 2019_ +_CUTLASS 2.2 - June 2020_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. @@ -17,14 +17,36 @@ and applications. To support a wide variety of applications, CUTLASS provides extensive support for mixed-precision computations, providing specialized data-movement and multiply-accumulate abstractions for half-precision floating -point (FP16), single-precision floating point (FP32), double-precision floating +point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32), +single-precision floating point (FP32), double-precision floating point (FP64) types, integer data types (4b and 8b), and binary data types (1b). -Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations for + +Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations targeting the programmable, high-throughput _Tensor Cores_ implemented by -NVIDIA's Volta and Turing architectures. +NVIDIA's Volta, Turing, and Ampere architectures. See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. +See the [functionality listing](media/docs/functionality.md) for the list of operations +supported at each level of the execution model hierarchy. + +# What's New in CUTLASS 2.2 + +CUTLASS 2.2 is a significant update to CUTLASS adding: + +- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) +- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types +- Deep software pipelines using asynchronous copy +- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) +- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) + +# What's New in CUTLASS 2.1 + +CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding: + +- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores +- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library) + # What's New in CUTLASS 2.0 CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer: @@ -33,10 +55,7 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to - Robust and durable templates that reliably span the design space - Encapsulated functionality that may be reusable in other contexts -See the [CHANGELOG](CHANGELOG.md) for more details. - -See the [functionality listing](media/docs/functionality.md) for the list of operations -supported at each level of the execution model hierarchy. +**See the [CHANGELOG](CHANGELOG.md) for more details.** # Performance @@ -45,15 +64,15 @@ supported at each level of the execution model hierarchy. CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels, they exhibit performance comparable to cuBLAS for scalar GEMM computations. The above figure shows CUTLASS performance relative to cuBLAS -for large matrix dimensions on an NVIDIA GeForce 2080 Ti and an NVIDIA TitanV -using CUDA 10.2. Tensor Core operations are implemented using CUDA's +for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV +using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's [mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). # Compatibility CUTLASS requires a C++11 host compiler and -performs best when built with the [CUDA 10.2 Toolkit](https://developer.nvidia.com/cuda-toolkit). -It is compatible with CUDA 9.2, CUDA 10.0, and CUDA 10.1. +performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit). +It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2. We have tested the following environments. @@ -62,27 +81,28 @@ We have tested the following environments. | Windows 10 | Microsoft Visual Studio 2015| | | Microsoft Visual Studio 2017| | Ubuntu 16.04 | GCC 5.4.0 | -| Ubuntu 18.04 | GCC 7.3.0 | +| Ubuntu 18.04 | GCC 7.5.0 | Additionally, CUTLASS may be built with clang. See [these instructions](media/docs/quickstart.md#clang) for more details. CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on -any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU. - -|**GPU**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**| -|---|---|---| -|NVIDIA GeForce 1080|9.2| | -|NVIDIA TitanXP|9.2| | -|NVIDIA Tesla P100|9.2| | -|NVIDIA Tesla V100|9.2|10.1| -|NVIDIA TitanV|9.2|10.1| -|NVIDIA GeForce RTX 2080 TI, 2080, 2070|10.0|10.2| -|NVIDIA Tesla T4|10.0|10.2| +any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. + +|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**| +|---|---|---|---| +|NVIDIA Tesla P100|6.0|9.2| | +|NVIDIA GeForce 1080|6.1|9.2| | +|NVIDIA TitanXP|6.1|9.2| | +|NVIDIA Tesla V100|7.0|9.2|10.1| +|NVIDIA TitanV|7.0|9.2|10.1| +|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2| +|NVIDIA Tesla T4|7.5|10.0|10.2| +|NVIDIA A100|8.0|11.0|11.0| # Documentation -CUTLASS 2.0 is described in the following documents and the accompanying +CUTLASS 2.2 is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). - [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS @@ -116,7 +136,7 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc ``` Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels -for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify +for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify the architectures to build CUTLASS for by changing the CMake configuration setting `CUTLASS_NVCC_ARCHS`. @@ -177,7 +197,7 @@ include/ # client applications should target this directory ### CUTLASS SDK Examples -CUTLASS SDK examples apply CUTLASS templates to implement basic computations. +[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations. ``` examples/ @@ -198,12 +218,23 @@ examples/ 07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores 08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores + + 10_planar_complex/ # example demonstrating planar complex GEMM kernels + + 11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes + + 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu + + 13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel ``` ### Tools ``` tools/ library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates + include/ + cutlass/ + library/ profiler/ # CUTLASS Profiler - command-line utility for executing operations in the # CUTLASS Library @@ -240,29 +271,32 @@ $ make cutlass_profiler -j Example command line for profiling SGEMM kernels is as follows: ``` -$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096 +$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096 ============================= Problem ID: 1 - Provider: CUTLASS - Operation: cutlass_simt_sgemm_128x128_nn + Provider: CUTLASS + OperationKind: gemm + Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1 + + Status: Success + Verification: ON + Disposition: Passed - Disposition: Passed - Status: Success + cuBLAS: Passed - Arguments: --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 \ - --split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 \ - --stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 \ - --max_cc=1024 + Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \ + --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \ + --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024 - Bytes: 52428800 bytes - FLOPs: 146064539648 flops + Bytes: 180355072 bytes + FLOPs: 115992428544 flops - Runtime: 10.5424 ms - Memory: 4.63158 GiB/s + Runtime: 6.73655 ms + Memory: 24.934 GiB/s - Math: 13854.9 GFLOP/s + Math: 17218.4 GFLOP/s ``` [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md) @@ -279,7 +313,7 @@ The official list of CUTLASS developers and contributors is available here: [CON # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/cmake/nop.cu b/cmake/nop.cu index 571c6c7c05..518a582b89 100644 --- a/cmake/nop.cu +++ b/cmake/nop.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cuBLAS.cmake b/cuBLAS.cmake index 60a56ca5fc..4c73a1db4c 100644 --- a/cuBLAS.cmake +++ b/cuBLAS.cmake @@ -1,7 +1,8 @@ message(STATUS "Configuring cublas ...") -if(DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) +if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR + (DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED)) # Don't add cuBLAS if it's defined and false, assume it's not found. @@ -9,28 +10,35 @@ if(DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) message(STATUS "cuBLAS Disabled.") elseif(NOT TARGET cublas) - + find_path( - _CUBLAS_INCLUDE_DIR cublas.h - PATHS - ${CUDA_TOOLKIT_ROOT_DIR}/include - $ENV{CUBLAS_PATH}/include - $ENV{CUDA_PATH}/include - ${CUBLAS_PATH}/include - /usr/include) + _CUBLAS_INCLUDE_DIR + NAMES cublas.h + HINTS + ${CUBLAS_INCLUDE_PATH} + ENV CUBLAS_INCLUDE_PATH + ${CUBLAS_PATH} + ENV CUBLAS_PATH + ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES + include + ) find_library( - _CUBLAS_LIBRARY cublas + _CUBLAS_LIBRARY + NAMES cublas HINTS - ${CUDA_TOOLKIT_ROOT_DIR}/lib64 - ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 - $ENV{CUBLAS_PATH}/lib64 - $ENV{CUBLAS_PATH}/lib/x64 - $ENV{CUDA_PATH}/lib64 - $ENV{CUDA_PATH}/lib/x64 - ${CUBLAS_PATH}/lib64 - ${CUBLAS_PATH}/lib/x64 - /usr/lib/x86_64-linux-gnu) + ${CUBLAS_LIBRARY_PATH} + ENV CUBLAS_LIBRARY_PATH + ${_CUBLAS_INCLUDE_DIR}/.. + ${CUBLAS_PATH} + ENV CUBLAS_PATH + ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES + lib64 + lib/x64 + lib + ) if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY) @@ -59,11 +67,13 @@ endif() if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas) if(WIN32) - add_library(cublas STATIC IMPORTED) + add_library(cublas STATIC IMPORTED GLOBAL) else() - add_library(cublas SHARED IMPORTED) + add_library(cublas SHARED IMPORTED GLOBAL) endif() + add_library(nvidia::cublas ALIAS cublas) + set_property( TARGET cublas PROPERTY IMPORTED_LOCATION @@ -76,35 +86,37 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas) $) find_library( - _CUBLASLT_LIBRARY cublasLt + _CUBLASLT_LIBRARY + NAMES cublasLt HINTS - ${CUDA_TOOLKIT_ROOT_DIR}/lib64 - ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 - $ENV{CUBLAS_PATH}/lib64 - $ENV{CUBLAS_PATH}/lib/x64 - $ENV{CUDA_PATH}/lib64 - $ENV{CUDA_PATH}/lib/x64 - ${CUBLAS_PATH}/lib64 - ${CUBLAS_PATH}/lib/x64 - /usr/lib/x86_64-linux-gnu) - - if(_CUBLASLT_LIBRARY) + ${CUBLAS_LIBRARY_PATH} + ENV CUBLAS_LIBRARY_PATH + ${_CUBLAS_INCLUDE_DIR}/.. + ${CUBLAS_PATH} + ENV CUBLAS_PATH + ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES + lib64 + lib/x64 + lib + ) + + if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt) if(WIN32) - add_library(cublasLt STATIC IMPORTED) + add_library(cublasLt STATIC IMPORTED GLOBAL) else() - add_library(cublasLt SHARED IMPORTED) + add_library(cublasLt SHARED IMPORTED GLOBAL) endif() set_property( TARGET cublasLt PROPERTY IMPORTED_LOCATION ${_CUBLASLT_LIBRARY}) - - target_link_libraries( - cublas - INTERFACE - cublasLt) + + add_library(nvidia::cublasLt ALIAS cublasLt) + + target_link_libraries(cublas INTERFACE cublasLt) endif() diff --git a/examples/00_basic_gemm/CMakeLists.txt b/examples/00_basic_gemm/CMakeLists.txt index 5b833b85dc..9ae257d9ab 100644 --- a/examples/00_basic_gemm/CMakeLists.txt +++ b/examples/00_basic_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/00_basic_gemm/basic_gemm.cu b/examples/00_basic_gemm/basic_gemm.cu index 415646327f..bda012abee 100644 --- a/examples/00_basic_gemm/basic_gemm.cu +++ b/examples/00_basic_gemm/basic_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/01_cutlass_utilities/CMakeLists.txt b/examples/01_cutlass_utilities/CMakeLists.txt index 2dfa083c1f..5f22b7b1cf 100644 --- a/examples/01_cutlass_utilities/CMakeLists.txt +++ b/examples/01_cutlass_utilities/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/01_cutlass_utilities/cutlass_utilities.cu b/examples/01_cutlass_utilities/cutlass_utilities.cu index 0b6aa38671..d1eaa57fe7 100644 --- a/examples/01_cutlass_utilities/cutlass_utilities.cu +++ b/examples/01_cutlass_utilities/cutlass_utilities.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/02_dump_reg_shmem/CMakeLists.txt b/examples/02_dump_reg_shmem/CMakeLists.txt index 4e9af4fbb7..5e6112e026 100644 --- a/examples/02_dump_reg_shmem/CMakeLists.txt +++ b/examples/02_dump_reg_shmem/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/02_dump_reg_shmem/dump_reg_shmem.cu b/examples/02_dump_reg_shmem/dump_reg_shmem.cu index 39d58db87d..ed712aa84e 100644 --- a/examples/02_dump_reg_shmem/dump_reg_shmem.cu +++ b/examples/02_dump_reg_shmem/dump_reg_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without *modification, are permitted provided that the following conditions are met: diff --git a/examples/03_visualize_layout/CMakeLists.txt b/examples/03_visualize_layout/CMakeLists.txt index 81211df90e..e2bb283489 100644 --- a/examples/03_visualize_layout/CMakeLists.txt +++ b/examples/03_visualize_layout/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -20,15 +20,9 @@ # STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cutlass_add_executable( +cutlass_example_add_executable( 03_visualize_layout visualize_layout.cpp register_layout.cu ) -target_link_libraries( - 03_visualize_layout - PRIVATE - CUTLASS - cutlass_tools_util_includes - ) diff --git a/examples/03_visualize_layout/options.h b/examples/03_visualize_layout/options.h index c72b1228f6..dd7de198a4 100644 --- a/examples/03_visualize_layout/options.h +++ b/examples/03_visualize_layout/options.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/03_visualize_layout/register_layout.cu b/examples/03_visualize_layout/register_layout.cu index 655d1f37dc..0d2b25eb30 100644 --- a/examples/03_visualize_layout/register_layout.cu +++ b/examples/03_visualize_layout/register_layout.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -34,6 +34,8 @@ #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/tensor_op_multiplicand_sm70.h" #include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + #include "visualize_layout.h" #include "register_layout.h" @@ -59,18 +61,40 @@ void RegisterLayouts(std::map // Integer matrix multiply.int4 8832 TN kblock128 {"TensorOpMultiplicand<4,128>", new VisualizeLayout>}, + // Integer matrix multiply.int4 16864 TN kblock256 + {"TensorOpMultiplicand<4,256>", + new VisualizeLayout>}, // Integer matrix multiply 8816 Interleaved-32 {"TensorOpMultiplicand<8,32>", new VisualizeLayout>}, // Integer matrix multiply 8816 TN kblock64 {"TensorOpMultiplicand<8,64>", new VisualizeLayout>}, + {"TensorOpMultiplicand<8,128>", + new VisualizeLayout>}, // Matrix Multiply 1688 TN kblock32 {"TensorOpMultiplicand<16,32>", new VisualizeLayout>}, // Matrix multiply 1688 NT {"TensorOpMultiplicand<16,64>", new VisualizeLayout>}, + // Matrix multiply 1688.TF32 TN kblock16 + {"TensorOpMultiplicand<32,16>", + new VisualizeLayout>}, + // Matrix multiply 1688.TF32 TN kblock32 + {"TensorOpMultiplicand<32,32>", + new VisualizeLayout>}, + // Matrix multiply 1688 NT + {"TensorOpMultiplicandCongruous<32,32>", + new VisualizeLayout< + cutlass::layout::TensorOpMultiplicandCongruous<32, 32>>}, + // Matrix multiply 884 NT + {"TensorOpMultiplicandCongruous<64,16>", + new VisualizeLayout< + cutlass::layout::TensorOpMultiplicandCongruous<64, 16>>}, + // Matrix multiply 884 TN + {"TensorOpMultiplicand64bCrosswise", + new VisualizeLayout}, {"TensorOpMultiplicandCongruous<128,4>", new VisualizeLayout< cutlass::layout::TensorOpMultiplicandCongruous<128, 4>>}, @@ -82,7 +106,7 @@ void RegisterLayouts(std::map cutlass::layout::VoltaTensorOpMultiplicandCongruous<16>>}, {"VoltaTensorOpMultiplicandCrosswise<16,32>", new VisualizeLayout< - cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>}, + cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>} }; for (auto layout : layout_pairs) { diff --git a/examples/03_visualize_layout/register_layout.h b/examples/03_visualize_layout/register_layout.h index fee911f798..1518e433c8 100644 --- a/examples/03_visualize_layout/register_layout.h +++ b/examples/03_visualize_layout/register_layout.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/03_visualize_layout/visualize_layout.cpp b/examples/03_visualize_layout/visualize_layout.cpp index 8908d2c1fd..a0f2718122 100644 --- a/examples/03_visualize_layout/visualize_layout.cpp +++ b/examples/03_visualize_layout/visualize_layout.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,14 +65,26 @@ void print_usage(std::ostream &out) { "--extent=64,64 --vectorize=32 --output-shape=256,4\n" << "$ 03_visualize_layout \"TensorOpMultiplicand<4,128>\" " "--extent=128,32 --vectorize=32 --output-shape=256,4\n" + << "$ 03_visualize_layout \"TensorOpMultiplicand<4,256>\" " + "--extent=256,16 --vectorize=32 --output-shape=256,4\n" << "$ 03_visualize_layout \"TensorOpMultiplicand<8,32>\" " "--extent=32,64 --vectorize=16 --output-shape=128,4\n" << "$ 03_visualize_layout \"TensorOpMultiplicand<8,64>\" " "--extent=64,32 --vectorize=16 --output-shape=128,4\n" + << "$ 03_visualize_layout \"TensorOpMultiplicand<8,128>\" " + "--extent=128,16 --vectorize=16 --output-shape=128,4\n" << "$ 03_visualize_layout \"TensorOpMultiplicand<16,32>\" " "--extent=32,32 --vectorize=8 --output-shape=64,4\n" << "$ 03_visualize_layout \"TensorOpMultiplicand<16,64>\" " "--extent=64,16 --vectorize=8 --output-shape=64,4\n" + << "$ 03_visualize_layout \"TensorOpMultiplicand<32,16>\" " + "--extent=16,32 --vectorize=4 --output-shape=32,4\n" + << "$ 03_visualize_layout \"TensorOpMultiplicand<32,32>\" " + "--extent=32,16 --vectorize=4 --output-shape=32,4\n" + << "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<32,32>\" " + "--extent=32,16 --vectorize=4 --output-shape=32,4\n" + << "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<64, 16>\" " + "--extent=16,16 --vectorize=2 --output-shape=16,4\n" << "$ 03_visualize_layout \"VoltaTensorOpMultiplicandCrosswise<16,32>\" " "--extent=32,64 --vectorize=4 --output-shape=64,4\n" << "$ 03_visualize_layout \"VotlaTensorOpMultiplicandCongruous<16>\" " diff --git a/examples/03_visualize_layout/visualize_layout.h b/examples/03_visualize_layout/visualize_layout.h index 031916c746..4093d27721 100644 --- a/examples/03_visualize_layout/visualize_layout.h +++ b/examples/03_visualize_layout/visualize_layout.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/04_tile_iterator/CMakeLists.txt b/examples/04_tile_iterator/CMakeLists.txt index cef156249d..cd32e2287a 100644 --- a/examples/04_tile_iterator/CMakeLists.txt +++ b/examples/04_tile_iterator/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/04_tile_iterator/tile_iterator.cu b/examples/04_tile_iterator/tile_iterator.cu index e63157608c..5c56f33bd8 100644 --- a/examples/04_tile_iterator/tile_iterator.cu +++ b/examples/04_tile_iterator/tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/05_batched_gemm/CMakeLists.txt b/examples/05_batched_gemm/CMakeLists.txt index 6c9bf50468..6cd0ca8dba 100644 --- a/examples/05_batched_gemm/CMakeLists.txt +++ b/examples/05_batched_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/05_batched_gemm/batched_gemm.cu b/examples/05_batched_gemm/batched_gemm.cu index d1fecda6e6..a9d8a9c680 100644 --- a/examples/05_batched_gemm/batched_gemm.cu +++ b/examples/05_batched_gemm/batched_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/06_splitK_gemm/CMakeLists.txt b/examples/06_splitK_gemm/CMakeLists.txt index 750c6205bb..7b30ae1668 100644 --- a/examples/06_splitK_gemm/CMakeLists.txt +++ b/examples/06_splitK_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/06_splitK_gemm/splitk_gemm.cu b/examples/06_splitK_gemm/splitk_gemm.cu index f0ce98258c..6e01a10162 100644 --- a/examples/06_splitK_gemm/splitk_gemm.cu +++ b/examples/06_splitK_gemm/splitk_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,7 +39,7 @@ inner product (1/16th of output), they accumulate to single output matrix. Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose +really hard. CUTLASS solves this problem by providing simplified abstractions to compose multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU easily. @@ -144,7 +144,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? // This code section describes ? using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -172,15 +172,28 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel; -int main() { +int run() { + cudaDeviceProp props; - CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major >= 7)) { - std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70." + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major != 7) { + std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75." << std::endl; + + // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. return 0; } + // + // Define problem size + // + const int length_m = 5120; const int length_n = 4096; const int length_k = 4096; @@ -295,11 +308,30 @@ int main() { tensor_ref_d.sync_host(); // Check if output from CUTLASS kernel and reference kernel are equal or not - std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), - tensor_ref_d.host_view()) - ? "Passed" - : "Failed") - << std::endl; + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); - CUTLASS_CHECK(status); + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); } + +int main() { + + // + // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. + // + // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + + // Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op. + return 0; + } + else { + return run(); + } +} + diff --git a/examples/07_volta_tensorop_gemm/CMakeLists.txt b/examples/07_volta_tensorop_gemm/CMakeLists.txt index 56dfce9ece..82e8172271 100644 --- a/examples/07_volta_tensorop_gemm/CMakeLists.txt +++ b/examples/07_volta_tensorop_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu index 424c90fc02..ac27fa177d 100644 --- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu +++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Volta GPU. Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose +really hard. CUTLASS solves this problem by providing simplified abstractions to compose multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU easily. @@ -156,7 +156,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? // This code section describes ? using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -188,13 +188,21 @@ using Gemm = cutlass::gemm::device::Gemm; -int main() { +int run() { + cudaDeviceProp props; - CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major >= 7)) { - std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70." + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major != 7) { + std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75." << std::endl; + + // Return 0 so tests are considered passing if run on unsupported architectures or CUDA Toolkits. return 0; } @@ -209,7 +217,7 @@ int main() { cutlass::HostTensor tensor_a( problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b( - problem_size.nk()); // <- Create matrix B with dimensions N x K + problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c( problem_size.mn()); // <- Create matrix C with dimensions M x N cutlass::HostTensor tensor_d( @@ -312,12 +320,28 @@ int main() { tensor_ref_d.sync_host(); // Check if output from CUTLASS kernel and reference kernel are equal or not - std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), - tensor_ref_d.host_view()) - ? "Passed" - : "Failed") - << std::endl; + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); - CUTLASS_CHECK(status); - return 0; + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + + // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. + // + // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + + // Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op. + return 0; + } + else { + return run(); + } } + diff --git a/examples/08_turing_tensorop_gemm/CMakeLists.txt b/examples/08_turing_tensorop_gemm/CMakeLists.txt index 9e011a1ed2..b4e4fe82f6 100644 --- a/examples/08_turing_tensorop_gemm/CMakeLists.txt +++ b/examples/08_turing_tensorop_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index 1628ce0a1a..d18a4e6ab7 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU. Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose +really hard. CUTLASS solves this problem by providing simplified abstractions to compose multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU easily. @@ -150,12 +150,12 @@ using SmArch = cutlass::arch::Sm75; using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N = 256, K = 64 // This code section describes tile size a warp will compute -using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 16 +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64 // This code section describes the size of MMA op using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -186,13 +186,30 @@ using Gemm = cutlass::gemm::device::Gemm; -int main() { +int run() { + + // Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + return -1; + } + cudaDeviceProp props; - CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major >= 7 && props.minor >= 5)) { - std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 75)) { + std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75." << std::endl; + + // Return 0 so tests are considered passing if run on unsupported platforms. return 0; } @@ -207,7 +224,7 @@ int main() { cutlass::HostTensor tensor_a( problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b( - problem_size.nk()); // <- Create matrix B with dimensions N x K + problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c( problem_size.mn()); // <- Create matrix C with dimensions M x N cutlass::HostTensor tensor_d( @@ -310,12 +327,28 @@ int main() { tensor_ref_d.sync_host(); // Check if output from CUTLASS kernel and reference kernel are equal or not - std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), - tensor_ref_d.host_view()) - ? "Passed" - : "Failed") - << std::endl; + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); - CUTLASS_CHECK(status); - return 0; + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + // Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + + // Returning zero so this test passes when built on older Toolkits. + return 0; + } + else { + return run(); + } } + diff --git a/examples/10_planar_complex/CMakeLists.txt b/examples/10_planar_complex/CMakeLists.txt new file mode 100644 index 0000000000..555836aebf --- /dev/null +++ b/examples/10_planar_complex/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# Planar Complex GEMM example +cutlass_example_add_executable( + 10_planar_complex + planar_complex.cu +) + + +# +# This example depends on the CUTLASS Library +# + +target_link_libraries( + 10_planar_complex + PRIVATE + cutlass_lib + cutlass_tools_util_includes +) + diff --git a/examples/10_planar_complex/planar_complex.cu b/examples/10_planar_complex/planar_complex.cu new file mode 100644 index 0000000000..b7318b99c2 --- /dev/null +++ b/examples/10_planar_complex/planar_complex.cu @@ -0,0 +1,557 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Planar Complex GEMM + + This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels supporting + the batched strided mode. + + These kernels represent complex matrices by storing the real and imaginary parts of the matrix in + disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts + as either column-major or row-major layouts with a single leading dimension indicating the stride + between columns or rows. + + The CUTLASS Library collects multiple template instantiations in a data structure and offers + a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. + + CUTLASS decouples matrix layout from complex transformation, so four possible transformations + are possible on the A and B operands: + + n: column-major + c: column-major complex conjugate + t: row-major + h: row-major complex conjugate + + The CUTLASS Library contains many kernel instances specialized for architecture, data type, tile + size, and alignment. This can result in long compile times. + + To build strictly the planar complex kernels needed for general application, execute the following + CMake command in an empty build directory. + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex + + This builds all planar complex GEMM variants for Volta and Turing architectures. + + To build strictly the kernels needed for this example, an even narrower filter string may be + specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for + the 'CN' layout configuration (conjugate A operand with both A and B as column-major). + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn + + $ make 10_planar_complex + + $ ./examples/10_planar_complex/10_planar_complex --m=2048 --n=1024 --k=512 --batch=10 +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" + +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "cutlass/util/reference/device/gemm_planar_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/library/handle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::complex alpha; + cutlass::complex beta; + + bool reference_check; + int iterations; + + Options(): + help(false), + problem_size({1024, 1024, 1024}), + batch_count(1), + reference_check(true), + iterations(20), + alpha(1), + beta() { } + + bool valid() { + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch", batch_count); + + cmd.get_cmd_line_argument("alpha", alpha.real()); + cmd.get_cmd_line_argument("alpha_i", alpha.imag()); + cmd.get_cmd_line_argument("beta", beta.real()); + cmd.get_cmd_line_argument("beta_i", beta.imag()); + + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "10_planar_complex example\n\n" + << " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m GEMM M dimension\n" + << " --n GEMM N dimension\n" + << " --k GEMM K dimension\n" + << " --batch Number of GEMM operations executed in one batch\n" + << " --alpha Epilogue scalar alpha (real part)\n" + << " --alpha_i Epilogue scalar alpha (imaginary part)\n" + << " --beta Epilogue scalar beta (real part)\n\n" + << " --beta_i Epilogue scalar beta (imaginary part)\n\n" + << " --iterations Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n" + << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = problem_size.product() * batch_count * 4; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance test environment for planar complex +class TestbedPlanarComplex { +public: + + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementCompute = float; + using ElementAccumulator = float; + + // + // Data members + // + + cutlass::library::Handle handle; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::DeviceAllocation tensor_A; + cutlass::DeviceAllocation tensor_B; + cutlass::DeviceAllocation tensor_C; + cutlass::DeviceAllocation tensor_D; + cutlass::DeviceAllocation tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex( + Options const &options + ): + problem_size(options.problem_size), batch_count(options.batch_count) { + + // Allocate device memory for batched strided GEMM + tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); + tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); + tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + } + + void initialize() { + + uint64_t seed = 1073; + + // Use small integers to simplify correctness checking + int scope_max = 6; + int scope_min = -6; + + cutlass::reference::device::BlockFillRandomUniform( + tensor_A.get(), tensor_A.size(), seed, ElementA(scope_max), ElementA(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_B.get(), tensor_B.size(), seed * 2019, ElementB(scope_max), ElementB(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_C.get(), tensor_C.size(), seed * 2020, ElementC(scope_max), ElementC(scope_min), 0); + } + + Result profile(Options const &options) { + + Result result; + + initialize(); + + ElementA *ptr_A = tensor_A.get(); + ElementB *ptr_B = tensor_B.get(); + ElementC *ptr_C = tensor_C.get(); + ElementC *ptr_D = tensor_D.get(); + + int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; + int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; + int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; + int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; + + int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); + int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); + int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + + int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); + int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); + int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); + int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMMs + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + + // + // Execute the planar complex GEMM kernel via the CUTLASS Library's + // dispatch routines. + // + // Note, for planar complex GEMM kernels, all numeric type arguments + // specify the data type of the base real types. These are understood to + // apply to planar complex representations of matrices in memory and to complex + // structures for scalars. + // + // See tools/library/include/cutlass/library/handle.h for more details. + // + + result.status = handle.gemm_planar_complex( + problem_size.m(), // GEMM M dimension + problem_size.n(), // GEMM N dimension + problem_size.k(), // GEMM K dimension + + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars + + &options.alpha, // Pointer to alpha scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix + cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand + ptr_A, // Pointer to real part of A matrix + ptr_A + imag_stride_A, // Pointer to imaginary part of A matrix + lda, // Leading dimension of real part of A matrix + lda, // Leading dimension of imaginary part of A matrix + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix + cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand + ptr_B, // Pointer to real part of B matrix + ptr_B + imag_stride_B, // Pointer to imaginary part of B matrix + ldb, // Leading dimension of real part of B matrix + ldb, // Leading dimension of imaginary part of B matrix + + &options.beta, // Pointer to beta scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices + + ptr_C, // Pointer to real part of C matrix + ptr_C + imag_stride_C, // Pointer to imaginary part of C matrix + ldc, // Leading dimension of real part of C matrix + ldc, // Leading dimension of imaginary part of C matrix + + ptr_D, // Pointer to real part of D matrix + ptr_D + imag_stride_D, // Pointer to imaginary part of D matrix + ldd, // Leading dimension of real part of D matrix + ldd, // Leading dimension of imaginary part of D matrix + + batch_count, // Number of batched elements + + batch_stride_A, // Stride between batches of real parts of A matrix + batch_stride_A, // Stride between batches of imaginary parts of A matrix + + batch_stride_B, // Stride between batches of real parts of B matrix + batch_stride_B, // Stride between batches of imaginary parts of B matrix + + batch_stride_C, // Stride between batches of real parts of C matrix + batch_stride_C, // Stride between batches of imaginary parts of C matrix + + batch_stride_D, // Stride between batches of real parts of D matrix + batch_stride_D // Stride between batches of imaginary parts of D matrix + ); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; + return result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + if (handle.get_last_operation()) { + std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; + } + + // + // Compute reference in device code + // + + if (options.reference_check) { + + result.passed = true; + + for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { + cutlass::reference::device::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + options.alpha, + {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, + cutlass::ComplexTransform::kConjugate, + {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, + cutlass::ComplexTransform::kNone, + options.beta, + {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, + {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} + ); + + ElementC epsilon = 0.1_hf; + ElementC nonzero_floor = 0.1_hf; + + result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( + tensor_D.get() + idx * batch_stride_D, + tensor_D_ref.get() + idx * batch_stride_D, + batch_stride_D, + epsilon, + nonzero_floor + ); + } + + if (result.passed) { + std::cout << "Reference check passed." << std::endl; + } + else { + std::cerr << "Error - reference check failed." << std::endl; + } + } + + std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. + // + // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 7) { + std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70." + << std::endl; + + // Returning zero so this test passes on older architectures even though its actions are no-op. + return 0; + } + else if (props.major == 7 && props.minor <= 2) { + // + // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits even though its actions are no-op. + return 0; + } + } + else if (props.major == 7 && props.minor >= 5) { + // + // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits even though its actions are no-op. + return 0; + } + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + TestbedPlanarComplex testbed(options); + + Result result = testbed.profile(options); + + return result.passed ? 0 : -1; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/11_planar_complex_array/CMakeLists.txt b/examples/11_planar_complex_array/CMakeLists.txt new file mode 100644 index 0000000000..2a3f5987e4 --- /dev/null +++ b/examples/11_planar_complex_array/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# Planar Complex Array GEMM example +cutlass_example_add_executable( + 11_planar_complex_array + planar_complex_array.cu +) + + +# +# This example depends on the CUTLASS Library +# + +target_link_libraries( + 11_planar_complex_array + PRIVATE + cutlass_lib + cutlass_tools_util_includes +) + diff --git a/examples/11_planar_complex_array/planar_complex_array.cu b/examples/11_planar_complex_array/planar_complex_array.cu new file mode 100644 index 0000000000..6a0270533e --- /dev/null +++ b/examples/11_planar_complex_array/planar_complex_array.cu @@ -0,0 +1,617 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Planar Complex Array Example + + This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels which + execute a batch of matrix products, loading problem sizes and matrix base pointers from arrays + in global memory. + + These kernels represent complex matrices by storing the real and imaginary parts of the matrix in + disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts + as either column-major or row-major layouts with a single leading dimension indicating the stride + between columns or rows. + + The CUTLASS Library collects multiple template instantiations in a data structure and offers + a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. + + CUTLASS decouples matrix layout from complex transformation, so four possible transformations + are possible on the A and B operands: + + n: column-major + c: column-major complex conjugate + t: row-major + h: row-major complex conjugate + + To build strictly the planar complex kernels needed for general application, execute the following + CMake command in an empty build directory. + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex + + This builds all planar complex GEMM variants for Volta and Turing architectures. + + To build strictly the kernels needed for this example, an even narrower filter string may be + specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for + the 'CN' layout configuration (conjugate A operand with both A and B as column-major). + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn + + $ make 11_planar_complex_array + + $ ./examples/11_planar_complex_array/11_planar_complex_array --m=2048 --n=1024 --k=512 --batch=10 +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" + +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "cutlass/util/reference/device/gemm_planar_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/library/handle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::complex alpha; + cutlass::complex beta; + + bool reference_check; + int iterations; + + Options(): + help(false), + problem_size({1024, 1024, 1024}), + batch_count(1), + reference_check(true), + iterations(20), + alpha(1), + beta() { } + + bool valid() { + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch", batch_count); + + cmd.get_cmd_line_argument("alpha", alpha.real()); + cmd.get_cmd_line_argument("alpha_i", alpha.imag()); + cmd.get_cmd_line_argument("beta", beta.real()); + cmd.get_cmd_line_argument("beta_i", beta.imag()); + + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "11_planar_complex_array example\n\n" + << " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m GEMM M dimension\n" + << " --n GEMM N dimension\n" + << " --k GEMM K dimension\n" + << " --batch Number of GEMM operations executed in one batch\n" + << " --alpha Epilogue scalar alpha (real part)\n" + << " --alpha_i Epilogue scalar alpha (imaginary part)\n" + << " --beta Epilogue scalar beta (real part)\n\n" + << " --beta_i Epilogue scalar beta (imaginary part)\n\n" + << " --iterations Number of profiling iterations to perform.\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = problem_size.product() * batch_count * 4; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance test environment for planar complex +class TestbedPlanarComplex { +public: + + // Half-precision input and output + using Element = cutlass::half_t; + + // Configurations for layouts and internal computation + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementCompute = float; + using ElementAccumulator = float; + + // + // Data members + // + + cutlass::library::Handle handle; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::DeviceAllocation tensor_A; + cutlass::DeviceAllocation tensor_B; + cutlass::DeviceAllocation tensor_C; + cutlass::DeviceAllocation tensor_D; + cutlass::DeviceAllocation tensor_D_ref; + + cutlass::DeviceAllocation ptr_A_real; + cutlass::DeviceAllocation ptr_A_imag; + cutlass::DeviceAllocation ptr_B_real; + cutlass::DeviceAllocation ptr_B_imag; + cutlass::DeviceAllocation ptr_C_real; + cutlass::DeviceAllocation ptr_C_imag; + cutlass::DeviceAllocation ptr_D_real; + cutlass::DeviceAllocation ptr_D_imag; + + // + // Methods + // + + TestbedPlanarComplex( + Options const &options + ): + problem_size(options.problem_size), batch_count(options.batch_count) { + + // Allocate device memory for batched planar complex GEMM + tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); + tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); + tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + + ptr_A_real.reset(batch_count); + ptr_A_imag.reset(batch_count); + ptr_B_real.reset(batch_count); + ptr_B_imag.reset(batch_count); + ptr_C_real.reset(batch_count); + ptr_C_imag.reset(batch_count); + ptr_D_real.reset(batch_count); + ptr_D_imag.reset(batch_count); + + } + + void initialize() { + + uint64_t seed = 1073; + + // Use small integers to simplify correctness checking + int scope_max = 6; + int scope_min = -6; + + cutlass::reference::device::BlockFillRandomUniform( + tensor_A.get(), tensor_A.size(), seed, Element(scope_max), Element(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_B.get(), tensor_B.size(), seed * 2019, Element(scope_max), Element(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_C.get(), tensor_C.size(), seed * 2020, Element(scope_max), Element(scope_min), 0); + } + + Result profile(Options const &options) { + + Result result; + + initialize(); + + Element *ptr_A = tensor_A.get(); + Element *ptr_B = tensor_B.get(); + Element *ptr_C = tensor_C.get(); + Element *ptr_D = tensor_D.get(); + + int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; + int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; + int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; + int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; + + int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); + int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); + int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + + int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); + int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); + int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); + int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); + + // + // Configure pointers in global memory + // + + struct { + Element *base; + void **ptr_real; + void **ptr_imag; + int64_t batch_stride; + int64_t imag_stride; + } tensors[] = { + { tensor_A.get(), ptr_A_real.get(), ptr_A_imag.get(), batch_stride_A, imag_stride_A}, + { tensor_B.get(), ptr_B_real.get(), ptr_B_imag.get(), batch_stride_B, imag_stride_B}, + { tensor_C.get(), ptr_C_real.get(), ptr_C_imag.get(), batch_stride_C, imag_stride_C}, + { tensor_D.get(), ptr_D_real.get(), ptr_D_imag.get(), batch_stride_D, imag_stride_D} + }; + + for (auto const &tensor : tensors) { + for (int idx = 0; idx < batch_count; ++idx) { + + void *ptr_real = tensor.base + idx * tensor.batch_stride; + void *ptr_imag = tensor.base + idx * tensor.batch_stride + tensor.imag_stride; + + cudaError_t error = cudaMemcpy( + tensor.ptr_real + idx, + &ptr_real, + sizeof(void *), + cudaMemcpyHostToDevice); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to copy pointer to device memory"); + } + + error = cudaMemcpy( + tensor.ptr_imag + idx, + &ptr_imag, + sizeof(void *), + cudaMemcpyHostToDevice); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to copy pointer to device memory"); + } + } + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + + // + // Execute the planar complex array GEMM kernel via the CUTLASS Library's + // dispatch routines. + // + // Note, for planar complex array GEMM kernels, all numeric type arguments + // specify the data type of the base real types. These are understood to + // apply to planar complex representations of matrices in memory and to complex + // structures for scalars. + // + // See tools/library/include/cutlass/library/handle.h for more details. + // + + result.status = handle.gemm_planar_complex_array( + + problem_size.m(), // expected GEMM M dimension + problem_size.n(), // expected GEMM N dimension + problem_size.k(), // expected GEMM K dimension + batch_count, // Number of batched elements + + nullptr, + nullptr, + nullptr, + + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars + + &options.alpha, // Pointer to alpha scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix + cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand + + ptr_A_real.get(), // Pointer to array of pointers to real part of A matrix + ptr_A_imag.get(), // Pointer to array of pointers to imaginary part of A matrix + + lda, // Leading dimension of real part of A matrix + lda, // Leading dimension of imaginary part of A matrix + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix + cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand + + ptr_B_real.get(), // Pointer to array of pointers to real part of B matrix + ptr_B_imag.get(), // Pointer to array of pointers to imaginary part of B matrix + + ldb, // Leading dimension of real part of B matrix + ldb, // Leading dimension of imaginary part of B matrix + + &options.beta, // Pointer to beta scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices + + ptr_C_real.get(), // Pointer to array of pointers to real part of C matrix + ptr_C_imag.get(), // Pointer to array of pointers to imaginary part of C matrix + + ldc, // Leading dimension of real part of C matrix + ldc, // Leading dimension of imaginary part of C matrix + + ptr_D_real.get(), // Pointer to array of pointers to real part of D matrix + ptr_D_imag.get(), // Pointer to array of pointers to imaginary part of D matrix + + ldd, // Leading dimension of real part of D matrix + ldd // Leading dimension of imaginary part of D matrix + ); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; + return result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + if (handle.get_last_operation()) { + std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; + } + + // + // Compute reference in device code + // + + if (options.reference_check) { + + result.passed = true; + + for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { + cutlass::reference::device::GemmPlanarComplex< + Element, LayoutA, + Element, LayoutB, + Element, LayoutC, + ElementAccumulator + >( + problem_size, + options.alpha, + {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, + cutlass::ComplexTransform::kConjugate, + {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, + cutlass::ComplexTransform::kNone, + options.beta, + {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, + {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} + ); + + Element epsilon = 0.1_hf; + Element nonzero_floor = 0.1_hf; + + result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( + tensor_D.get() + idx * batch_stride_D, + tensor_D_ref.get() + idx * batch_stride_D, + batch_stride_D, + epsilon, + nonzero_floor + ); + } + + if (result.passed) { + std::cout << "Reference check passed." << std::endl; + } + else { + std::cerr << "Error - reference check failed." << std::endl; + } + } + + std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. + // + // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 7) { + std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70." + << std::endl; + + // Returning zero so this passes on older architectures. Its actions are no-op. + return 0; + } + else if (props.major == 7 && props.minor <= 2) { + // + // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + + // Returning zero so this passes on older Toolkits. Its actions are no-op. + return 0; + } + } + else if (props.major == 7 && props.minor >= 5) { + // + // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + + // Returning zero so this passes on older Toolkits. Its actions are no-op. + return 0; + } + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + TestbedPlanarComplex testbed(options); + + Result result = testbed.profile(options); + + return result.passed ? 0 : -1; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/12_gemm_bias_relu/CMakeLists.txt b/examples/12_gemm_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..fb78d77fa2 --- /dev/null +++ b/examples/12_gemm_bias_relu/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 12_gemm_bias_relu + gemm_bias_relu.cu + ) + diff --git a/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/examples/12_gemm_bias_relu/gemm_bias_relu.cu new file mode 100644 index 0000000000..7faaa98aa7 --- /dev/null +++ b/examples/12_gemm_bias_relu/gemm_bias_relu.cu @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +*/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A +using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Column Major for +// Matrix A, Row Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::ColumnMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm75; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + +// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to +// +// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + beta * c_ij ) +// +using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This becomes + // the vector width of math instructions in + // epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + +// Number of pipelines you want to use +constexpr int NumStages = 2; + +using Gemm = cutlass::gemm::device::Gemm; + +int run() { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major * 10 + props.minor >= 75)) { + std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." + << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + const int length_m = 5120; + const int length_n = 4096; + const int length_k = 4096; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + cutlass::HostTensor tensor_b( + problem_size.nk()); // <- Create matrix B with dimensions N x K + + cutlass::HostTensor tensor_c_bias( + {problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1 + + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(4), + ElementInputB(-4), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c_bias.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c_bias.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + tensor_a.device_ref(), // <- reference to matrix A on device + tensor_b.device_ref(), // <- reference to matrix B on device + + {tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM + // to project away the N dimension by setting the stride to zero. + + tensor_d.device_ref(), // <- reference to matrix D on device + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // + // Create instantiation for device reference gemm kernel + // + + cutlass::reference::device::Gemm + gemm_device_reference; + + // Launch device reference to compute strictly the product A * B + gemm_device_reference( + problem_size, + alpha, + tensor_a.device_ref(), + tensor_b.device_ref(), + 0, + tensor_c_bias.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Compute bias + relu in host code + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + tensor_ref_d.at({i, j}) = std::max( + ElementOutput(0), + ElementOutput(tensor_ref_d.at({i, j}) + beta * tensor_c_bias.at({i, 0})) + ); + } + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), + tensor_ref_d.host_view()) + ? "Passed" + : "Failed") + << std::endl; + + CUTLASS_CHECK(status); + return 0; +} + +int main() { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + else { + return run(); + } +} + diff --git a/examples/13_fused_two_gemms/CMakeLists.txt b/examples/13_fused_two_gemms/CMakeLists.txt new file mode 100644 index 0000000000..ba51537ca2 --- /dev/null +++ b/examples/13_fused_two_gemms/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 13_fused_two_gemms + fused_gemm.cu + ) + +target_include_directories( + 13_fused_two_gemms + PRIVATE + . + ) + diff --git a/examples/13_fused_two_gemms/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h b/examples/13_fused_two_gemms/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h new file mode 100644 index 0000000000..10a0d4bf94 --- /dev/null +++ b/examples/13_fused_two_gemms/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/b2b_gemm.h" +#include "b2b_gemm_run.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +void run_nonfused_gemm_f16() { + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576); + cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64); + ElementCompute alpha0 = ElementCompute(2); + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(2); + ElementCompute beta1 = ElementCompute(1); + + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Gemm0 = cutlass::gemm::device::Gemm< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ThreadblockShape0, + WarpShape0, + InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + 2 + >; + using Gemm1 = cutlass::gemm::device::Gemm< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ThreadblockShape1, + WarpShape1, + InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + 2 + >; + + B2bNonFusedGemmRun nonFusedGemm; + + std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; + bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1); + if(pass) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; +} + +void run_fused_gemm_f16() { + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576); + cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64); + ElementCompute alpha0 = ElementCompute(2); + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(2); + ElementCompute beta1 = ElementCompute(1); + + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + InstructionShape::kM * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >; + + + + using B2bGemm = cutlass::gemm::device::B2bGemm< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + 2 + >; + + B2bFusedGemmRun fusedGemm; + + std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n"; + bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1); + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + +} +//////////////////////////////////////////////////////////////////////////////// + +#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) diff --git a/examples/13_fused_two_gemms/b2b_gemm_run.h b/examples/13_fused_two_gemms/b2b_gemm_run.h new file mode 100644 index 0000000000..053064d751 --- /dev/null +++ b/examples/13_fused_two_gemms/b2b_gemm_run.h @@ -0,0 +1,608 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bNonFusedGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + B2bNonFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm0::Arguments arguments_0{ + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_D0.device_ref(), + {alpha0, beta0} + }; + + typename Gemm1::Arguments arguments_1{ + problem_size_1, + tensor_D0.device_ref(), + tensor_B1.device_ref(), + tensor_C1.device_ref(), + tensor_D1.device_ref(), + {alpha1, beta1} + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + cutlass::Status status = gemm_op_0.initialize(arguments_0); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1); + + CUTLASS_CHECK(status); + // + // Run the GEMM + // + + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < 100; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < 100; i++) { + + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n"; + std::cout << "total time " << totalTime / 100.0 << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + tensor_C0.device_ref(), + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + tensor_C1.device_ref(), + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + + return passed; + } +}; + +template +struct B2bFusedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + B2bFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + +// cutlass::HostTensor< +// typename B2bGemm::ElementC, +// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename B2bGemm::Arguments arguments{ + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_B1.device_ref(), + tensor_C1.device_ref(), + tensor_D1.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + }; + + B2bGemm b2b_gemm_op; + + cutlass::Status status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < 100; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "time " << gemmTime / 100.0 << " ms\n"; + + //tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, + ElementAccumulator, typename B2bGemm::Operator> + reference_gemm_0, reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + tensor_C0.device_ref(), + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + tensor_C1.device_ref(), + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() +// << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h b/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h new file mode 100644 index 0000000000..1c3f15c2cf --- /dev/null +++ b/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/b2b_gemm.h" +#include "b2b_interleaved_gemm_run.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +void run_nonfused_gemm_s8() { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576); + cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64); + ElementCompute alpha0 = ElementCompute(2); + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(2); + ElementCompute beta1 = ElementCompute(1); + + using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; + + using Gemm0 = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ThreadblockShape0, + WarpShape0, + InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + 2 + >; + using Gemm1 = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ThreadblockShape1, + WarpShape1, + InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + 2 + >; + + B2bInterleavedNonFusedGemmRun nonFusedGemm; + + std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; + bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1); + if(pass) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; +} + +void run_fused_gemm_s8() { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576); + cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64); + ElementCompute alpha0 = ElementCompute(2); + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(2); + ElementCompute beta1 = ElementCompute(1); + + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + InstructionShape::kM * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >; + + + + using B2bGemm = cutlass::gemm::device::B2bGemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + 2 + >; + + B2bInterleavedFusedGemmRun fusedGemm; + + std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n"; + bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1); + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) diff --git a/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h b/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h new file mode 100644 index 0000000000..906cabb409 --- /dev/null +++ b/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h @@ -0,0 +1,633 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/reference/device/gemm.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template +struct B2bInterleavedNonFusedGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + B2bInterleavedNonFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + + //Reorder B0 and B1 + cutlass::reorder_column( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); + cutlass::reorder_column( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm0::Arguments arguments_0{ + problem_size_0, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_D0.device_ref(), + {alpha0, beta0} + }; + + typename Gemm1::Arguments arguments_1{ + problem_size_1, + tensor_D0.device_ref(), + tensor_B1_reordered.device_ref(), + tensor_C1.device_ref(), + tensor_D1.device_ref(), + {alpha1, beta1} + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + cutlass::Status status = gemm_op_0.initialize(arguments_0); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1); + + CUTLASS_CHECK(status); + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < 100; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + + for(int i = 0; i < 100; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n"; + std::cout << "total time " << totalTime / 100.0 << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + tensor_C0.device_ref(), + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + tensor_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + tensor_C1.device_ref(), + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_interleaved_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + + return passed; + } +}; + +template +struct B2bInterleavedFusedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + B2bInterleavedFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + +// cutlass::HostTensor< +// typename B2bGemm::ElementC, +// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + + //Reorder B0 + cutlass::reorder_column( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); + cutlass::reorder_column( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); + + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + //tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename B2bGemm::Arguments arguments{ + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_B1_reordered.device_ref(), + tensor_C1.device_ref(), + tensor_D1.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + 1, /*threadblock_swizzle_k_tile*/ + }; + + B2bGemm b2b_gemm_op; + + cutlass::Status status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < 100; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "time " << gemmTime / 100.0 << " ms\n"; + + //tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, + ElementAccumulator, typename B2bGemm::Operator> + reference_gemm_0, reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + tensor_C0.device_ref(), + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + tensor_C1.device_ref(), + reference_D1.device_ref() + ); + + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_interleaved_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() + << "\nC0 =\n" << tensor_C0.host_view() +// << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_fused_two_gemms/device/b2b_gemm.h b/examples/13_fused_two_gemms/device/b2b_gemm.h new file mode 100644 index 0000000000..3f161435dd --- /dev/null +++ b/examples/13_fused_two_gemms/device/b2b_gemm.h @@ -0,0 +1,439 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" + +#include "kernel/b2b_gemm.h" +#include "kernel/default_b2b_gemm.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Epilogue output operator + typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Whether Beta is zero or not + bool IsBetaZero = false> +class B2bGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape0 = ThreadblockShape0_; + using ThreadblockShape1 = ThreadblockShape1_; + using WarpShape0 = WarpShape0_; + using WarpShape1 = WarpShape1_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool const kIsBetaZero = IsBetaZero; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Define the kernel + using B2bGemmKernel = typename kernel::DefaultB2bGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kIsBetaZero + >::B2bGemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size_0; + GemmCoord problem_size_1; + TensorRef ref_A0; + TensorRef ref_B0; + TensorRef ref_C0; + TensorRef ref_B1; + TensorRef ref_C1; + TensorRef ref_D1; + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_0_, + GemmCoord problem_size_1_, + TensorRef ref_A0_, + TensorRef ref_B0_, + TensorRef ref_C0_, + TensorRef ref_B1_, + TensorRef ref_C1_, + TensorRef ref_D1_, + typename EpilogueOutputOp0::Params epilogue0_ = + typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params(), + int split_k_slices_ = 1 + ): + problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + split_k_slices(split_k_slices_) { + + } + }; + +private: + + /// Kernel parameters object + typename B2bGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + B2bGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = B2bGemmKernel::can_implement( + args.problem_size_0, + args.problem_size_1, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.split_k_slices); +// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( +// args.problem_size_1, +// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, +// args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename B2bGemmKernel::Params{ + args.problem_size_0, + args.problem_size_1, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.epilogue0, + args.epilogue1, + static_cast(workspace), + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A0.reset(args.ref_A.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C.non_const_ref().data()); + params_.ref_B1.reset(args.ref_B.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(B2bGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_fused_two_gemms/fused_gemm.cu b/examples/13_fused_two_gemms/fused_gemm.cu new file mode 100644 index 0000000000..a7856abe5a --- /dev/null +++ b/examples/13_fused_two_gemms/fused_gemm.cu @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + +This example shows fusing two GEMM mainloops into one kernel. The first GEMM computes relu(alpha*A*B) and +the second GEMM computes relu(alpha*A*B+beta*C). The performance measuring environment compares against +two unfused GEMM operations, demonstrating a speedup of the fused kernel on the +NVIDIA Turing GPU architecture. + +Problem size: + + GEMM1 (M,N,K): 128*1600, 64, 576 + GEMM2 (M,N,K): 128*1600, 128, 64 + +Note that GEMM1_N = GEMM2_K + +The example requires the number of threadblocks be the same across 2 GEMMs and +thread_block_tile_N = problem_N so the data required by each layer is threadblock-resident. It +also requires warp_tile_N = thread_block_tile_N so the data required by each warp is +register-file-resident. + +Performance: + + - fp16 on Tesla T4 @ 1590MHz (non-fused vs. fused): 1.39011 ms vs. 1.26035 ms + - int8 on Tesla T4 @ 1590MHz (non-fused vs. fused): 0.751759 ms vs. 0.62971 ms + - fp16 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.721144 ms vs. 0.629864 ms + - int8 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.379049 ms vs. 0.324764 ms + +*/ + +#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h" +#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h" + +int run() { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major * 10 + props.minor >= 75)) { + std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." + << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + run_nonfused_gemm_f16(); + run_fused_gemm_f16(); + run_nonfused_gemm_s8(); + run_fused_gemm_s8(); +#endif + + return 0; +} + +int main() { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + else { + return run(); + } +} + diff --git a/examples/13_fused_two_gemms/kernel/b2b_gemm.h b/examples/13_fused_two_gemms/kernel/b2b_gemm.h new file mode 100644 index 0000000000..d106fa46af --- /dev/null +++ b/examples/13_fused_two_gemms/kernel/b2b_gemm.h @@ -0,0 +1,407 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct B2bGemm { + + using B2bMma = B2bMma_; + using Epilogue = Epilogue_; + using OutputOp0 = typename B2bMma::OutputOp; + using OutputOp1 = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename B2bMma::WarpCount0; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size_0; + cutlass::gemm::GemmCoord problem_size_1; + cutlass::gemm::GemmCoord grid_tiled_shape; + typename B2bMma::IteratorA0::Params params_A0; + typename B2bMma::IteratorA0::TensorRef ref_A0; + typename B2bMma::IteratorB0::Params params_B0; + typename B2bMma::IteratorB0::TensorRef ref_B0; + typename Epilogue::OutputTileIterator::Params params_C0; + typename Epilogue::OutputTileIterator::TensorRef ref_C0; + typename B2bMma::IteratorB1::Params params_B1; + typename B2bMma::IteratorB1::TensorRef ref_B1; + typename Epilogue::OutputTileIterator::Params params_C1; + typename Epilogue::OutputTileIterator::TensorRef ref_C1; + typename Epilogue::OutputTileIterator::Params params_D1; + typename Epilogue::OutputTileIterator::TensorRef ref_D1; + typename OutputOp0::Params output_op_0; + typename OutputOp1::Params output_op_1; + int *semaphore; + int gemm_k_iterations_0; + int gemm_k_size_0; + int gemm_k_iterations_1; + int gemm_k_size_1; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), + gemm_k_iterations_1(0), gemm_k_size_1(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size_0, + cutlass::gemm::GemmCoord const & problem_size_1, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + int *workspace = nullptr + ): + problem_size_0(problem_size_0), + problem_size_1(problem_size_1), + grid_tiled_shape(grid_tiled_shape), + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + output_op_0(output_op_0), + output_op_1(output_op_1) { + + int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK; + int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename B2bMma::B2bMmaSharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + B2bGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size_0, + cutlass::gemm::GemmCoord const & problem_size_1, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1) { + + static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) || + (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) || + (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) || + (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) || + (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) || + (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * B2bMma::Shape0::kM, + threadblock_tile_offset.k() * params.gemm_k_size_0, + }; + + cutlass::MatrixCoord tb_offset_B0{ + threadblock_tile_offset.k() * params.gemm_k_size_0, + threadblock_tile_offset.n() * B2bMma::Shape0::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + threadblock_tile_offset.k() * params.gemm_k_size_1, + threadblock_tile_offset.n() * B2bMma::Shape1::kN + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k_0 = min( + params.problem_size_0.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k_1 = min( + params.problem_size_1.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + // Compute threadblock-scoped matrix multiply-add +// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.params_A0, + params.ref_A0.data(), + {params.problem_size_0.m(), problem_size_k_0}, + thread_idx, + tb_offset_A0); + + typename B2bMma::IteratorB0 iterator_B0( + params.params_B0, + params.ref_B0.data(), + {problem_size_k_0, params.problem_size_0.n()}, + thread_idx, + tb_offset_B0); + + typename B2bMma::IteratorB1 iterator_B1( + params.params_B1, + params.ref_B1.data(), + {problem_size_k_1, params.problem_size_1.n()}, + thread_idx, + tb_offset_B1); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + OutputOp0 output_op_0(params.output_op_0); + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations_0 > 0) { + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0); + } + + // + // Epilogue + // + + OutputOp1 output_op_1(params.output_op_1); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * B2bMma::Shape1::kM, + threadblock_tile_offset.n() * B2bMma::Shape1::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + params.ref_C1.data(), + params.problem_size_1.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D1( + params.params_D1, + params.ref_D1.data(), + params.problem_size_1.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h b/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h new file mode 100644 index 0000000000..45b2d545ef --- /dev/null +++ b/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h @@ -0,0 +1,296 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "kernel/b2b_gemm.h" +#include "threadblock/default_b2b_mma.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Beta is zero or not + bool IsBetaZero = false +> +struct DefaultB2bGemm; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator +> +struct DefaultB2bGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + SplitKSerial, + Operator +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + 2, + Operator, + EpilogueOutputOp0 + >::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + typename B2bMma::Operator1, + kPartitionsK1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + + +/// Partial specialization for Turing IMMA Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of Interleaved k + int InterleavedK, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Is Beta zero or not + bool IsBetaZero> +struct DefaultB2bGemm, + kAlignmentA, ElementB, + layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, + int32_t, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, + arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, + WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue for the 2nd Gemm + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK, + IsBetaZero>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h b/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h new file mode 100644 index 0000000000..01cca8b7a2 --- /dev/null +++ b/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + using Shape1 = Shape1_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm0 = typename Policy0::Operator::Shape; + using WarpGemm1 = typename Policy1::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount0 = GemmShape; + using WarpCount1 = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations0 = + (WarpGemm0::kK / Operator0::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = + (WarpGemm1::kK / Operator1::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template< + typename Shape_, + typename Policy_ + > + class SharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Policy = Policy_; + using Operator = typename Policy::Operator; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + using SharedStorage0 = SharedStorage; + using SharedStorage1 = SharedStorage; + union B2bMmaSharedStorage { + SharedStorage0 sharedStorage0; + SharedStorage1 sharedStorage1; + }; + + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A0 operand from shared memory + typename Operator0::IteratorA warp_tile_iterator_A0_; + + /// Iterator to load a warp-scoped tile of B0 operand from shared memory + typename Operator0::IteratorB warp_tile_iterator_B0_; + + /// Iterator to load a warp-scoped tile of B0 operand from shared memory + typename Operator1::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h b/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h new file mode 100644 index 0000000000..ca89cf0bdc --- /dev/null +++ b/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h @@ -0,0 +1,509 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////////////////////// +template +struct chk_val { + static_assert(a==0, "check value"); +}; + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bMmaPipelined : public B2bMmaBase { +public: + + ///< Base class + using Base = B2bMmaBase; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using Policy0 = Policy0_; ///< Policy describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy describing tuning details + + using SmemIteratorB1 = SmemIteratorB1_; + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaPipelined( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) { + + + // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + //These should stay the same across different GEMM layers + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + //These may change across different GEMM layers + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumualtor tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(tb_frag_A); + this->smem_iterator_B0_.store(tb_frag_B0); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + if (gemm_k_iterations_0 <= 1) { + iterator_A.clear_mask(); + iterator_B0.clear_mask(); + } + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + iterator_A.load(tb_frag_A); + + // + // Mainloop + // + + // Note: The main loop does not support Base::WarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(tb_frag_A); + + this->smem_iterator_B0_.store(tb_frag_B0); + + __syncthreads(); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + iterator_A.load(tb_frag_A); + + ++this->smem_iterator_B0_; + ++this->smem_iterator_A_; + + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + // Avoid reading out of bounds if this was the last loop iteration + if (gemm_k_iterations_0 <= 2) { + iterator_A.clear_mask(); + iterator_B0.clear_mask(); + } + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0); + } + } + + //2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + // + // Prologue + // + + FragmentB1 tb_frag_B1; + + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + this->smem_iterator_B1_.store(tb_frag_B1); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + //warp_tile_iterator_A1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Avoid reading out of bounds + if (gemm_k_iterations_1 <= 1) { + iterator_B1.clear_mask(); + } + + // + // Mainloop + // + + // Note: The main loop does not support Base::WarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + + this->smem_iterator_B1_.store(tb_frag_B1); + + __syncthreads(); + ++smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + ++iterator_B1; + + + // Avoid reading out of bounds if this was the last loop iteration + if (gemm_k_iterations_1 <= 2) { + iterator_B1.clear_mask(); + } + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum); + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h b/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h new file mode 100644 index 0000000000..cd1403c792 --- /dev/null +++ b/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_pipelined.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultB2bMma; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + OperatorClass, 2, Operator>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + OperatorClass, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp, true>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, ArchTag, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + + static_assert(kAlignmentA == 128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + static_assert(kAlignmentB ==128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + // Define iterators over tiles from the A operand + using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, + LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, + LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, + InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; + + + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e434cd7fe3..99379fe45a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -22,16 +22,14 @@ set(CUTLASS_EXAMPLES_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common) -function(cutlass_example_add_executable) +function(cutlass_example_add_executable NAME) set(options) set(oneValueArgs) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cutlass_add_executable(${__UNPARSED_ARGUMENTS}) - - list(GET __UNPARSED_ARGUMENTS 0 NAME) + cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS}) target_link_libraries( ${NAME} @@ -46,9 +44,18 @@ function(cutlass_example_add_executable) ${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR} ) + add_custom_target( + test_${NAME} + COMMAND + ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ + DEPENDS + ${NAME} + ) + endfunction() add_custom_target(cutlass_examples) +add_custom_target(test_examples) foreach(EXAMPLE 00_basic_gemm @@ -59,9 +66,15 @@ foreach(EXAMPLE 05_batched_gemm 06_splitK_gemm 07_volta_tensorop_gemm - 08_turing_tensorop_gemm) + 08_turing_tensorop_gemm + 10_planar_complex + 11_planar_complex_array + 12_gemm_bias_relu + 13_fused_two_gemms +) add_subdirectory(${EXAMPLE}) add_dependencies(cutlass_examples ${EXAMPLE}) + add_dependencies(test_examples test_${EXAMPLE}) endforeach() diff --git a/include/cutlass/aligned_buffer.h b/include/cutlass/aligned_buffer.h index 3232ef87d3..8b3bb0713d 100644 --- a/include/cutlass/aligned_buffer.h +++ b/include/cutlass/aligned_buffer.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index b38a347a45..faf01cc656 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -52,6 +52,10 @@ struct Sm72 { struct Sm75 { static int const kMinComputeCapability = 75; }; +struct Sm80 { + static int const kMinComputeCapability = 80; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace arch diff --git a/include/cutlass/arch/cache_operation.h b/include/cutlass/arch/cache_operation.h new file mode 100644 index 0000000000..646b51ded3 --- /dev/null +++ b/include/cutlass/arch/cache_operation.h @@ -0,0 +1,60 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Directives related to cache operations +*/ +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Controls PTX cache operations +struct CacheOperation { + enum Kind { + /// Cache at all levels - accessed again + Always, + /// Cache at global level + Global, + /// Streaming - likely to be accessed once + Streaming, + /// Indicates the line will not be used again + LastUse, + /// Don't cache, and fetch again + Volatile, + /// Write back at all coherent levels + WriteBack, + /// Write through to system memory + WriteThrough + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index fc939053d4..48ef02cd0e 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,13 +28,271 @@ #pragma once +#include "cutlass/cutlass.h" + namespace cutlass { namespace arch { ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Fragment type to store loaded data + typename AccessType, + /// The bytes of loading + int LoadBytes + > +struct global_load; ///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %9, 0;\n" + " mov.b32 %0, %10;\n" + " mov.b32 %1, %11;\n" + " mov.b32 %2, %12;\n" + " mov.b32 %3, %13;\n" + " mov.b32 %4, %14;\n" + " mov.b32 %5, %15;\n" + " mov.b32 %6, %16;\n" + " mov.b32 %7, %17;\n" + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), + "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), + "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), + "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); + } +}; + + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %7;\n" + " mov.b32 %2, %8;\n" + " mov.b32 %3, %9;\n" + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint2 &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %5;\n" + " @p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + unsigned &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" + " @p ld.global.u32 %0, [%1];\n" + "}\n" + : "=r"(data) + : "l"(ptr), "r"((int)pred_guard), "r"(data)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint16_t &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b16 %0, %3;\n" + " @p ld.global.u16 %0, [%1];\n" + "}\n" + : "=h"(data) + : "l"(ptr), "r"((int)pred_guard), "h"(data)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + if (pred_guard) D = *(reinterpret_cast(ptr)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Fragment type to store loaded data + typename AccessType, + /// The bytes of loading + int LoadBytes + > +struct global_store; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint4 const *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" + "}\n" + : + : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), + "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), + "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint4 const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + "}\n" + : + : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint2 const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " @p st.global.v2.u32 [%0], {%1, %2};\n" + "}\n" + : + : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint32_t const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @p st.global.u32 [%0], %1;\n" + "}\n" + : + : "l"(ptr), "r"(data), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint16_t const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @p st.global.u16 [%0], %1;\n" + "}\n" + : + : "l"(ptr), "h"(data), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + if (pred_guard) *(reinterpret_cast(ptr)) = D; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace arch } // namespace cutlass @@ -42,4 +300,6 @@ namespace arch { ///////////////////////////////////////////////////////////////////////////////////////////////// #include "memory_sm75.h" +#include "memory_sm80.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index c821ddaf91..3fd121b903 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -46,61 +46,99 @@ inline __device__ void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Specializations +// Determine the appropriate way to target PTX's "ldmatrix" instruction. // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if (__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ == 2) - #define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 1 -#else - #define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 0 -#endif +#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11) -#if ! defined(CUDA_NVVM_GET_SHARED_POINTER_ENABLED) - #define CUDA_NVVM_GET_SHARED_POINTER_ENABLED (CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) +#define CUDA_LDMATRIX_ACTIVATED 1 #endif -#if ! defined(CUDA_LDMATRIX_SUPPORTED) - #define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2)) +#define CUDA_LDMATRIX_SUPPORTED 1 #endif -#if ! defined(CUDA_LDMATRIX_ENABLED) - #define CUDA_LDMATRIX_ENABLED (CUDA_LDMATRIX_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/* +#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10) + #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif +#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) + #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1)) #endif -#if (CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) - #define CUDA_LDMATRIX_ACTIVATED 1 -#else - #define CUDA_LDMATRIX_ACTIVATED 0 +#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED) + #define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED #endif +*/ -#if defined(CUTLASS_GET_SMEM_POINTER) - // Use the existing implementation -#elif CUDA_NVVM_GET_SHARED_POINTER_ENABLED - #if ! defined(NVVM_GET_SMEM_POINTER) - #define NVVM_GET_SMEM_POINTER +#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) extern "C" { - // - // This NVVM intrinsic is subject to change in future versions of CUDA. - // Clients should not call it directly. Rather, they should use the - // cutlass::arch::ldsm<>() template. - // - __device__ uint32_t __nvvm_get_smem_pointer(void*); + // + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. Rather, they should use the + // cutlass::arch::ldsm<>() template. + // + __device__ uint32_t __nvvm_get_smem_pointer(void *); } - #endif - #define CUTLASS_GET_SMEM_POINTER(ptr) __nvvm_get_smem_pointer((void*)ptr) #endif ///////////////////////////////////////////////////////////////////////////////////////////////// +/// CUTLASS helper to get SMEM pointer +inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { + +// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to +// the previous internal intrinsics if they are available. +#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTLASS helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm( + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) : "l"(ptr)); + + return smem_ptr; + +#else + + return 0; +#endif +} + +/// CUTLASS helper to get SMEM pointer +inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { + return cutlass_get_smem_pointer(const_cast(ptr)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + template <> inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x; asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); @@ -120,9 +158,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y; asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); @@ -142,9 +180,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y, z, w; asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); @@ -167,9 +205,10 @@ template <> inline __device__ void ldsm( Array & D, void const* ptr) { + #if CUDA_LDMATRIX_ACTIVATED - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x; asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); @@ -189,9 +228,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y; asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); @@ -211,9 +250,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y, z, w; asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); @@ -227,5 +266,6 @@ inline __device__ void ldsm( } ///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace arch } // namespace cutlass diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h new file mode 100644 index 0000000000..04c568760e --- /dev/null +++ b/include/cutlass/arch/memory_sm80.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Architecture-specific operators on memory added for SM80 +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/cache_operation.h" + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + #define CUDA_CP_ASYNC_ACTIVATED 1 +#else + #define CUDA_CP_ASYNC_ACTIVATED 0 +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initiates an asynchronous copy from global memory to shared memory. +/// +/// LDGSTS +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always> +struct cp_async; + +/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate +/// the entire transfer, zeros are written to SMEM if the guard predicate is false. +/// +/// LDGSTS +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always> +struct cp_async_zfill; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async { + /// Copy + CUTLASS_DEVICE + cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + #if CUDA_CP_ASYNC_ACTIVATED + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + #endif + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async_zfill { + /// Copy with zero fill + CUTLASS_DEVICE + cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { + #if CUDA_CP_ASYNC_ACTIVATED + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + int src_in_bytes = (pred_guard ? SizeInBytes : 0); + + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + #endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async { + /// Copy + CUTLASS_DEVICE + cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + #if CUDA_CP_ASYNC_ACTIVATED + + static_assert(SizeInBytes == 16, + "cp.async only supports CacheOperation::Global when access size is 16B."); + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + #endif + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async_zfill { + /// Copy with zero fill + CUTLASS_DEVICE + cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + #if CUDA_CP_ASYNC_ACTIVATED + + static_assert(SizeInBytes == 16, + "cp.async only supports CacheOperation::Global when access size is 16B."); + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + int src_in_bytes = (pred_guard ? SizeInBytes : 0); + + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + #endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTLASS_DEVICE +void cp_async_fence() { + #if CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.commit_group;\n" ::); + #endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +CUTLASS_DEVICE void cp_async_wait() { + #if CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + #endif +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +template <> +CUTLASS_DEVICE void cp_async_wait<0>() { + #if CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.wait_all;\n" ::); + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 6898f51232..d6ea99886e 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,7 +30,9 @@ #include "cutlass/array.h" #include "cutlass/numeric_types.h" + #include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -49,6 +51,26 @@ struct OpMultiplyAddSaturate; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag indicating the input is converted to a narrower type (BF16) +struct OpMultiplyAddFastBF16; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input is converted to a narrower type (F16) +struct OpMultiplyAddFastF16; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the complex multiply-add operation +struct OpMultiplyAddComplex; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the gaussian complex multiply-add operation +struct OpMultiplyAddGaussianComplex; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Tag indicating the inner product is defined by (XOR, POPC) struct OpXorPopc; @@ -142,4 +164,5 @@ struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, El #include "cutlass/arch/mma_sm61.h" #include "cutlass/arch/mma_sm70.h" #include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm50.h b/include/cutlass/arch/mma_sm50.h index 8698a8b3c6..fce521dcee 100644 --- a/include/cutlass/arch/mma_sm50.h +++ b/include/cutlass/arch/mma_sm50.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/mma_sm60.h b/include/cutlass/arch/mma_sm60.h index 6e513cedc5..ab0481ae44 100644 --- a/include/cutlass/arch/mma_sm60.h +++ b/include/cutlass/arch/mma_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/mma_sm61.h b/include/cutlass/arch/mma_sm61.h index 68a1b145f4..9ec8857e8c 100644 --- a/include/cutlass/arch/mma_sm61.h +++ b/include/cutlass/arch/mma_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h index 90721f0dee..b03ce2c1de 100644 --- a/include/cutlass/arch/mma_sm70.h +++ b/include/cutlass/arch/mma_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -27,7 +27,11 @@ */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "mma.h" #include "cutlass/layout/matrix.h" @@ -84,6 +88,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -139,6 +144,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -194,6 +200,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -249,6 +256,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -310,6 +318,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE @@ -385,6 +394,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE @@ -460,6 +470,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE @@ -535,6 +546,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index ee9599b089..ef65f20b97 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/arch/wmma.h" @@ -93,6 +97,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; CUTLASS_HOST_DEVICE void operator()( @@ -154,6 +159,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -215,6 +221,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -271,6 +278,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -327,6 +335,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -384,6 +393,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -446,6 +456,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -502,6 +513,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -558,6 +570,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -614,6 +627,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -676,6 +690,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -732,6 +747,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -788,6 +804,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -844,6 +861,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -906,6 +924,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -962,6 +981,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -1018,6 +1038,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -1074,6 +1095,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -1136,6 +1158,7 @@ struct Mma< using FragmentC = Array; using Operator = OpXorPopc; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h new file mode 100644 index 0000000000..d75aa1336c --- /dev/null +++ b/include/cutlass/arch/mma_sm80.h @@ -0,0 +1,2091 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) + +#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTLASS_ARCH_MMA_SM80_ENABLED +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - Float BF16, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + bfloat16_t, + layout::RowMajor, + bfloat16_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1684 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 4>, + 32, + tfloat32_t, + layout::RowMajor, + tfloat32_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 4>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct Mma, 32, tfloat32_t, layout::RowMajor, + tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + bfloat16_t, + layout::RowMajor, + bfloat16_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 884 - F64 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<8,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + double const & A = reinterpret_cast(a); + double const & B = reinterpret_cast(b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=d"(D[0]), "=d"(D[1]) + : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 - S8 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + + +#else + assert(0); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - S8 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16864 - S4 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int32_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpXorPopc> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpXorPopc; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/simd.h b/include/cutlass/arch/simd.h index 75b38001f1..4520acc9b2 100644 --- a/include/cutlass/arch/simd.h +++ b/include/cutlass/arch/simd.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/simd_sm60.h b/include/cutlass/arch/simd_sm60.h index cd0babd54f..36030a3661 100644 --- a/include/cutlass/arch/simd_sm60.h +++ b/include/cutlass/arch/simd_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/simd_sm61.h b/include/cutlass/arch/simd_sm61.h index e8d5c8897c..94f1c617c3 100644 --- a/include/cutlass/arch/simd_sm61.h +++ b/include/cutlass/arch/simd_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index b2f8d1eb75..88968abdc5 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -68,24 +68,6 @@ namespace cutlass { namespace arch { -///////////////////////////////////////////////////////////////////////////////////////////////// -/// MemoryKind class (Shared vs. Global memory) -///////////////////////////////////////////////////////////////////////////////////////////////// -enum class MemoryKind { - kShared, // Data resides in shared memory - kGlobal // Data resides in global memory -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WarpParams holds architecture-specific constants -///////////////////////////////////////////////////////////////////////////////////////////////// -struct WarpParams { - static int const kThreadsPerWarp = 32; - static int const kQuadsPerWarp = 8; - static int const kThreadsPerQuad = 4; -}; - //////////////////////////////////////////////////////////////////////////////////////////////// /// Statically maps cutlass data types => nvcuda::wmma data types ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -196,7 +178,6 @@ template < struct Wmma; ///////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace arch } // namespace cutlass diff --git a/include/cutlass/arch/wmma_ptx.h b/include/cutlass/arch/wmma_ptx.h deleted file mode 100644 index 6361428669..0000000000 --- a/include/cutlass/arch/wmma_ptx.h +++ /dev/null @@ -1,105 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing warp matrix multiply-add (WMMA) operations -*/ -#pragma once - -#include "cutlass/arch/wmma.h" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// -/// WMMA structures to enclose * PTX * instruction string -/// -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WMMA PTX string load for A, B, and C matrices -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaLoadA; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaLoadB; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaLoadC; -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WMMA Matrix multiply-add operation -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename ElementA_, ///< Data type of A elements - typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) - typename ElementB_, ///< Data type of B elements - typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) - typename ElementC_, ///< Element type of C matrix - typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) - typename Operator = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) -> -struct PtxWmma; -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WMMA store for matrix D -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaStoreD; -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index 63363ed710..94eeb93deb 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -68,6 +72,7 @@ struct Wmma< using ElementC = ElementC_; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm70; // check supported wmma shape for the given multiplicand data types static_assert( diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index c5c15e9d6b..1b8cc1161e 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -65,6 +69,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; // check supported wmma shape for the given multiplicand data types static_assert( @@ -145,6 +150,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; // check supported wmma shape for the given multiplicand data types static_assert( diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index a1bccbfcf8..f630712fc6 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -65,6 +69,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm75; // check supported wmma shape for the given multiplicand data types static_assert( @@ -115,8 +120,7 @@ struct Wmma< //////////////////////////////////////////////////////////////////////////////// // // WMMA template structure defines nvcuda::wmma::fragments and static assert for -// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1) -// (nvcuda::wmma targetting SASS instruction BMMA) +// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1). // //////////////////////////////////////////////////////////////////////////////// template < @@ -143,6 +147,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpXorPopc; + using ArchTag = arch::Sm75; // check supported wmma shape for the given multiplicand data types static_assert( diff --git a/include/cutlass/array.h b/include/cutlass/array.h index be14a879e8..0018b76f5a 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -167,7 +167,7 @@ class Array { class const_iterator { /// Pointer to object - T *ptr_; + const T *ptr_; public: diff --git a/include/cutlass/array_planar_complex.h b/include/cutlass/array_planar_complex.h new file mode 100644 index 0000000000..e2dbbc47cb --- /dev/null +++ b/include/cutlass/array_planar_complex.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Array holding planar complex elements +template +struct ArrayPlanarComplex { + + /// Underlying real element + using Element = Element_; + + /// Number of logical elements + static size_t const kElements = N; + + /// Underlying Fragment of real-valued elemenets + using ArrayReal = Array; + +public: + + /// Fragment of real-valued elements representing the real part + ArrayReal real; + + /// Fragment of real-valued elements representing the imaginary part + ArrayReal imag; + +public: + + /// Ctor + CUTLASS_HOST_DEVICE + ArrayPlanarComplex() { } + + /// Ctor + CUTLASS_HOST_DEVICE + ArrayPlanarComplex( + ArrayReal const &real_, + ArrayReal const &imag_ + ): + real(real_), imag(imag_) { } + + /// Sets the array to zero efficiently + CUTLASS_HOST_DEVICE + void clear() { + real.clear(); + imag.clear(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to deduce template arguments +template +CUTLASS_HOST_DEVICE +ArrayPlanarComplex +make_ArrayPlanarComplex(Array const &real, Array const &imag) { + return ArrayPlanarComplex(real, imag); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index b340c890fb..78081facc7 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h new file mode 100644 index 0000000000..c3bd1782bb --- /dev/null +++ b/include/cutlass/bfloat16.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a proxy class for storing non-standard 16-bit floating point values with + 8 bits of exponent and 7 bit of mantissa. +*/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#include +#include +#endif + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Floating-point type with 8 bits of exponent and 7 bits of mantissa. +struct alignas(2) bfloat16_t { + + // + // Data members + // + + /// Storage type + uint16_t storage; + + // + // Methods + // + + /// Constructs from an unsigned short + CUTLASS_HOST_DEVICE + static bfloat16_t bitcast(uint16_t x) { + bfloat16_t h; + h.storage = x; + return h; + } + + /// Default constructor + CUTLASS_HOST_DEVICE + bfloat16_t() { } + + /// Floating-point conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(float x) { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + + asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); + + #else + uint32_t bits = reinterpret_cast(x); + + if ((bits & 0x7f800000) != 0x7f800000) { + + bool mantissa_bit = ((bits & (1 << 16)) != 0); + bool round_bit = ((bits & (1 << 15)) != 0); + bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { + bits += uint32_t(1 << 16); + } + } + else if (bits & ~0xff800000) { + bits = 0x7fffffff; + } + + storage = uint16_t((bits >> 16) & 0xffff); + #endif + } + + /// Floating-point conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(double x): bfloat16_t(float(x)) { + + } + + /// Integer conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(int x) { + float flt = static_cast(x); + storage = uint16_t(reinterpret_cast(flt) >> 16); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + unsigned bits = (unsigned(storage) << 16); + return reinterpret_cast(bits); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator double() const { + return double(float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(float(*this)); + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + operator bool() const { + return (float(*this) != 0.0f); + } + + /// Obtains raw bits + CUTLASS_HOST_DEVICE + uint16_t raw() const { + return storage; + } + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((raw() & 0x8000) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((raw() >> 7) & 0x0ff); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 127; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(raw() & 0x7f); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool signbit(cutlass::bfloat16_t const& h) { + return h.signbit(); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { + return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isnan(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() == 0x0ff) && h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isfinite(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() != 0x0ff); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t nan_bf16(const char*) { + // NVIDIA canonical NaN + return cutlass::bfloat16_t::bitcast(0x7fff); +} + +CUTLASS_HOST_DEVICE +bool isinf(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() == 0x0ff) && !h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isnormal(cutlass::bfloat16_t const& h) { + return h.exponent_biased() && h.exponent_biased() != 0x0ff; +} + +CUTLASS_HOST_DEVICE +int fpclassify(cutlass::bfloat16_t const& h) { + int exp = h.exponent_biased(); + int mantissa = h.mantissa(); + if (exp == 0x0ff) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { +#if defined(__CUDACC_RTC__) + return cutlass::bfloat16_t(sqrtf(float(h))); +#else + return cutlass::bfloat16_t(std::sqrt(float(h))); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { + + uint16_t a_mag = (reinterpret_cast(a) & 0x7fff); + uint16_t b_sign = (reinterpret_cast(b) & 0x8000); + uint16_t result = (a_mag | b_sign); + + return reinterpret_cast(result); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace std { + +#if !defined(__CUDACC_RTC__) +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; +#endif + +} // namespace std + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) + float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator-(bfloat16_t const& lhs) { + return bfloat16_t(-float(lhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator++(bfloat16_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = bfloat16_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator--(bfloat16_t & lhs) { + float tmp(lhs); + --tmp; + lhs = bfloat16_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator++(bfloat16_t & lhs, int) { + bfloat16_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = bfloat16_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator--(bfloat16_t & lhs, int) { + bfloat16_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = bfloat16_t(tmp); + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t operator "" _bf16(long double x) { + return cutlass::bfloat16_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { + return cutlass::bfloat16_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index b479d31069..6f7d73bb91 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -25,12 +25,19 @@ #pragma once #include +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/half.h" #include "cutlass/real.h" +#include "cutlass/bfloat16.h" +#include "cutlass/tfloat32.h" + #if !defined(__CUDACC_RTC__) #include #endif @@ -351,11 +358,30 @@ CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) static_cast(imag(z)) * static_cast(imag(z)); } +/// Returns the complex conjugate +CUTLASS_HOST_DEVICE float conj(float const &z) { + return z; +} + +/// Returns the complex conjugate +CUTLASS_HOST_DEVICE double conj(double const &z) { + return z; +} + /// Returns the complex conjugate template CUTLASS_HOST_DEVICE complex conj(complex const &z) { return complex(real(z), -imag(z)); } +/// Indentity transform for non-complex types +template +CUTLASS_HOST_DEVICE T conj(T const &z) { + static_assert( !std::is_same::value && + !std::is_same::value && + !std::is_same>::value && + !std::is_same>::value, "May not be a complex data type"); + return z; +} /// Projects the complex number z onto the Riemann sphere template @@ -414,6 +440,11 @@ CUTLASS_HOST_DEVICE complex sin(complex const &z) { template struct RealType< complex > { using Type = T; + +CUTLASS_HOST_DEVICE + static complex from_real(double x) { + return complex(static_cast(x)); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -438,5 +469,18 @@ cutlass::complex from_real >(double r) { ////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct is_complex { + static bool const value = false; +}; + +template +struct is_complex> { + static bool const value = true; +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass +////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 7f40ede39f..82613c2450 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,6 +28,12 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + #include "cutlass/cutlass.h" namespace cutlass { @@ -354,6 +360,29 @@ struct Coord { namespace cutlass { + +/// Scalar multiplication +template +CUTLASS_HOST_DEVICE +Coord operator*(Index s, Coord coord) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] *= s; + } + return coord; +} + +/// Scalar multiplication +template +CUTLASS_HOST_DEVICE +Coord operator*(Coord coord, Index s) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] *= s; + } + return coord; +} + /// Scalar division template CUTLASS_HOST_DEVICE @@ -413,3 +442,4 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) { //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass + diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index d9dc789055..a87ecfa707 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,9 +33,14 @@ #include "cutlass/coord.h" #include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/gemm/gemm.h" namespace cutlass { +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -47,8 +52,6 @@ std::ostream& operator<<(std::ostream& out, Coord const& coord) { return out; } -/////////////////////////////////////////////////////////////////////////////////////////////////// - inline std::istream & operator>>(std::istream &stream, half_t &x) { float tmp; @@ -62,6 +65,16 @@ std::ostream & operator<<(std::ostream &out, half_t const &x) { return out << float(x); } +inline +std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { + return out << float(x); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to enable formatted printing of CUTLASS scalar types to an ostream @@ -98,7 +111,54 @@ inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scal return out << unsigned(scalar.value); } + +/// Default printing to ostream for MatrixShape +template +inline +std::ostream & operator<<(std::ostream &out, cutlass::MatrixShape const &matrix_shape) { + out << "cutlass::MatrixShape::(kRow, kColumn) {" + << cutlass::MatrixShape::kRow <<"," + << cutlass::MatrixShape::kColumn <<"}"; + return out; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass::gemm namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// +namespace gemm { -} // namespace cutlass +/// Default printing to ostream for GemmShape +template +inline +std::ostream & operator<<(std::ostream &out, cutlass::gemm::GemmShape const &gemm_shape) { + out << "cutlass::GemmShape::(kM, kN, kK) {" + << cutlass::gemm::GemmShape::kM <<"," + << cutlass::gemm::GemmShape::kN <<"," + << cutlass::gemm::GemmShape::kK << "}"; + return out; +} + +} //namespace gemm +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass::layout namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// +namespace layout { +/// Default printing to ostream for PitchLinearShape +template < int Contiguous, int Strided> +inline +std::ostream & operator<<(std::ostream &out, cutlass::layout::PitchLinearShape const &pitch_linear_shape) { + out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {" + << cutlass::layout::PitchLinearShape::kContiguous <<"," + << cutlass::layout::PitchLinearShape::kStrided <<"}"; + return out; +} + +} //namespace layout +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index d50b2511e9..860dc3e566 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -35,25 +35,41 @@ namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#endif + /// Status code returned by CUTLASS operations enum class Status { kSuccess, ///< Operation was successful. kErrorMisalignedOperand, ///< operands fail alignment requirements. + kErrorInvalidDataType, ///< DataType fails requirement. kErrorInvalidLayout, ///< Layout fails alignment requirement. kErrorInvalidProblem, ///< Specified problem size is not supported by operator. kErrorNotSupported, ///< Operation is not supported on current device. kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null. kErrorInternal, ///< An error within CUTLASS occurred. + kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for. + kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old. kInvalid ///< Status is unspecified. }; /// Convert cutlass status to status strings -static inline char const* cutlassGetStatusString(cutlass::Status status) { +CUTLASS_HOST_DEVICE +static char const* cutlassGetStatusString(cutlass::Status status) { switch (status) { case cutlass::Status::kSuccess: return "Success"; case cutlass::Status::kErrorMisalignedOperand: return "Error Misaligned Operand"; + case cutlass::Status::kErrorInvalidDataType: + return "Error Invalid Data Type"; case cutlass::Status::kErrorInvalidLayout: return "Error Invalid Layout"; case cutlass::Status::kErrorInvalidProblem: @@ -64,6 +80,10 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) { return "Error Workspace Null"; case cutlass::Status::kErrorInternal: return "Error Internal"; + case cutlass::Status::kErrorInsufficientDriver: + return "Error Insufficient Driver"; + case cutlass::Status::kErrorArchMismatch: + return "Erroor Architecture Mismatch"; case cutlass::Status::kInvalid: break; } @@ -79,16 +99,6 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#elif defined(__CUDACC_RTC__) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#else -#define CUTLASS_HOST_DEVICE inline -#endif - #define CUTLASS_ASSERT(x) assert(x) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -115,6 +125,12 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Debug { + typename T::X x; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// static const int NUM_THREADS_PER_WARP = 32; static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2; @@ -131,6 +147,14 @@ int LaneId() { return ret; } +/// Computes SM number the thread is running on +CUTLASS_DEVICE +int SmId() { + int ret; + asm ("mov.u32 %0, %%smid;" : "=r"(ret)); + return ret; +} + #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 4a992bb3c3..f5166ab16a 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h new file mode 100644 index 0000000000..c0f42146e6 --- /dev/null +++ b/include/cutlass/epilogue/thread/activation.h @@ -0,0 +1,119 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This extends the contents of cutlass/functional.h with frequently used activation functions. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/complex.h" + +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// ReLu operator - propagates NaNs +template +struct ReLu { + CUTLASS_HOST_DEVICE + T operator()(T const & threshold, T const &value) const { + if (value < threshold) { + value = threshold; + } + return value; + } +}; + +template +struct ReLu> { + CUTLASS_HOST_DEVICE + Array operator()(T const & threshold, Array const &frag) const { + Array result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + T value = frag[i]; + if (value < threshold) { + value = threshold; + } + result[i] = value; + } + return result; + } +}; + +// Sigmoid operator +template +struct Sigmoid { + CUTLASS_HOST_DEVICE + T operator()(T const &scalar) const { + return T(1) / (T(1) + exp(-scalar)); + } +}; + +template <> +struct Sigmoid { + CUTLASS_HOST_DEVICE + float operator()(float const &scalar) const { + return 1.0f / (1.0f + expf(-scalar)); + } +}; + +template +struct Sigmoid > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + Array y; + Sigmoid sigmoid_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(rhs.size()); ++i) { + y[i] = sigmoid_op(rhs[i]); + } + + return y; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/epilogue/thread/conversion_op.h b/include/cutlass/epilogue/thread/conversion_op.h index 32b885bc68..ad17d41490 100644 --- a/include/cutlass/epilogue/thread/conversion_op.h +++ b/include/cutlass/epilogue/thread/conversion_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -101,7 +101,7 @@ class Convert { CUTLASS_HOST_DEVICE FragmentOutput operator()( FragmentAccumulator const &accumulator, - FragmentOutput const &source, + FragmentOutput const &source = FragmentOutput(), ElementCompute uniform = ElementCompute(0)) const { // Convert to destination numeric type diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index dd8236b3c4..8b5f6ead1c 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -165,6 +165,28 @@ class LinearCombination { return destination_converter(intermediate); } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + multiplies mul_accumulator; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index 75843b38e9..25611bd36c 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -167,10 +167,11 @@ class LinearCombinationClamp { intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X /// Clamping constant value - ElementCompute const kClamp = ElementCompute(1 << (sizeof_bits::value - 1)); - - intermediate = max_accumulator(intermediate, -kClamp); - intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); + ElementCompute const kClamp = + ElementCompute((1U << (sizeof_bits::value - 1)) - 1); + + intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1)); + intermediate = min_accumulator(intermediate, kClamp); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -178,12 +179,45 @@ class LinearCombinationClamp { return destination_converter(intermediate); } + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_accumulator; + + minimum min_accumulator; + maximum max_accumulator; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + /// Clamping constant value + ElementCompute const kClamp = + ElementCompute((1U << (sizeof_bits::value - 1)) - 1); + + intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1)); + intermediate = min_accumulator(intermediate, kClamp); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Conditional guards to enable partial specialization for packed integers -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) /// Applies a linear combination operator to an array of elements then clamps the output before /// converting to the output element type. @@ -278,7 +312,7 @@ class LinearCombinationClamp { beta_ = ElementCompute(1); } } - + /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( @@ -316,11 +350,212 @@ class LinearCombinationClamp { return destination_converter(scaled_accumulator); } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Compute linear scaling in floating point + ComputeFragment intermediate; + + multiplies mul_add_accumulator; + + // Float min-max + intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + scaled_accumulator[i] = static_cast(intermediate[i]); + } + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(scaled_accumulator); + } }; #endif // Conditional guards to enable partial specialization for packed integers -///////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements then clamps +/// the output before converting to the output element type. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +/// Note: The below method only works for small k dimensions. The default +/// approach is above +/// TODO: Add logic to fallback to the default approach +template < + /// Data type used to load and store< tensors + typename ElementOutput_, + /// Number of elements computed per operation + int Count, + /// Rounding mode + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class FastLinearCombinationClamp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = int; + using ElementCompute = float; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + /// scales accumulators + ElementCompute alpha; + /// scales source tensor + ElementCompute beta; + /// pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *alpha_ptr; + /// pointer to source scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, ElementCompute beta) + : alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) + : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + FastLinearCombinationClamp(Params const ¶ms) { + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return beta_ != ElementCompute(0); } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, + FragmentOutput const &source, + ElementCompute uniform = ElementCompute(0)) const { + // Convert source to interal compute numeric type + FastNumericArrayConverter + source_converter; + FastNumericArrayConverter + accumulator_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Compute linear scaling in floating point + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + minimum min_accumulator; + maximum max_accumulator; + + // Float min-max + intermediate = + mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, + intermediate); // D = alpha * Accum + X + + /// Clamping constant value + ElementCompute const kClamp = + ElementCompute(1 << (sizeof_bits::value - 1)); + + intermediate = max_accumulator(intermediate, -kClamp); + intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); + + // Convert to destination numeric type + FastNumericArrayConverter + destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + FastNumericArrayConverter + accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Compute linear scaling in floating point + ComputeFragment intermediate; + + multiplies mul_accumulator; + + minimum min_accumulator; + maximum max_accumulator; + + // Float min-max + intermediate = mul_accumulator(alpha_, converted_accumulator); + + /// Clamping constant value + ElementCompute const kClamp = + ElementCompute(1 << (sizeof_bits::value - 1)); + + intermediate = max_accumulator(intermediate, -kClamp); + intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); + + // Convert to destination numeric type + FastNumericArrayConverter + destination_converter; + + return destination_converter(intermediate); + } +}; + +//////////////////////////////////////////////////////////////////////////////// } // namespace thread } // namespace epilogue diff --git a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h new file mode 100644 index 0000000000..3934af1041 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations on planar-complex arrays +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/complex.h" +#include "cutlass/array_planar_complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to arrays of planar-complex elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +/// Note, as with most CUTLASS components for planar complex, the template arguments describe +/// the underlying real data type. +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationPlanarComplex { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = ArrayPlanarComplex; + using FragmentAccumulator = ArrayPlanarComplex; + using ComputeFragment = ArrayPlanarComplex; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + complex alpha; ///< scales accumulators + complex beta; ///< scales source tensor + complex const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + complex const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + complex alpha, + complex beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + complex const *alpha_ptr, + complex const *beta_ptr + ): alpha(complex()), beta(complex()), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + complex alpha_; + complex beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationPlanarComplex(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_source( + source_converter(source.real), + source_converter(source.imag)); + + ComputeFragment converted_accumulator( + accumulator_converter(accumulator.real), + accumulator_converter(accumulator.imag)); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies > mul_op; + multiply_add > mul_add_op; + + // complex multiply: I = beta * C + intermediate.real = mul_op(beta_.real(), converted_source.real); + intermediate.imag = mul_op(beta_.real(), converted_source.imag); + + intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real); + intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag); + + // complex multiply-add: I = alpha * AB + I + intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real); + intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag); + + intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); + intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return FragmentOutput( + destination_converter(intermediate.real), + destination_converter(intermediate.imag)); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator( + accumulator_converter(accumulator.real), + accumulator_converter(accumulator.imag)); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies > mul_op; + multiply_add > mul_add_op; + + // complex multiply-add: I = alpha * AB + I + intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real); + intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag); + + intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); + intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return FragmentOutput( + destination_converter(intermediate.real), + destination_converter(intermediate.imag)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index f0514d4e3b..7a2fa9e8af 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -23,8 +23,7 @@ * **************************************************************************************************/ /*! \file - \brief Functor performing linear combination operations used by epilogues. Values are clamped before - converting to the output element type. + \brief Functor performing linear combination with a maximum operation used by epilogues. */ #pragma once @@ -34,6 +33,7 @@ #include "cutlass/array.h" #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -43,8 +43,7 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Applies a linear combination operator to an array of elements then clamps the output before -/// converting to the output element type. +/// Applies a linear combination operator to an array of elements. /// /// D = alpha * accumulator + beta * source + uniform /// @@ -75,10 +74,10 @@ class LinearCombinationRelu { ElementCompute alpha; ///< scales accumulators ElementCompute beta; ///< scales source tensor - ElementCompute threshold; ///< Relu threshold + ElementCompute threshold; ///< minimum value that is output ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - + ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory // // Methods // @@ -87,16 +86,17 @@ class LinearCombinationRelu { Params(): alpha(ElementCompute(1)), beta(ElementCompute(0)), - threshold(ElementCompute(0)), + threshold(ElementCompute(0)), alpha_ptr(nullptr), - beta_ptr(nullptr) { } + beta_ptr(nullptr), + threshold_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha, ElementCompute beta, - ElementCompute threshold = ElementCompute(0) - ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { + ElementCompute threshold = ElementCompute(0) + ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) { } @@ -104,8 +104,8 @@ class LinearCombinationRelu { Params( ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, - ElementCompute threshold = ElementCompute(0) - ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + ElementCompute const *threshold_ptr = nullptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) { } }; @@ -128,7 +128,7 @@ class LinearCombinationRelu { alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); - threshold_ = params.threshold; + threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold); } /// Returns true if source is needed @@ -144,13 +144,12 @@ class LinearCombinationRelu { beta_ = ElementCompute(1); } } - + /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( FragmentAccumulator const &accumulator, - FragmentOutput const &source, - ElementCompute uniform = ElementCompute(0)) const { + FragmentOutput const &source) const { // Convert source to interal compute numeric type NumericArrayConverter source_converter; @@ -160,18 +159,44 @@ class LinearCombinationRelu { ComputeFragment converted_accumulator = accumulator_converter(accumulator); // Perform binary operations - ComputeFragment intermediate; multiplies mul_add_source; multiply_add mul_add_accumulator; - - maximum max_accumulator; + ReLu relu; intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X - - intermediate = max_accumulator(intermediate, threshold_); + + // Compute threshold optionally + intermediate = relu(threshold_, intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_accumulator; + ReLu relu; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + // Compute threshold optionally + intermediate = relu(threshold_, intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -183,17 +208,21 @@ class LinearCombinationRelu { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Applies a linear combination operator to an array of elements then clamps the output before -/// converting to the output element type. +// Conditional guards to enable partial specialization for packed integers +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) + +/// Applies a linear combination operator to an array of elements. /// /// D = alpha * accumulator + beta * source + uniform /// +/// Special handling for int types + template < typename ElementOutput_, ///< Data type used to load and store tensors int Count, ///< Number of elements computed per operation FloatRoundStyle Round > -class LinearCombinationRelu { +class LinearCombinationRelu { public: using ElementOutput = ElementOutput_; @@ -213,10 +242,10 @@ class LinearCombinationRelu { ElementCompute alpha; ///< scales accumulators ElementCompute beta; ///< scales source tensor - ElementCompute threshold; ///< Relu threshold + ElementCompute threshold; ///< minimum value that is output ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - + ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory // // Methods // @@ -225,16 +254,17 @@ class LinearCombinationRelu { Params(): alpha(ElementCompute(1)), beta(ElementCompute(0)), - threshold(ElementCompute(0)), + threshold(ElementCompute(0)), alpha_ptr(nullptr), - beta_ptr(nullptr) { } + beta_ptr(nullptr), + threshold_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha, ElementCompute beta, - ElementCompute threshold = ElementCompute(0) - ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { + ElementCompute threshold = ElementCompute(0) + ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) { } @@ -242,8 +272,8 @@ class LinearCombinationRelu { Params( ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, - ElementCompute threshold = ElementCompute(0) - ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + ElementCompute const *threshold_ptr = nullptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) { } }; @@ -266,7 +296,7 @@ class LinearCombinationRelu { alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); - threshold_ = params.threshold; + threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold); } /// Returns true if source is needed @@ -282,13 +312,12 @@ class LinearCombinationRelu { beta_ = ElementCompute(1); } } - + /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( FragmentAccumulator const &accumulator, - FragmentOutput const &source, - ElementCompute uniform = ElementCompute(0)) const { + FragmentOutput const &source) const { // Convert source to interal compute numeric type NumericArrayConverter source_converter; @@ -298,21 +327,16 @@ class LinearCombinationRelu { ComputeFragment converted_accumulator = accumulator_converter(accumulator); // Perform binary operations - ComputeFragment intermediate; multiplies mul_add_source; multiply_add mul_add_accumulator; - - maximum max_accumulator; + ReLu relu; intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X - - // Clamp to theshold - intermediate = max_accumulator(intermediate, threshold_); - // Convert back to accumulator data type + // Convert floats back to INT FragmentAccumulator scaled_accumulator; CUTLASS_PRAGMA_UNROLL @@ -320,15 +344,58 @@ class LinearCombinationRelu { scaled_accumulator[i] = static_cast(intermediate[i]); } - // Convert to destination numeric type and pack - NumericArrayConverter destination_converter; + // Compute threshold optionally + scaled_accumulator = relu(threshold_, scaled_accumulator); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(scaled_accumulator); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_accumulator; + ReLu relu; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + // Convert floats back to INT + FragmentAccumulator scaled_accumulator; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + scaled_accumulator[i] = static_cast(intermediate[i]); + } + + // Compute threshold optionally + scaled_accumulator = relu(threshold_, scaled_accumulator); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; return destination_converter(scaled_accumulator); } }; +#endif // Conditional guards to enable partial specialization for packed integers + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread } // namespace epilogue } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h new file mode 100644 index 0000000000..3a65c49acf --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationSigmoid { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationSigmoid(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + Sigmoid sigmoid; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + intermediate = sigmoid(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_add_accumulator; + Sigmoid sigmoid; + + intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + intermediate = sigmoid(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/thread/reduction_op.h b/include/cutlass/epilogue/thread/reduction_op.h index b33332e931..0331f0fad5 100644 --- a/include/cutlass/epilogue/thread/reduction_op.h +++ b/include/cutlass/epilogue/thread/reduction_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h index d11c623d96..67fccf05c2 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -45,6 +45,7 @@ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h" #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" @@ -58,9 +59,102 @@ namespace cutlass { namespace epilogue { namespace threadblock { -//////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization and defines sensible defaults for epilogues for complex*complex case +// 4 real-valued mma operations (Complex) +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Epilouge Shape + typename Shape_, + /// Warp-level mma operator + typename WarpMmaTensorOp_, + /// Number of k partitions + int PartitionsK, + /// Epilogue output operator + typename OutputOp_, + /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load() + int ElementsPerAccess, + /// Multiply-add operator + /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator_ = arch::OpMultiplyAddComplex> +struct DefaultEpilogueComplexTensorOp { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + using Operator = Operator_; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; -/// Defines sensible defaults for epilogues for TensorOps. + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC + >; + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + ElementAccumulator, + LayoutC + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< + typename OutputTileThreadMap::CompactedThreadMap, + ElementAccumulator + >; + + /// Hard-coded padding elements added + using Padding = cutlass::MatrixShape<0, 0>; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization and defines sensible defaults for epilogues for complex*complex case +// 3 real-valued mma operations (Gaussian Complex) +// A = (ar + j ai), B = (br +j bi), D = AB +// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) +// D = dr + j di = (P1 - P3) + j (P1 + P2) +///////////////////////////////////////////////////////////////////////////////////////////////// template < typename Shape_, typename WarpMmaTensorOp_, @@ -68,13 +162,16 @@ template < typename OutputOp_, int ElementsPerAccess > -struct DefaultEpilogueComplexTensorOp { +struct DefaultEpilogueComplexTensorOp { using Shape = Shape_; using WarpMmaTensorOp = WarpMmaTensorOp_; static int const kPartitionsK = PartitionsK; using OutputOp = OutputOp_; static int const kElementsPerAccess = ElementsPerAccess; + using Operator = arch::OpMultiplyAddGaussianComplex; using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaTensorOp::LayoutC; @@ -97,7 +194,7 @@ struct DefaultEpilogueComplexTensorOp { ElementOutput >; - using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp< typename WarpMmaTensorOp::Shape, typename WarpMmaTensorOp::Policy::Operator::Shape, typename WarpMmaTensorOp::Policy::Operator::ElementC, diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h b/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h new file mode 100644 index 0000000000..bb2fdb6b8c --- /dev/null +++ b/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Constructs a default epilogue for planar complex outputs. + + This template reuses components for real-valued epilogues and applies them to planar complex + output matrices. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/array_planar_complex.h" + +#include "cutlass/arch/arch.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMma_, + typename OpcodeClass_, + typename ArchTag_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassTensorOp, + arch::Sm70, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueVoltaTensorOp< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassTensorOp, + arch::Sm75, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueTensorOp< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassTensorOp, + arch::Sm80, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueTensorOp< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + typename ArchTag_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassSimt, + ArchTag_, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueSimt< + ThreadblockShape_, + WarpMmaOperator_, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index d39ad1d941..00bf26d35b 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,6 +39,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/epilogue/thread/conversion_op.h" #include "cutlass/epilogue/thread/reduction_op.h" diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 5afb1f22c1..51ebab37d9 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,16 +39,20 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/epilogue/thread/conversion_op.h" #include "cutlass/epilogue/thread/reduction_op.h" #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/interleaved_epilogue.h" @@ -61,6 +65,177 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// +namespace detail { + +template < + typename ElementOutput, + typename ElementAccumulator, + int ElementsPerAccess, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + ElementAccumulator, + layout::RowMajor + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< + ThreadMap, + ElementAccumulator + >; +}; + +/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts. +template < + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + half_t, + float, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + WarpShape, + InstructionShape, + float, + 32, + 16, + 8, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + float, + 32, + 16, + 8, + 8 + >; +}; + +/// Partial specialization for int8_t x 16 <= int32_t x 16 epilogues avoids shared memory bank conflicts. +template < + int K, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + int8_t, + int32_t, + 16, + gemm::GemmShape<128, 128, K>, + gemm::GemmShape<64, 64, K>, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + gemm::GemmShape<64, 64, K>, + InstructionShape, + int32_t, + 32, + 8, + 16, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + int32_t, + 32, + 8, + 16, + 8 + >; +}; + +/// Partial specialization for int8_t x 8 <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template < + int K, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + int8_t, + int32_t, + 8, + gemm::GemmShape<128, 64, K>, + gemm::GemmShape<64, 32, K>, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + gemm::GemmShape<64, 32, K>, + InstructionShape, + int32_t, + 32, + 8, + 8, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + int32_t, + 32, + 8, + 8, + 8 + >; +}; + +/// Partial specialization for int8_t x 8 <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template < + int K, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + int8_t, + int32_t, + 8, + gemm::GemmShape<64, 64, K>, + gemm::GemmShape<32, 32, K>, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + gemm::GemmShape<32, 32, K>, + InstructionShape, + int32_t, + 32, + 8, + 8, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + int32_t, + 32, + 8, + 8, + 8 + >; +}; + +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + /// Defines sensible defaults for epilogues for TensorOps. template < typename Shape_, @@ -98,25 +273,33 @@ struct DefaultEpilogueTensorOp { ElementOutput >; - using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename WarpMmaTensorOp::Policy::Operator::ElementC, - typename WarpMmaTensorOp::Policy::Operator::FragmentC, - LayoutC - >; - - using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + /// Support several implementations depending on structure of epilogue + using DefaultIterators = detail::DefaultIteratorsTensorOp< + ElementOutput, + ElementAccumulator, + kElementsPerAccess, + Shape, typename WarpMmaTensorOp::Shape, typename WarpMmaTensorOp::Policy::Operator::Shape, - ElementAccumulator, - LayoutC + typename OutputTileThreadMap::CompactedThreadMap >; - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< - typename OutputTileThreadMap::CompactedThreadMap, - ElementAccumulator - >; + using WarpTileIterator = typename DefaultIterators::WarpTileIterator; + using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; /// Hard-coded padding elements added using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; @@ -184,6 +367,7 @@ struct DefaultInterleavedEpilogueTensorOp { }; //////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h index 8a08e03624..7fec5110f4 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,6 +39,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/epilogue/thread/conversion_op.h" #include "cutlass/epilogue/thread/reduction_op.h" diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h index 1fd9f7a510..58425c286c 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -23,7 +23,7 @@ * **************************************************************************************************/ /*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + \brief Epilogue for threadblock scoped GEMMs using WMMA. The epilogue rearranges the result of a matrix product through shared memory to match canonical tensor layouts in global memory. Epilogues support conversion and reduction operations. @@ -39,6 +39,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/epilogue/thread/conversion_op.h" #include "cutlass/epilogue/thread/reduction_op.h" diff --git a/include/cutlass/epilogue/threadblock/default_thread_map_simt.h b/include/cutlass/epilogue/threadblock/default_thread_map_simt.h index 788e07a7dd..8e8f4d339b 100644 --- a/include/cutlass/epilogue/threadblock/default_thread_map_simt.h +++ b/include/cutlass/epilogue/threadblock/default_thread_map_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h b/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h index 74910789db..736e552531 100644 --- a/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h b/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h index 4c4068a37d..45aba393c5 100644 --- a/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h b/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h index 376887c399..34ec750d27 100644 --- a/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/direct_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/direct_epilogue_tensor_op.h index f197112b6a..f14be1ff8e 100644 --- a/include/cutlass/epilogue/threadblock/direct_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/direct_epilogue_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index b8e1e0caf0..0786842019 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -32,7 +32,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -74,7 +78,7 @@ template < class Epilogue : public EpilogueBase< Shape_, - WarpMmaOperator_, + typename WarpMmaOperator_::Shape, PartitionsK, AccumulatorFragmentIterator_, WarpTileIterator_, @@ -84,7 +88,7 @@ class Epilogue : using Base = EpilogueBase< Shape_, - WarpMmaOperator_, + typename WarpMmaOperator_::Shape, PartitionsK, AccumulatorFragmentIterator_, WarpTileIterator_, @@ -172,13 +176,105 @@ class Epilogue : OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + else { + compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); + } + } +private: + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators) { ///< Complete warp-level accumulator tile - typename OutputTileIterator::Fragment source_fragment; + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Convert and store fragment + // + + __syncthreads(); + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + this->warp_tile_iterator_.store(accum_fragment); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) + { + plus add_fragments; + const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment[0]); + + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; - if (!output_op.is_source_needed()) { - source_iterator.clear_mask(); } + } + + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment; source_fragment.clear(); @@ -260,8 +356,6 @@ class Epilogue : } } -private: - /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE void apply_output_operator_( @@ -289,6 +383,30 @@ class Epilogue : output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); } } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment) { + + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i]); + } + } }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_base.h b/include/cutlass/epilogue/threadblock/epilogue_base.h index 55843e2657..a9b5a41404 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -32,7 +32,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/matrix_shape.h" @@ -58,7 +62,7 @@ namespace threadblock { /// Base class for epilogues defining warp-level template < typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) - typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) int PartitionsK, ///< Number of partitions of the K dimension typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM @@ -68,7 +72,7 @@ class EpilogueBase { public: using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; + using WarpShape = WarpShape_; static int const kPartitionsK = PartitionsK; using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; using WarpTileIterator = WarpTileIterator_; @@ -83,11 +87,10 @@ class EpilogueBase { /// Accumulator element using ElementAccumulator = typename AccumulatorTile::Element; - /// Number of warps using WarpCount = gemm::GemmShape< - Shape::kM / WarpMmaOperator::Shape::kM, - Shape::kN / WarpMmaOperator::Shape::kN, + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, kPartitionsK >; @@ -144,24 +147,6 @@ class EpilogueBase { storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn})); } - - CUTLASS_DEVICE - void debug_print() { - if (threadIdx.x == 0) { - - #pragma unroll 1 - for (int r = 0; r < Shape::kRow; ++r) { - - #pragma unroll 1 - for (int c = 0; c < Shape::kColumn; ++c) { - - printf("%d ", int(storage.data()[r * StorageShape::kColumn + c])); - } - printf("\n"); - } - } - __syncthreads(); - } }; protected: diff --git a/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h b/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h new file mode 100644 index 0000000000..6cb9963615 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h @@ -0,0 +1,397 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/array_planar_complex.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator for planar-complex output representations. +/// +/// Note, as with most CUTLASS components for planar complex, the template arguments describe +/// the underlying real data type. +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator + typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) +> +class EpiloguePlanarComplex { +public: + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + /// Output layout is always row-major + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = ArrayPlanarComplex< + typename WarpMmaOperator::FragmentC::Element, + WarpMmaOperator::FragmentC::kElements + >; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Shape of each warp-level operation + using WarpShape = typename WarpMmaOperator::Shape; + + /// Number of warps + using WarpCount = gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + kPartitionsK + >; + + /// Shared memory allocation + struct SharedStorage { + + // + // Type definitions + // + + /// Element type of shared memory + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Layout of shared memory allocation + using Layout = typename WarpTileIterator::Layout; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = MatrixShape< + WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK, + WarpCount::kN * WarpTileIterator::Shape::kColumn + >; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape< + Shape::kRow + Padding::kRow, + Shape::kColumn + Padding::kColumn + >; + + static int const kImaginaryStride = StorageShape::kCount; + + // + // Data members + // + + AlignedBuffer storage; + + // + // Methods + // + + /// Returns a pointer to the shared memory buffer + CUTLASS_DEVICE + Element *data() { + return storage.data(); + } + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference() { + return TensorRef( + storage.data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + +private: + + // + // Data members + // + + SharedStorage &shared_storage_; + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator_; + +public: + + /// Constructor + CUTLASS_DEVICE + EpiloguePlanarComplex( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_storage_(shared_storage), + shared_load_iterator_(shared_storage.reference(), thread_idx), + warp_tile_iterator_(shared_storage.reference(), lane_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: + // + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator_real, ///< Tile iterator for destination + OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment_real; + typename OutputTileIterator::Fragment source_fragment_imag; + + if (!output_op.is_source_needed()) { + source_iterator_real.clear_mask(); + source_iterator_imag.clear_mask(); + } + + source_fragment_real.clear(); + source_fragment_imag.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real); + AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag); + + // + // Iterate over accumulator tile + // + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + source_iterator_real.load(source_fragment_real); + source_iterator_imag.load(source_fragment_imag); + + ++source_iterator_real; + ++source_iterator_imag; + + // + // Convert and store fragment + // + + __syncthreads(); + + typename AccumulatorFragmentIterator::Fragment accum_fragment_real; + typename AccumulatorFragmentIterator::Fragment accum_fragment_imag; + + accum_fragment_iterator_real.load(accum_fragment_real); + accum_fragment_iterator_imag.load(accum_fragment_imag); + + ++accum_fragment_iterator_real; + ++accum_fragment_iterator_imag; + + this->warp_tile_iterator_.store(accum_fragment_real); + this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment_real[0]); + shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time"); + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment_real; + typename OutputTileIterator::Fragment output_fragment_imag; + + apply_output_operator_( + output_fragment_real, + output_fragment_imag, + output_op, + aligned_accum_fragment_real[0], + aligned_accum_fragment_imag[0], + source_fragment_real, + source_fragment_imag); + + // + // Store the final result + // + + destination_iterator_real.store(output_fragment_real); + destination_iterator_imag.store(output_fragment_imag); + + ++destination_iterator_real; + ++destination_iterator_imag; + } + } + +private: + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment &output_fragment_real, + typename OutputTileIterator::Fragment &output_fragment_imag, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag, + typename OutputTileIterator::Fragment const &source_fragment_real, + typename OutputTileIterator::Fragment const &source_fragment_imag) { + + OutputAccessType *output_frag_real_ptr = + reinterpret_cast(&output_fragment_real); + + OutputAccessType *output_frag_imag_ptr = + reinterpret_cast(&output_fragment_imag); + + AccumulatorAccessType const *compute_frag_real_ptr = + reinterpret_cast(&aligned_accum_fragment_real); + + AccumulatorAccessType const *compute_frag_imag_ptr = + reinterpret_cast(&aligned_accum_fragment_imag); + + OutputAccessType const *source_frag_real_ptr = + reinterpret_cast(&source_fragment_real); + + OutputAccessType const *source_frag_imag_ptr = + reinterpret_cast(&source_fragment_imag); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operator + auto result_fragment = output_op( + make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]), + make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i]) + ); + + output_frag_real_ptr[i] = result_fragment.real; + output_frag_imag_ptr[i] = result_fragment.imag; + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_workspace.h b/include/cutlass/epilogue/threadblock/epilogue_workspace.h index 72eb8d2e4d..36d196a37f 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_workspace.h +++ b/include/cutlass/epilogue/threadblock/epilogue_workspace.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index ba97f9cf3f..b616545b9f 100644 --- a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -85,9 +85,6 @@ class InterleavedEpilogue { using OutputTileIterator = OutputTileIterator_; using OutputOp = OutputOp_; - /// Output layout is always row-major - using Layout = layout::ColumnMajorInterleaved; - /// The complete warp-level accumulator tile using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index e4929cfe96..4eb5e3784b 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -437,11 +437,10 @@ struct OutputTileOptimalThreadMap { /// - minimal address arithmetic /// - minimal predicate calculations /// -template struct InterleavedOutputTileThreadMap { using WarpCount = WarpCount_; - using MmaCount = MmaCount_; static int const kWarpSize = 32; static int const kThreads = Threads; @@ -460,7 +459,7 @@ struct InterleavedOutputTileThreadMap { // Output // - using Iterations = MmaCount; + using Iterations = Iterations_; using Delta = layout::PitchLinearShape; diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 9b8941704a..f3c88300ba 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -36,12 +36,12 @@ #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" #include "cutlass/matrix_shape.h" #include "cutlass/tensor_ref.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/epilogue/threadblock/output_tile_thread_map.h" - +#include "cutlass/arch/memory.h" //////////////////////////////////////////////////////////////////////////////// @@ -107,16 +107,16 @@ class PredicatedTileIterator { // Data members // - Index stride; ///< stride in bytes between rows + LongIndex stride; ///< stride in bytes between rows - Index increment_row; ///< increment quantity (in bytes) to advance when moving between rows - Index increment_group; ///< increment quantity (in bytes) to advance when moving to the next group - Index increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster + LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows + LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group + LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster - Index advance_row; ///< amount to add to move to the next 'row' position - Index advance_group; ///< amount to add to move to the next 'group' position - Index advance_cluster; ///< amount to add to move to the next 'cluster' position - Index advance_tile; ///< amount to add to move to the next 'tile' + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_group; ///< amount to add to move to the next 'group' position + LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position + LongIndex advance_tile; ///< amount to add to move to the next 'tile' // // Methods @@ -125,7 +125,7 @@ class PredicatedTileIterator { CUTLASS_HOST_DEVICE Status initialize(Index stride_) { - stride = stride_; + stride = LongIndex(stride_); increment_row = stride * ThreadMap::Delta::kRow; @@ -261,8 +261,8 @@ class PredicatedTileIterator { // Initialize pointer byte_pointer_ = reinterpret_cast(pointer) + - thread_offset.row() * params_.stride + - thread_offset.column() * sizeof(AccessType) / kElementsPerAccess; + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; @@ -276,7 +276,7 @@ class PredicatedTileIterator { /// Loads a fragment from memory CUTLASS_DEVICE - void load(Fragment &frag) { + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { uint8_t *byte_pointer = byte_pointer_; AccessType *frag_ptr = reinterpret_cast(&frag); @@ -299,17 +299,22 @@ class PredicatedTileIterator { bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; - if (guard) { - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess]; - } + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); } if (row + 1 < ThreadMap::Iterations::kRow) { @@ -328,9 +333,15 @@ class PredicatedTileIterator { } } + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_byte_offset(frag, 0); + } + /// Stores a fragment to memory CUTLASS_DEVICE - void store(Fragment const &frag) { + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { uint8_t *byte_pointer = byte_pointer_; AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -352,18 +363,19 @@ class PredicatedTileIterator { bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; - if (guard) { - - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); } if (row + 1 < ThreadMap::Iterations::kRow) { @@ -382,6 +394,12 @@ class PredicatedTileIterator { } } + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_byte_offset(frag, 0); + } + /// Advances to the next position to load or store CUTLASS_HOST_DEVICE PredicatedTileIterator &operator++() { @@ -440,6 +458,7 @@ class PredicatedTileIterator { }; //////////////////////////////////////////////////////////////////////////////// + /// Tile iterator used to load output tile from shared memory in epilogue. /// /// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator @@ -447,7 +466,7 @@ class PredicatedTileIterator { template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename Element_, ///< Element data type - int InterleavedK ///< Number of Interleaved K + int InterleavedN ///< Number of Interleaved N > class InterleavedPredicatedTileIterator { public: @@ -455,7 +474,7 @@ class InterleavedPredicatedTileIterator { using Element = Element_; - using Layout = layout::ColumnMajorInterleaved; + using Layout = layout::ColumnMajorInterleaved; using TensorRef = TensorRef; using ConstTensorRef = typename TensorRef::ConstTensorRef; @@ -483,10 +502,10 @@ class InterleavedPredicatedTileIterator { // Data members // - Index stride; ///< stride in bytes between columns + LongIndex stride; ///< stride in bytes between columns - Index advance_row; ///< amount to add to move to the next 'row' position - Index advance_column; ///< amount to add to move to the next 'column' position + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_column; ///< amount to add to move to the next 'column' position // // Methods @@ -494,14 +513,16 @@ class InterleavedPredicatedTileIterator { CUTLASS_HOST_DEVICE Status initialize(Index stride_) { - stride = stride_; + + stride = LongIndex(stride_); advance_row = ThreadMap::Delta::kContiguous * sizeof_bits::value / 8; - advance_column = - stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess * - sizeof_bits::value * ThreadMap::kWarpSize / 8; + advance_column = LongIndex(stride_) - ThreadMap::Iterations::kContiguous * + kElementsPerAccess * + sizeof_bits::value * + ThreadMap::kWarpSize / 8; return Status::kSuccess; } @@ -602,10 +623,10 @@ class InterleavedPredicatedTileIterator { ): params_(params) { TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + - TensorCoord(threadblock_offset.contiguous() * InterleavedK, - threadblock_offset.strided() / InterleavedK); + TensorCoord(threadblock_offset.contiguous() * InterleavedN, + threadblock_offset.strided() / InterleavedN); - extent_col_ = extent.strided() / InterleavedK; + extent_col_ = extent.strided() / InterleavedN; thread_start_col_ = thread_offset.strided(); // Initialize predicates @@ -613,13 +634,13 @@ class InterleavedPredicatedTileIterator { for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { mask_.predicates[c] = ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) < - (extent.contiguous() * InterleavedK)); + (extent.contiguous() * InterleavedN)); } // Initialize pointer byte_pointer_ = reinterpret_cast(pointer) + - thread_offset.strided() * params_.stride + - thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess; + LongIndex(thread_offset.strided()) * LongIndex(params_.stride) + + LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess; // Initialize internal state counter iteration_contiguous_ = iteration_strided_ = 0; @@ -634,6 +655,7 @@ class InterleavedPredicatedTileIterator { /// Loads a fragment from memory CUTLASS_DEVICE void load(Fragment &frag) { + uint8_t *byte_pointer = byte_pointer_; AccessType *frag_ptr = reinterpret_cast(&frag); AccessType *memory_pointer = reinterpret_cast(byte_pointer); @@ -644,9 +666,13 @@ class InterleavedPredicatedTileIterator { bool guard = col_guard && mask_.predicates[iteration_contiguous_]; - if (guard) { - *frag_ptr = *memory_pointer; - } + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + *frag_ptr, + (void *)memory_pointer, + guard); } /// Stores a fragment to memory @@ -662,9 +688,8 @@ class InterleavedPredicatedTileIterator { bool guard = col_guard && mask_.predicates[iteration_contiguous_]; - if (guard) { - *memory_pointer = *frag_ptr; - } + cutlass::arch::global_store( + *frag_ptr, (void *)memory_pointer, guard); } /// Overrides the internal iteration index @@ -716,6 +741,7 @@ class InterleavedPredicatedTileIterator { } }; +/////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// } // namespace threadblock diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/include/cutlass/epilogue/threadblock/shared_load_iterator.h index 5e4a64b1be..0aa3dbb19d 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -96,6 +96,15 @@ class SharedLoadIterator { ThreadMap::kElementsPerAccess, kAlignment>; + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment) + >; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + private: // @@ -149,7 +158,6 @@ class SharedLoadIterator { CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - AccessType *frag_ptr = reinterpret_cast(&frag); CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { @@ -169,15 +177,19 @@ class SharedLoadIterator { int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - AccessType const *memory_pointer = reinterpret_cast(byte_pointer); + LoadType *frag_ptr = reinterpret_cast(&frag); + LoadType const *memory_pointer = reinterpret_cast(byte_pointer); CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - frag_ptr[frag_idx] = - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess]; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + frag_ptr[frag_idx * kLoadsPerAccess + v] = + memory_pointer[(column * ThreadMap::Delta::kColumn / kElementsPerAccess) * kLoadsPerAccess + v]; + } } } } diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h new file mode 100644 index 0000000000..d37b07d562 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -0,0 +1,559 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops optimized for mixed-precision. + + This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. + + When the fragment is loaded into registers, it matches the row-major thread map assumed by + the predicated tile iterator writing to global memory. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Accumulator data type + int ElementSizeBits_, ///< Size of accumulator in bits + int OutputSizeBits_, ///< Size of output element in bits + int ElementsPerAccess, ///< Vector length of output vector + int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements + /// in the global memory tensor +> +class SharedLoadIteratorMixed; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_ ///< Accumulator data type +> +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + ThreadMap::kElementsPerAccess, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment) + >; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + + // + // Data members + // + + /// Byte-level pointer + LoadType const *pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed( + TensorRef ref, + int thread_idx + ): + stride_((ref.stride(0) / LoadType::kElements)) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * sizeof(LoadType) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_ += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * stride_ + offset.column() / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const *memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for int32_t x 16 => int8_t x 16 +template < + typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap) +> +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = 16; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + 16, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + 4, + 16 + >; + + static int const kLoadsPerAccess = 4; + +private: + + // + // Data members + // + + /// Byte-level pointer + LoadType const *pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed( + TensorRef ref, + int thread_idx + ): + stride_((ref.stride(0) / LoadType::kElements)) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; + + int lane_col_idx = thread_offset.column() / 16; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + int lane_offset = (lane_col_idx % 2) * 4 | ((lane_col_idx / 2) * 8) | ((lane_col_idx / 2) ^ i); + + pointers_[i] = base_ptr + lane_offset; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * stride_ + offset.column() / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + LoadType const *memory_pointer = pointers_[v]; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for int32_t x 8 => int8_t x 8 +template < + typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap) +> +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + 8, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + 4, + 16 + >; + + static int const kLoadsPerAccess = 2; + +private: + + // + // Data members + // + + /// Byte-level pointer + LoadType const *pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed( + TensorRef ref, + int thread_idx + ): + stride_((ref.stride(0) / LoadType::kElements)) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; + + int lane_col_idx = thread_offset.column() / 8; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + int lane_offset = (lane_col_idx % 8) * 2 | ((lane_col_idx / 4) ^ i); + + pointers_[i] = base_ptr + lane_offset; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * stride_ + offset.column() / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + LoadType const *memory_pointer = pointers_[v]; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h index d369a835d8..1bab9104c7 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h new file mode 100644 index 0000000000..4c95649244 --- /dev/null +++ b/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h @@ -0,0 +1,188 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FragmentIteratorGaussianComplexTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type + typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) +> +class FragmentIteratorGaussianComplexTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + complex, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// Size of one part of accumulator of 3-part accumulator in units of number of OperatorElementC + static int const kElementsAccumulatorPerPart = + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; + + /// Offset into the accumulator fragment part 1 + static int const kPart1Index = kElementsAccumulatorPerPart * 0; + + /// Offset into the accumulator fragment part 2 + static int const kPart2Index = kElementsAccumulatorPerPart * 1; + + /// Offset into the accumulator fragment part 3 + static int const kPart3Index = kElementsAccumulatorPerPart * 2; + + /// This is the complete warp-level accumulator tile holding part1, part2, and part3 + using AccumulatorTile = Array; + + /// This is the complete warp-level accumulator tile holding final output of complex type + using OutputAccumulatorTile = Array, kElementsAccumulatorPerPart>; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + + using FragmentAccessType = Array, Policy::kElementsPerAccess>; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorGaussianComplexTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorGaussianComplexTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorGaussianComplexTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + auto const & part1_accum_array = accumulators_[accumulator_access_offset + kPart1Index]; + auto const & part2_accum_array = accumulators_[accumulator_access_offset + kPart2Index / Policy::kElementsPerAccess]; + auto const & part3_accum_array = accumulators_[accumulator_access_offset + kPart3Index / Policy::kElementsPerAccess]; + + // Pack parts 1, 2, and 3 into a structure. This is likely to result in MOVs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Policy::kElementsPerAccess; ++i) { + + frag_ptr[n][i].real() = part1_accum_array[i] - part3_accum_array[i]; + frag_ptr[n][i].imag() = part1_accum_array[i] + part2_accum_array[i]; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/fragment_iterator_simt.h b/include/cutlass/epilogue/warp/fragment_iterator_simt.h index 160844203d..6d75e5697b 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_simt.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h index e19f12b930..f620e4bddf 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h index 15c095ffc1..1abbbdc03c 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h index b96b4c5bc2..79106b111e 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/simt_policy.h b/include/cutlass/epilogue/warp/simt_policy.h index 1d010c6844..3e096978da 100644 --- a/include/cutlass/epilogue/warp/simt_policy.h +++ b/include/cutlass/epilogue/warp/simt_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/tensor_op_policy.h b/include/cutlass/epilogue/warp/tensor_op_policy.h index c02656a521..82e685b843 100644 --- a/include/cutlass/epilogue/warp/tensor_op_policy.h +++ b/include/cutlass/epilogue/warp/tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/tile_iterator_simt.h b/include/cutlass/epilogue/warp/tile_iterator_simt.h index 2164a1349b..a9d03db1c3 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_simt.h +++ b/include/cutlass/epilogue/warp/tile_iterator_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -100,12 +100,28 @@ class TileIteratorSimt; + 4 * Policy::kElementsPerAccess +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES + + 1 +#endif + >; private: +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + 1 + >; + +#else /// Storage type for accessing memory - using AccessType = AlignedArray; + using AccessType = AlignedArray< + Element, + Policy::kElementsPerAccess + >; +#endif // // Data members @@ -130,18 +146,21 @@ class TileIteratorSimt(ref.data())), - layout_(ref.stride()[0] / Policy::kElementsPerAccess) { + layout_(ref.stride()[0] / AccessType::kElements) { auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); MatrixCoord lane_offset = lane_layout.inverse(lane_id); - pointer_ += layout_(lane_offset); + pointer_ += layout_({ + lane_offset.row(), + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); } /// Adds a pointer offset CUTLASS_HOST_DEVICE TileIteratorSimt & add_pointer_offset(Index pointer_offset) { - pointer_ += pointer_offset / Policy::kElementsPerAccess; + pointer_ += pointer_offset / AccessType::kElements; return *this; } @@ -151,7 +170,7 @@ class TileIteratorSimt; ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); - ScalarAccessType *scalarPointer = reinterpret_cast(pointer_); + ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { @@ -187,7 +206,7 @@ class TileIteratorSimt(&frag); CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { - pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn] = frag_ptr[n]; + pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; } #endif } @@ -206,7 +225,7 @@ class TileIteratorSimt +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kOutputElementCount = OutputElementCount; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = + (OutputElementCount * sizeof_bits::value) / (const_min(128, OutputElementCount * sizeof_bits::value)); + + static_assert(kPointerCount <= 4, "Can only accommodate four pointers at present."); + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Detail::kLanesInQuad * Policy::kElementsPerAccess>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount]; + + /// Stride in units of AccessType + int stride_; + + /// Logical column in which warp tile is aligned + int warp_column_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed() { + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] = nullptr; + } + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / Policy::kElementsPerAccess), + warp_column_(0) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = (lane_in_quad % 2) + (((lane_in_quad / 2) + i) % Detail::kPointerCount) * 2; + + ptr += column_idx; + + if (i == 0) { + pointers_[0 % Detail::kPointerCount] = ptr; + } + else if (i == 1) { + pointers_[1 % Detail::kPointerCount] = ptr; + } + else if (i == 2) { + pointers_[2 % Detail::kPointerCount] = ptr; + } + else if (i == 3) { + pointers_[3 % Detail::kPointerCount] = ptr; + } + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / Policy::kElementsPerAccess; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess; + } + + warp_column_ += tile_offset.column() * Shape::kColumn; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; + int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0 % Detail::kPointerCount]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1 % Detail::kPointerCount]; + } + else if (ptr_idx == 2) { + ptr = pointers_[2 % Detail::kPointerCount]; + } + else if (ptr_idx == 3) { + ptr = pointers_[3 % Detail::kPointerCount]; + } + + int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess; +#if 0 + // Using inline PTX to avoid generic memory + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType); + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; + int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; + + AccessType const *smem_ptr = pointers_[ptr_idx]; + frag_ptr[n] = smem_ptr[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for int32_t x 16 => int8_t x 16 +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_ ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = int32_t; + using Layout = layout::RowMajor; + static int const kOutputElementCount = 16; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = 2; + + /// Offsets added + static int const kOffsetCount = 4; + + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount]; + + /// Stride in units of AccessType + int stride_; + + /// Uniform offset in bytes added to warp tile iterator + int uniform_offset_[Detail::kOffsetCount]; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed() { + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] = nullptr; + } + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / AccessType::kElements) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = lane_in_quad ^ (i * 2); + + ptr += column_idx; + + if (i == 0) { + pointers_[0] = ptr; + } + else if (i == 1) { + pointers_[1] = ptr; + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kOffsetCount; ++i) { + uniform_offset_[i] = (i ^ 0) * 4 * sizeof(AccessType); + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / AccessType::kElements; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / AccessType::kElements; + + pointers_[0] += ptr_offset; + pointers_[1] += ptr_offset; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kOffsetCount; ++i) { + uniform_offset_[i] = (i ^ tile_offset.column()) * 4 * sizeof(AccessType); + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int ptr_idx = (n / 4); + int offset_idx = (n % 4); + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1]; + } + + int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements; + +#if 0 + // + // Using inline PTX to avoid generic memory + // + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType) + uniform_offset_[offset_idx]; + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for int32_t x 8 => int8_t x 8 +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_ ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = int32_t; + using Layout = layout::RowMajor; + static int const kOutputElementCount = 8; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = 2; + + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount]; + + /// Stride in units of AccessType + int stride_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed() { + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] = nullptr; + } + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / AccessType::kElements) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = lane_in_quad ^ (i * 2); + + ptr += column_idx; + + if (i == 0) { + pointers_[0] = ptr; + } + else if (i == 1) { + pointers_[1] = ptr; + } + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / AccessType::kElements; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / AccessType::kElements; + + pointers_[0] += ptr_offset; + pointers_[1] += ptr_offset; + + if (tile_offset.column() % 2) { + auto tmp = pointers_[0]; + pointers_[0] = pointers_[1]; + pointers_[1] = tmp; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int ptr_idx = (n / 4); + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1]; + } + + int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements + (n % 4) * 4; + +#if 0 + // + // Using inline PTX to avoid generic memory + // + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType); + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h index 3984680fec..8ffb5ec128 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -187,7 +187,8 @@ class TileIteratorVoltaTensorOp, half_t, int access = access_idx % 2; int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess + - access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + access; + access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + + access + pointer_offset / Policy::kElementsPerAccess; int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; @@ -219,7 +220,9 @@ class TileIteratorVoltaTensorOp, half_t, int access_quad = access_idx / 2; int access = access_idx % 2; - int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + access; + int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + + access + pointer_offset / Policy::kElementsPerAccess; + int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; frag_ptr[frag_idx] = pointer_[ptr_offset]; @@ -382,7 +385,7 @@ class TileIteratorVoltaTensorOp, float, l int ptr_row_offset = row_idx * 2; - int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}); + int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess; pointer_[ptr_offset] = frag_ptr[frag_idx]; } diff --git a/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h index e8299f9d2a..6017b5c7ed 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/volta_tensor_op_policy.h b/include/cutlass/epilogue/warp/volta_tensor_op_policy.h index 631d423e5b..b0ecc5eb6f 100644 --- a/include/cutlass/epilogue/warp/volta_tensor_op_policy.h +++ b/include/cutlass/epilogue/warp/volta_tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h b/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h index fc312c7a61..7b938d3712 100644 --- a/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h +++ b/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index a341922b40..036b08e23b 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -25,7 +25,12 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif + #include "cutlass/cutlass.h" /** @@ -201,6 +206,18 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, rem = src - (quo * div); } +/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b +CUTLASS_HOST_DEVICE +int round_up(int a, int b) { + return ((a + b - 1) / b) * b; +} + +/// Returns the ceiling of (a / b) +CUTLASS_HOST_DEVICE +int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + /****************************************************************************** * Min/Max ******************************************************************************/ diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 00eeff2dbd..13ee7f542b 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -1,5 +1,5 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + /*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -96,6 +96,16 @@ struct multiply_add { } }; +/// Fused multiply-add +template +struct and_add { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b, T const &c) const { + return ((a & b) + c); + } +}; + + /// Fused multiply-add template struct xor_add { @@ -1207,4 +1217,214 @@ struct multiply_add, Array, Array> { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + bfloat16_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned a_packed = static_cast(a.raw()); + a_packed = (a_packed | (a_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + bfloat16_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned b_packed = static_cast(b.raw()); + b_packed = (b_packed | (b_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + bfloat16_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + + unsigned c_packed = static_cast(c.raw()); + c_packed = (c_packed | (c_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index fff34dc4d7..c65b3f0062 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -422,6 +422,342 @@ struct DefaultGemmConfiguration< using Operator = arch::OpMultiplyAddSaturate; }; +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + uint1b_t, + uint1b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 512>; + using WarpShape = GemmShape<64, 64, 512>; + using InstructionShape = GemmShape<8, 8, 128>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpXorPopc; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 16>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = typename platform::conditional< + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; +}; + +//////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 16>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + + +template <> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + complex, + complex, + complex, + complex + > { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<64, 64, 16>; + using WarpShape = GemmShape<32, 32, 16>; + using InstructionShape = GemmShape<8, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + complex, 1, complex, + complex>; + + using Operator = arch::OpMultiplyAddComplex; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + uint8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint8_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint8_t, + uint8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + uint4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint4b_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint4b_t, + uint4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint1b_t, + uint1b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 512>; + using WarpShape = GemmShape<64, 64, 512>; + using InstructionShape = GemmShape<16, 8, 256>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAdd; +}; + +//////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////// } // namespace device } // namespace gemm diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index 55deea6e5e..70383e15ef 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -192,7 +192,8 @@ template < OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::EpilogueOutputOp, /// Threadblock-level swizzling operator - typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle, + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, /// Number of stages used in the pipelined mainloop int Stages = DefaultGemmConfiguration 1) { - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { - return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } - return 0; + return bytes; } /// Initializes GEMM state from arguments. @@ -426,6 +431,7 @@ class Gemm { params_.ref_B.reset(args.ref_B.non_const_ref().data()); params_.ref_C.reset(args.ref_C.non_const_ref().data()); params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; params_.semaphore = static_cast(workspace); return Status::kSuccess; @@ -560,6 +566,8 @@ class Gemm gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator +> +class GemmArray { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + using Operator = Operator_; + + /// Define the kernel + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + false, + Operator, + false + >::GemmKernel; + + using GemmKernel = kernel::GemmArray; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + + ElementA const * const *ptr_A; + LayoutA layout_A; + + ElementB const * const *ptr_B; + LayoutB layout_B; + + ElementC const * const *ptr_C; + LayoutC layout_C; + + ElementC * const * ptr_D; + LayoutC layout_D; + + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + ElementA const * const *ptr_A_, + LayoutA layout_A_, + ElementB const * const *ptr_B_, + LayoutB layout_B_, + ElementC const * const *ptr_C_, + LayoutC layout_C_, + ElementC * const * ptr_D_, + LayoutC layout_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ptr_A(ptr_A_), + layout_A(layout_A_), + ptr_B(ptr_B_), + layout_B(layout_B_), + ptr_C(ptr_C_), + layout_C(layout_C_), + ptr_D(ptr_D_), + layout_D(layout_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + GemmArray() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (args.layout_A.stride(0) % kAlignmentA) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_B.stride(0) % kAlignmentB) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_C.stride(0) % kAlignmentC) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_D.stride(0) % kAlignmentC) { + return Status::kErrorMisalignedOperand; + } + + if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) || + (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) || + (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + args.batch_count, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ptr_A, + args.layout_A, + args.ptr_B, + args.layout_B, + args.ptr_C, + args.layout_C, + args.ptr_D, + args.layout_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + args.batch_count, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ptr_A, + args.layout_A, + args.ptr_B, + args.layout_B, + args.ptr_C, + args.layout_C, + args.ptr_D, + args.layout_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + typename Operator_ +> +class GemmArray< + ElementA_, + LayoutA_, + ElementB_, + LayoutB_, + ElementC_, + layout::ColumnMajor, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + AlignmentA, + AlignmentB, + Operator_ +> { +public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = false; + + // + using UnderlyingOperator = GemmArray< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + + ElementA const * const *ptr_A; + LayoutA layout_A; + + ElementB const * const *ptr_B; + LayoutB layout_B; + + ElementC const * const *ptr_C; + LayoutC layout_C; + + ElementC * const * ptr_D; + LayoutC layout_D; + + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + ElementA const * const *ptr_A_, + LayoutA layout_A_, + ElementB const * const *ptr_B_, + LayoutB layout_B_, + ElementC const * const *ptr_C_, + LayoutC layout_C_, + ElementC * const * ptr_D_, + LayoutC layout_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ptr_A(ptr_A_), + layout_A(layout_A_), + ptr_B(ptr_B_), + layout_B(layout_B_), + ptr_C(ptr_C_), + layout_C(layout_C_), + ptr_D(ptr_D_), + layout_D(layout_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmArray() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + + GemmCoord problem_size{ + args.problem_size.n(), + args.problem_size.m(), + args.problem_size.k() + }; + + return UnderlyingArguments( + problem_size, + args.ptr_B, + args.layout_B.stride(), + args.ptr_A, + args.layout_A.stride(), + args.ptr_C, + args.layout_C.stride(), + args.ptr_D, + args.layout_D.stride(), + args.epilogue, + args.batch_count + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index d2090e96d1..052bd90093 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -370,8 +370,8 @@ class GemmBatched { cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, - args.batch_count, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); // Initialize the Params structure params_ = typename GemmKernel::Params{ diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 5b0dea3b14..8ad1036bb1 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -192,7 +192,7 @@ template < OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::EpilogueOutputOp, /// Threadblock-level swizzling operator - typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle, + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, /// Number of stages used in the pipelined mainloop int Stages = DefaultGemmConfiguration @@ -228,7 +231,11 @@ class GemmComplex { static int const kStages = Stages; static ComplexTransform const kTransformA = TransformA; static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; static bool const kSplitKSerial = SplitKSerial; + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + static int const kAlignmentC = EpilogueOutputOp::kCount; /// Define the kernel using GemmKernel = typename kernel::DefaultGemmComplex< @@ -249,6 +256,7 @@ class GemmComplex { kStages, kTransformA, kTransformB, + Operator, kSplitKSerial >::GemmKernel; @@ -498,6 +506,9 @@ template < ComplexTransform TransformA, /// Complex elementwise transformation on B operand ComplexTransform TransformB, + /// Multiply-add operator + // (selects complex or gaussian complex) + typename Operator_, /// If true, kernel supports split-K as a serial reduction bool SplitKSerial > @@ -519,6 +530,7 @@ class GemmComplex< Stages, TransformA, TransformB, + Operator_, SplitKSerial > { public: @@ -542,6 +554,7 @@ class GemmComplex< using EpilogueOutputOp = EpilogueOutputOp_; using ThreadblockSwizzle = ThreadblockSwizzle_; static int const kStages = Stages; + using Operator = Operator_; static bool const kSplitKSerial = SplitKSerial; using UnderlyingOperator = GemmComplex< @@ -560,10 +573,17 @@ class GemmComplex< EpilogueOutputOp, ThreadblockSwizzle, Stages, - TransformA, TransformB, + TransformA, + Operator, SplitKSerial >; + + static int const kAlignmentA = UnderlyingOperator::kAlignmentB; + static int const kAlignmentB = UnderlyingOperator::kAlignmentA; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB; + static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA; using UnderlyingArguments = typename UnderlyingOperator::Arguments; using GemmKernel = typename UnderlyingOperator::GemmKernel; diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index df11ba5b55..73f1c240b0 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/device/gemm_universal.h b/include/cutlass/gemm/device/gemm_universal.h new file mode 100644 index 0000000000..0912909014 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal.h @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone +> +class GemmUniversal : + GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB> +class GemmUniversal { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversal< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversal() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h new file mode 100644 index 0000000000..12a8a6d7f3 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + + template < + typename ElementA_, + typename LayoutA_, + ComplexTransform TransformA, + int AlignmentA, + typename ElementB_, + typename LayoutB_, + ComplexTransform TransformB, + int AlignmentB, + typename LayoutC_, + bool Transpose + > + struct MapArguments { + using ElementA = ElementA_; + using LayoutA = LayoutA_; + static ComplexTransform const kTransformA = TransformA; + static int const kAlignmentA = AlignmentA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + static ComplexTransform const kTransformB = TransformB; + static int const kAlignmentB = AlignmentB; + using LayoutC = LayoutC_; + }; + + template < + typename ElementA_, + typename LayoutA_, + ComplexTransform TransformA, + int AlignmentA, + typename ElementB_, + typename LayoutB_, + ComplexTransform TransformB, + int AlignmentB, + typename LayoutC_ + > + struct MapArguments< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + LayoutC_, + true + > { + using ElementA = ElementB_; + using LayoutA = typename layout::LayoutTranspose::type; + static ComplexTransform const kTransformA = TransformB; + static int const kAlignmentA = AlignmentB; + using ElementB = ElementA_; + using LayoutB = typename layout::LayoutTranspose::type; + static ComplexTransform const kTransformB = TransformA; + static int const kAlignmentB = AlignmentA; + using LayoutC = typename layout::LayoutTranspose::type; + }; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter { +public: + + using GemmKernel = GemmKernel_; + + static bool const kInternalTranspose = + std::is_same::value; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + using OperatorClass = typename GemmKernel::OperatorClass; + using ArchTag = typename GemmKernel::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using MapArguments = detail::MapArguments< + typename GemmKernel::ElementA, + typename GemmKernel::LayoutA, + GemmKernel::kTransformA, + GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, + typename GemmKernel::LayoutB, + GemmKernel::kTransformB, + GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, + kInternalTranspose + >; + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static int const kAlignmentA = GemmKernel::kAlignmentA; + + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + static int const kAlignmentB = GemmKernel::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename MapArguments::LayoutC; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalAdapter() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + if (kInternalTranspose) { + return args.transposed_problem(); + } + else { + return args; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h new file mode 100644 index 0000000000..18ccb3469a --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -0,0 +1,339 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template +class GemmUniversalBase { +public: + + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + +protected: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +protected: + + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + +public: + + /// Constructs the GEMM. + GemmUniversalBase() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = + sizeof(ElementC) * + size_t(args.batch_stride_D) * + size_t(grid_tiled_shape.k()); + } + else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + return threadblock_swizzle.get_grid_shape(grid_tiled_shape); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size <= (48 << 10)) { + + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + GemmKernel::kThreadCount, + smem_size); + + if (result == cudaSuccess) { + return max_active_blocks; + } + } + else { + + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + GemmKernel::kThreadCount, + 0); + + if (result != cudaSuccess) { + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + return std::min(max_active_blocks, smem_capacity / smem_size); + } + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params( + args, + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + ); + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 3a18a2b68c..78d0a6da6f 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -394,6 +394,16 @@ struct BatchedGemmCoord : public Coord<4, int> { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class GemmUniversalMode { + kGemm, + kGemmSplitKParallel, + kBatched, + kArray, + kInvalid +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace gemm diff --git a/include/cutlass/gemm/kernel/default_gemm.h b/include/cutlass/gemm/kernel/default_gemm.h index f3f6a1495c..0aba2d3a72 100644 --- a/include/cutlass/gemm/kernel/default_gemm.h +++ b/include/cutlass/gemm/kernel/default_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -49,6 +49,7 @@ #include "cutlass/gemm/kernel/gemm_pipelined.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/gemm/threadblock/default_mma_core_simt.h" #include "cutlass/gemm/threadblock/threadblock_swizzle.h" @@ -116,6 +117,68 @@ template < struct DefaultGemm; //////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultGemm { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::Gemm; +}; +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Turing Architecture template < /// Element type for A matrix operand @@ -201,6 +264,75 @@ struct DefaultGemm< }; //////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Is Beta zero or not + bool IsBetaZero> +struct DefaultGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, + layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, + SplitKSerial, Operator, IsBetaZero> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, + true>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + 64 / sizeof_bits::value, InterleavedK, + IsBetaZero>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::Gemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Turing Integer Matrix Multiply Interleaved layout template < /// Element type for A matrix operand @@ -439,6 +571,80 @@ struct DefaultGemm< //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Ampere +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages + int Stages, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultGemm, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80, + ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, + Operator>::ThreadblockMma; + + static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; + static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + typename Mma::Operator, + EpilogueOutputOp, + kEpilogueElementsPerAccess + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::Gemm; +}; + //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for SIMT DP4A @@ -516,7 +722,6 @@ struct DefaultGemm; }; - #if defined(CUTLASS_ARCH_WMMA_ENABLED) //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Wmma Gemm Kernel diff --git a/include/cutlass/gemm/kernel/default_gemm_complex.h b/include/cutlass/gemm/kernel/default_gemm_complex.h new file mode 100644 index 0000000000..15b1430c79 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_complex.h @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Multiply-add operator + // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial +> +struct DefaultGemmComplex; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Multiply-add operator + // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial + > +struct DefaultGemmComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementC, + layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, + layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, + WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< + ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::Gemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h b/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h new file mode 100644 index 0000000000..870084834a --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_planar_complex.h" +#include "cutlass/gemm/kernel/gemm_planar_complex_array.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" +#include "cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Math operation performed by GEMM (e.g. arch::OpMultiplyAdd) + typename Operator, + /// Conditional enabling to switch between stages + typename Enable = void + > +struct DefaultGemmPlanarComplexUniversal; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for pipelined mainloop +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator + > +struct DefaultGemmPlanarComplexUniversal< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if<(Stages <= 2)>::type +> { + + /// Define planar complex valued variants instead + using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexPipelined< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + TransformA, + TransformB, + Operator + >::ThreadblockMma; + + /// Planar complex epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< + ThreadblockShape, + typename Mma::Policy::Operator, + OperatorClass, + ArchTag, + ThreadblockShape::kK / WarpShape::kK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmPlanarComplex< + Mma, + Epilogue, + ThreadblockSwizzle + >; + + // Array variant + using GemmArrayKernel = kernel::GemmPlanarComplexArray< + Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiple pipeline stages. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator + > +struct DefaultGemmPlanarComplexUniversal< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if<(Stages > 2)>::type +> { + + /// Define planar complex valued variants instead + using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexMultistage< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + TransformA, + TransformB, + Operator + >::ThreadblockMma; + + /// Planar complex epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< + ThreadblockShape, + typename Mma::Policy::Operator, + OperatorClass, + ArchTag, + ThreadblockShape::kK / WarpShape::kK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmPlanarComplex< + Mma, + Epilogue, + ThreadblockSwizzle + >; + + // Array variant + using GemmArrayKernel = kernel::GemmPlanarComplexArray< + Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h b/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h index f50ead0468..e23965d336 100644 --- a/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h +++ b/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h new file mode 100644 index 0000000000..579005cb41 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// + typename Enable = void + > +struct DefaultGemmUniversal; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultGemmUniversal< + ElementA, + LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if< ! cutlass::is_complex::value>::type +> { + + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + true, + Operator, + false + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Complex-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator + > +struct DefaultGemmUniversal< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if::value>::type +> { + + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + TransformA, + TransformB, + Operator, + false + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemv.h b/include/cutlass/gemm/kernel/default_gemv.h old mode 100644 new mode 100755 index 08a3079032..36ae339c4e --- a/include/cutlass/gemm/kernel/default_gemv.h +++ b/include/cutlass/gemm/kernel/default_gemv.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index 2220465edc..6700659a1f 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -83,7 +83,7 @@ struct Gemm { // CUTLASS_HOST_DEVICE - Params() { } + Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } CUTLASS_HOST_DEVICE Params( @@ -94,7 +94,7 @@ struct Gemm { typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op = typename OutputOp::Params(), - int *semaphore = nullptr + int *workspace = nullptr ): problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), @@ -106,13 +106,14 @@ struct Gemm { ref_C(ref_C), params_D(ref_D.layout()), ref_D(ref_D), - output_op(output_op), - semaphore(semaphore) { + output_op(output_op) { int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; } }; @@ -220,7 +221,9 @@ struct Gemm { thread_idx, tb_offset_B); - int warp_idx = threadIdx.x / 32; + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h new file mode 100644 index 0000000000..f63571b023 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -0,0 +1,253 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmArray { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::Element const * const * ptr_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::Element const * const * ptr_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Element const * const * ptr_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::Element * const * ptr_D; + int64_t stride_D; + typename OutputOp::Params epilogue; + int batch_count; + int gemm_k_iterations; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size_, + cutlass::gemm::GemmCoord const & grid_tiled_shape_, + typename Mma::IteratorA::Element const * const * ptr_A_, + typename Mma::IteratorA::Layout layout_A, + typename Mma::IteratorB::Element const * const * ptr_B_, + typename Mma::IteratorB::Layout layout_B, + typename Epilogue::OutputTileIterator::Element const * const * ptr_C_, + typename Epilogue::OutputTileIterator::Layout layout_C, + typename Epilogue::OutputTileIterator::Element * const * ptr_D_, + typename Epilogue::OutputTileIterator::Layout layout_D, + typename OutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + grid_tiled_shape(grid_tiled_shape_), + params_A(layout_A), + ptr_A(ptr_A_), + params_B(layout_B), + ptr_B(ptr_B_), + params_C(layout_C), + ptr_C(ptr_C_), + params_D(layout_D), + ptr_D(ptr_D_), + epilogue(epilogue_), + batch_count(batch_count_), + gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) { + + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmArray() { } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + + // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension + for (int batch_idx = threadblock_swizzle.get_batch_idx(); + batch_idx < params.batch_count; + batch_idx += gridDim.z) { + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + 0 + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + const_cast(params.ptr_A[batch_idx]), + params.problem_size.mk(), + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + const_cast(params.ptr_B[batch_idx]), + params.problem_size.kn(), + thread_idx, + tb_offset_B); + + // + // Main loop + // + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + OutputOp output_op(params.epilogue); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + const_cast(params.ptr_C[batch_idx]), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ptr_D[batch_idx], + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index 68a5587f78..eb638375c0 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -193,8 +193,10 @@ struct GemmBatched { // Main loop // - // Construct thread-scoped matrix multiply - int warp_idx = threadIdx.x / 32; + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); diff --git a/include/cutlass/gemm/kernel/gemm_pipelined.h b/include/cutlass/gemm/kernel/gemm_pipelined.h index 293592e74e..6caa0eae31 100644 --- a/include/cutlass/gemm/kernel/gemm_pipelined.h +++ b/include/cutlass/gemm/kernel/gemm_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h new file mode 100644 index 0000000000..e05112569b --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -0,0 +1,700 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmPlanarComplex { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using Operator = typename Mma::Operator; + using ArchTag = typename Mma::ArchTag; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value); + + // + // Additional types needed for reflection + // + + using ElementAccumulator = typename Mma::Policy::Operator::ElementC; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::Shape; + + static int const kStages = Mma::kStages; + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + // + // Arguments structure + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A_real; + void const * ptr_A_imag; + + void const * ptr_B_real; + void const * ptr_B_imag; + + void const * ptr_C_real; + void const * ptr_C_imag; + + void * ptr_D_real; + void * ptr_D_imag; + + int lda_real; + int lda_imag; + int ldb_real; + int ldb_imag; + int ldc_real; + int ldc_imag; + int ldd_real; + int ldd_imag; + + int64_t batch_stride_A; + int64_t batch_stride_A_imag; + int64_t batch_stride_B; + int64_t batch_stride_B_imag; + int64_t batch_stride_C; + int64_t batch_stride_C_imag; + int64_t batch_stride_D; + int64_t batch_stride_D_imag; + + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr) + { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A_real, + void const * ptr_A_imag, + void const * ptr_B_real, + void const * ptr_B_imag, + void const * ptr_C_real, + void const * ptr_C_imag, + void * ptr_D_real, + void * ptr_D_imag, + int lda_real, + int lda_imag, + int ldb_real, + int ldb_imag, + int ldc_real, + int ldc_imag, + int ldd_real, + int ldd_imag, + int64_t batch_stride_A = 0, + int64_t batch_stride_A_imag = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_B_imag = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_C_imag = 0, + int64_t batch_stride_D = 0, + int64_t batch_stride_D_imag = 0 + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A_real(ptr_A_real), + ptr_A_imag(ptr_A_imag), + ptr_B_real(ptr_B_real), + ptr_B_imag(ptr_B_imag), + ptr_C_real(ptr_C_real), + ptr_C_imag(ptr_C_imag), + ptr_D_real(ptr_D_real), + ptr_D_imag(ptr_D_imag), + lda_real(lda_real), + lda_imag(lda_imag), + ldb_real(ldb_real), + ldb_imag(ldb_imag), + ldc_real(ldc_real), + ldc_imag(ldc_imag), + ldd_real(ldd_real), + ldd_imag(ldd_imag), + batch_stride_A(batch_stride_A), + batch_stride_A_imag(batch_stride_A_imag), + batch_stride_B(batch_stride_B), + batch_stride_B_imag(batch_stride_B_imag), + batch_stride_C(batch_stride_C), + batch_stride_C_imag(batch_stride_C_imag), + batch_stride_D(batch_stride_D), + batch_stride_D_imag(batch_stride_D_imag) { + + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A_real, args.ptr_B_real); + std::swap(args.ptr_A_imag, args.ptr_B_imag); + std::swap(args.lda_real, args.ldb_real); + std::swap(args.lda_imag, args.ldb_imag); + std::swap(args.batch_stride_A, args.batch_stride_B); + std::swap(args.batch_stride_A_imag, args.batch_stride_B_imag); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + + typename Mma::IteratorA::Params params_A_real; + typename Mma::IteratorA::Params params_A_imag; + typename Mma::IteratorB::Params params_B_real; + typename Mma::IteratorB::Params params_B_imag; + typename Epilogue::OutputTileIterator::Params params_C_real; + typename Epilogue::OutputTileIterator::Params params_C_imag; + typename Epilogue::OutputTileIterator::Params params_D_real; + typename Epilogue::OutputTileIterator::Params params_D_imag; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A_real; + void * ptr_A_imag; + void * ptr_B_real; + void * ptr_B_imag; + void * ptr_C_real; + void * ptr_C_imag; + void * ptr_D_real; + void * ptr_D_imag; + + int64_t batch_stride_A; + int64_t batch_stride_A_imag; + int64_t batch_stride_B; + int64_t batch_stride_B_imag; + int64_t batch_stride_C; + int64_t batch_stride_C_imag; + int64_t batch_stride_D; + int64_t batch_stride_D_imag; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr), + batch_stride_A(0), + batch_stride_A_imag(0), + batch_stride_B(0), + batch_stride_B_imag(0), + batch_stride_C(0), + batch_stride_C_imag(0), + batch_stride_D(0), + batch_stride_D_imag(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + params_A_real(args.lda_real), + params_A_imag(args.lda_imag), + params_B_real(args.ldb_real), + params_B_imag(args.ldb_imag), + params_C_real(args.ldc_real), + params_C_imag(args.ldc_imag), + params_D_real(args.ldd_real), + params_D_imag(args.ldd_imag), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A_real(const_cast(args.ptr_A_real)), + ptr_A_imag(const_cast(args.ptr_A_imag)), + ptr_B_real(const_cast(args.ptr_B_real)), + ptr_B_imag(const_cast(args.ptr_B_imag)), + ptr_C_real(const_cast(args.ptr_C_real)), + ptr_C_imag(const_cast(args.ptr_C_imag)), + ptr_D_real(args.ptr_D_real), + ptr_D_imag(args.ptr_D_imag), + batch_stride_A(args.batch_stride_A), + batch_stride_A_imag(args.batch_stride_A_imag), + batch_stride_B(args.batch_stride_B), + batch_stride_B_imag(args.batch_stride_B_imag), + batch_stride_C(args.batch_stride_C), + batch_stride_C_imag(args.batch_stride_C_imag), + batch_stride_D(args.batch_stride_D), + batch_stride_D_imag(args.batch_stride_D_imag), + semaphore(static_cast(workspace)) { + + } + + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A_real = const_cast(args.ptr_A_real); + ptr_A_imag = const_cast(args.ptr_A_imag); + + ptr_B_real = const_cast(args.ptr_B_real); + ptr_B_imag = const_cast(args.ptr_B_imag); + + ptr_C_real = const_cast(args.ptr_C_real); + ptr_C_imag = const_cast(args.ptr_C_imag); + + ptr_D_real = const_cast(args.ptr_D_real); + ptr_D_imag = const_cast(args.ptr_D_imag); + + batch_stride_A = args.batch_stride_A; + batch_stride_A_imag = args.batch_stride_A_imag; + batch_stride_B = args.batch_stride_B; + batch_stride_B_imag = args.batch_stride_B_imag; + batch_stride_C = args.batch_stride_C; + batch_stride_C_imag = args.batch_stride_C_imag; + batch_stride_D = args.batch_stride_D; + batch_stride_D_imag = args.batch_stride_D_imag; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmPlanarComplex() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const &args) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) || + (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) || + (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A_real = static_cast(params.ptr_A_real); + ElementA *ptr_A_imag = static_cast(params.ptr_A_imag); + + ElementB *ptr_B_real = static_cast(params.ptr_B_real); + ElementB *ptr_B_imag = static_cast(params.ptr_B_imag); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A; + ptr_A_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A_imag; + ptr_B_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B; + ptr_B_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B_imag; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A_real = static_cast(params.ptr_A_real)[threadblock_tile_offset.k()]; + ptr_A_imag = static_cast(params.ptr_A_imag)[threadblock_tile_offset.k()]; + ptr_B_real = static_cast(params.ptr_B_real)[threadblock_tile_offset.k()]; + ptr_B_imag = static_cast(params.ptr_B_imag)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A_real( + params.params_A_real, + ptr_A_real, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorA iterator_A_imag( + params.params_A_imag, + ptr_A_imag, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B_real( + params.params_B_real, + ptr_B_real, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorB iterator_B_imag( + params.params_B_imag, + ptr_B_imag, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A_real, + iterator_A_imag, + iterator_B_real, + iterator_B_imag, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C_real = static_cast(params.ptr_C_real); + ElementC *ptr_C_imag = static_cast(params.ptr_C_imag); + ElementC *ptr_D_real = static_cast(params.ptr_D_real); + ElementC *ptr_D_imag = static_cast(params.ptr_D_imag); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D_real += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D_imag += threadblock_tile_offset.k() * params.batch_stride_D_imag; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C; + ptr_C_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C_imag; + ptr_D_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D; + ptr_D_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D_imag; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C_real = static_cast(params.ptr_C_real)[threadblock_tile_offset.k()]; + ptr_C_imag = static_cast(params.ptr_C_imag)[threadblock_tile_offset.k()]; + ptr_D_real = static_cast(params.ptr_D_real)[threadblock_tile_offset.k()]; + ptr_D_imag = static_cast(params.ptr_D_imag)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C_real( + params.params_C_real, + ptr_C_real, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params.params_C_imag, + ptr_C_imag, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D_real( + params.params_D_real, + ptr_D_real, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params.params_D_imag, + ptr_D_imag, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // + // Construct epilogue + // + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C_real = iterator_D_real; + iterator_C_imag = iterator_D_imag; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h new file mode 100644 index 0000000000..00841d4692 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -0,0 +1,591 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmPlanarComplexArray { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using Operator = typename Mma::Operator; + using ArchTag = typename Mma::ArchTag; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value); + + // + // Additional types needed for reflection + // + + using ElementAccumulator = typename Mma::Policy::Operator::ElementC; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::Shape; + + static int const kStages = Mma::kStages; + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + // + // Arguments structure + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + int const *ptr_M; + int const *ptr_N; + int const *ptr_K; + + void const * const * ptr_A_real; + void const * const * ptr_A_imag; + + void const * const * ptr_B_real; + void const * const * ptr_B_imag; + + void const * const * ptr_C_real; + void const * const * ptr_C_imag; + + void * const * ptr_D_real; + void * const * ptr_D_imag; + + int lda_real; + int lda_imag; + int ldb_real; + int ldb_imag; + int ldc_real; + int ldc_imag; + int ldd_real; + int ldd_imag; + + int64_t batch_stride_D; // unused + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kArray), + batch_count(1), + ptr_M(nullptr), + ptr_N(nullptr), + ptr_K(nullptr), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr), + batch_stride_D(0) + { } + + /// constructs an arguments structure + Arguments( + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + int const *ptr_M, + int const *ptr_N, + int const *ptr_K, + void const * const * ptr_A_real, + void const * const * ptr_A_imag, + void const * const * ptr_B_real, + void const * const * ptr_B_imag, + void const * const * ptr_C_real, + void const * const * ptr_C_imag, + void * const * ptr_D_real, + void * const * ptr_D_imag, + int lda_real, + int lda_imag, + int ldb_real, + int ldb_imag, + int ldc_real, + int ldc_imag, + int ldd_real, + int ldd_imag + ): + mode(GemmUniversalMode::kArray), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_M(ptr_M), + ptr_N(ptr_N), + ptr_K(ptr_K), + ptr_A_real(ptr_A_real), + ptr_A_imag(ptr_A_imag), + ptr_B_real(ptr_B_real), + ptr_B_imag(ptr_B_imag), + ptr_C_real(ptr_C_real), + ptr_C_imag(ptr_C_imag), + ptr_D_real(ptr_D_real), + ptr_D_imag(ptr_D_imag), + lda_real(lda_real), + lda_imag(lda_imag), + ldb_real(ldb_real), + ldb_imag(ldb_imag), + ldc_real(ldc_real), + ldc_imag(ldc_imag), + ldd_real(ldd_real), + ldd_imag(ldd_imag), + batch_stride_D(0) { + + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_M, args.ptr_N); + std::swap(args.ptr_A_real, args.ptr_B_real); + std::swap(args.ptr_A_imag, args.ptr_B_imag); + std::swap(args.lda_real, args.ldb_real); + std::swap(args.lda_imag, args.ldb_imag); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + + typename Mma::IteratorA::Params params_A_real; + typename Mma::IteratorA::Params params_A_imag; + typename Mma::IteratorB::Params params_B_real; + typename Mma::IteratorB::Params params_B_imag; + typename Epilogue::OutputTileIterator::Params params_C_real; + typename Epilogue::OutputTileIterator::Params params_C_imag; + typename Epilogue::OutputTileIterator::Params params_D_real; + typename Epilogue::OutputTileIterator::Params params_D_imag; + + typename EpilogueOutputOp::Params output_op; + + int batch_count; + + int const *ptr_M; + int const *ptr_N; + int const *ptr_K; + + void const * const * ptr_A_real; + void const * const * ptr_A_imag; + void const * const * ptr_B_real; + void const * const * ptr_B_imag; + void const * const * ptr_C_real; + void const * const * ptr_C_imag; + void * const * ptr_D_real; + void * const * ptr_D_imag; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + batch_count(0), + ptr_M(nullptr), + ptr_N(nullptr), + ptr_K(nullptr), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size = 0, // ignored + void *workspace = nullptr // ignored + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + ptr_M(args.ptr_M), + ptr_N(args.ptr_N), + ptr_K(args.ptr_K), + params_A_real(args.lda_real), + params_A_imag(args.lda_imag), + params_B_real(args.ldb_real), + params_B_imag(args.ldb_imag), + params_C_real(args.ldc_real), + params_C_imag(args.ldc_imag), + params_D_real(args.ldd_real), + params_D_imag(args.ldd_imag), + output_op(args.epilogue), + batch_count(args.batch_count), + ptr_A_real(args.ptr_A_real), + ptr_A_imag(args.ptr_A_imag), + ptr_B_real(args.ptr_B_real), + ptr_B_imag(args.ptr_B_imag), + ptr_C_real(args.ptr_C_real), + ptr_C_imag(args.ptr_C_imag), + ptr_D_real(args.ptr_D_real), + ptr_D_imag(args.ptr_D_imag) { + + } + + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_M = args.ptr_M; + ptr_N = args.ptr_N; + ptr_K = args.ptr_K; + + ptr_A_real = args.ptr_A_real; + ptr_A_imag = args.ptr_A_imag; + + ptr_B_real = args.ptr_B_real; + ptr_B_imag = args.ptr_B_imag; + + ptr_C_real = args.ptr_C_real; + ptr_C_imag = args.ptr_C_imag; + + ptr_D_real = args.ptr_D_real; + ptr_D_imag = args.ptr_D_imag; + + output_op = args.epilogue; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmPlanarComplexArray() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const &args) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) || + (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) || + (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int batch_idx = threadblock_tile_offset.k(); + + int problem_size_m = params.problem_size.m(); + int problem_size_n = params.problem_size.n(); + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A_real = static_cast(const_cast(params.ptr_A_real[batch_idx])); + ElementA *ptr_A_imag = static_cast(const_cast(params.ptr_A_imag[batch_idx])); + + ElementB *ptr_B_real = static_cast(const_cast(params.ptr_B_real[batch_idx])); + ElementB *ptr_B_imag = static_cast(const_cast(params.ptr_B_imag[batch_idx])); + + // + // If pointers for problem sizes are specified, these are loaded from global memory + // + + if (params.ptr_M) { + problem_size_m = params.ptr_M[batch_idx]; + } + + if (params.ptr_N) { + problem_size_n = params.ptr_N[batch_idx]; + } + + if (params.ptr_K) { + problem_size_k = params.ptr_K[batch_idx]; + } + + int const kBlockCountM = (problem_size_m + Mma::Shape::kM - 1) / Mma::Shape::kM; + int const kBlockCountN = (problem_size_n + Mma::Shape::kN - 1) / Mma::Shape::kN; + + int const kGemmKIterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // + // Each threadblock loops over the logical problem size which the kernel may have discovered + // after the grid is launched. + // + + CUTLASS_PRAGMA_NO_UNROLL + for (int block_m = threadblock_tile_offset.m(); + block_m < kBlockCountM; + block_m += params.grid_tiled_shape.m()) { + + CUTLASS_PRAGMA_NO_UNROLL + for (int block_n = threadblock_tile_offset.n(); + block_n < kBlockCountN; + block_n += params.grid_tiled_shape.n()) { + + // + // Compute indices within threadblock and warp. + // + int thread_idx = threadIdx.x; + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Proceed with regular GEMM logic. + // + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ block_m * Mma::Shape::kM, 0}; + cutlass::MatrixCoord tb_offset_B{ 0, block_n * Mma::Shape::kN }; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A_real( + params.params_A_real, + ptr_A_real, + {problem_size_m, problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorA iterator_A_imag( + params.params_A_imag, + ptr_A_imag, + {problem_size_m, problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B_real( + params.params_B_real, + ptr_B_real, + {problem_size_k, problem_size_n}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorB iterator_B_imag( + params.params_B_imag, + ptr_B_imag, + {problem_size_k, problem_size_n}, + thread_idx, + tb_offset_B); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma( + kGemmKIterations, + accumulators, + iterator_A_real, + iterator_A_imag, + iterator_B_real, + iterator_B_imag, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + //assume identity swizzle + MatrixCoord threadblock_offset( + block_m * Mma::Shape::kM, + block_n * Mma::Shape::kN + ); + + ElementC *ptr_C_real = static_cast(const_cast(params.ptr_C_real[batch_idx])); + ElementC *ptr_C_imag = static_cast(const_cast(params.ptr_C_imag[batch_idx])); + ElementC *ptr_D_real = static_cast(params.ptr_D_real[batch_idx]); + ElementC *ptr_D_imag = static_cast(params.ptr_D_imag[batch_idx]); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C_real( + params.params_C_real, + ptr_C_real, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params.params_C_imag, + ptr_C_imag, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D_real( + params.params_D_real, + ptr_D_real, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params.params_D_imag, + ptr_D_imag, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + // + // Construct epilogue + // + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + + + } // for block_n + } // for block_m + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/kernel/gemm_splitk_parallel.h b/include/cutlass/gemm/kernel/gemm_splitk_parallel.h index 2c5978aa8a..973897521f 100644 --- a/include/cutlass/gemm/kernel/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/kernel/gemm_splitk_parallel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h new file mode 100644 index 0000000000..6efd50a7fd --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmUniversal { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + int lda; + int ldb; + int ldc; + int ldd; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int lda, + int ldb, + int ldc, + int ldd + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + semaphore(static_cast(workspace)) { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmUniversal() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemv_batched_strided.h b/include/cutlass/gemm/kernel/gemv_batched_strided.h old mode 100644 new mode 100755 index 852edde29a..ea8d9bdf85 --- a/include/cutlass/gemm/kernel/gemv_batched_strided.h +++ b/include/cutlass/gemm/kernel/gemv_batched_strided.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/thread/mma.h b/include/cutlass/gemm/thread/mma.h index 41ea8b49cd..15dfe4338e 100644 --- a/include/cutlass/gemm/thread/mma.h +++ b/include/cutlass/gemm/thread/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index 78c77bef27..04658f7bc0 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 66fed7e17a..16d0d61c24 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/thread/mma_sm61.h b/include/cutlass/gemm/thread/mma_sm61.h index 13bbb54299..83e31b2377 100644 --- a/include/cutlass/gemm/thread/mma_sm61.h +++ b/include/cutlass/gemm/thread/mma_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/default_gemv_core.h b/include/cutlass/gemm/threadblock/default_gemv_core.h old mode 100644 new mode 100755 index de234b851c..9d692d6db5 --- a/include/cutlass/gemm/threadblock/default_gemv_core.h +++ b/include/cutlass/gemm/threadblock/default_gemv_core.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 25be077028..3ebe14e6b8 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,10 +33,13 @@ #include "cutlass/arch/arch.h" #include "cutlass/arch/wmma.h" +#include "cutlass/layout/matrix.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include "cutlass/gemm/threadblock/default_mma_core_wmma.h" #endif //CUTLASS_ARCH_WMMA_ENABLED @@ -143,8 +146,9 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass Simt) +/// Specialization for row-major output (OperatorClass TensorOp) template < /// Element type for A matrix operand typename ElementA, @@ -199,6 +203,58 @@ struct DefaultMma; }; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator + > +struct DefaultMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, + LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, + arch::OpMultiplyAddFastF16>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, float, + layout::RowMajor, typename MmaCore::MmaPolicy>; +}; + //////////////////////////////////////////////////////////////////////////////// /// Specialization for column-major-interleaved output @@ -268,7 +324,217 @@ struct DefaultMma +struct DefaultMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, Operator>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + //////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator + > +struct DefaultMma { + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Number of Interleaved K + int InterleavedK> +struct DefaultMma, OperatorClass, + ArchTag, ThreadblockShape, WarpShape, InstructionShape, + Stages, Operator, true> { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Specialization for SIMT IDP4A Kernels template < /// Layout type for A matrix operand @@ -326,6 +592,8 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// + #if defined(CUTLASS_ARCH_WMMA_ENABLED) /// Specialization for Wmma TensorOp operator with 2 staged pipeline template < @@ -384,6 +652,8 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// + /// Specialization for Wmma TensorOp operator with 1 staged pipeline template < ///< Element type for A matrix operand @@ -440,6 +710,7 @@ struct DefaultMma; }; + //////////////////////////////////////////////////////////////////////////////// #endif //CUTLASS_ARCH_WMMA_ENABLED diff --git a/include/cutlass/gemm/threadblock/default_mma_core.h b/include/cutlass/gemm/threadblock/default_mma_core.h index f346709e60..a7ac7c44b2 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core.h +++ b/include/cutlass/gemm/threadblock/default_mma_core.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -40,6 +40,8 @@ #include "cutlass/gemm/warp/mma.h" #include "cutlass/gemm/threadblock/mma_pipelined.h" #include "cutlass/gemm/threadblock/mma_singlestage.h" +#include "cutlass/arch/cache_operation.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -86,6 +88,17 @@ template < /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor = false + /// Cache operation of operand A + , cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) > struct DefaultMmaCore; diff --git a/include/cutlass/gemm/threadblock/default_mma_core_simt.h b/include/cutlass/gemm/threadblock/default_mma_core_simt.h index 9eaa6a7a5f..be50149372 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_simt.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm50.h b/include/cutlass/gemm/threadblock/default_mma_core_sm50.h index 37aee47620..782cd7aea8 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm50.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm50.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm70.h b/include/cutlass/gemm/threadblock/default_mma_core_sm70.h index a9ec80fd12..30b3b3c0aa 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm70.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h index 490b479e7b..e7a2adcb14 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -598,6 +598,525 @@ struct DefaultMmaCore +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = float; + using LayoutA = layout::ColumnMajor; + using ElementB = float; + using LayoutB = layout::RowMajor; + using ElementC = float; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassTensorOp; + + /// Number of warps present + using WarpCount = GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + Shape::kK / WarpShape::kK + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 256; + + /// Default Operator + using Operator = arch::OpMultiplyAdd; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(half_t))>; + + // Shared memory layout + using SmemLayoutB = + layout::RowMajorTensorOpMultiplicandCongruous::value, + int(128 / sizeof(half_t))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value + >; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileIterator< + MatrixShape, + half_t, + SmemLayoutA, + 1, + IteratorThreadMapA + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value + >; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileIterator< + MatrixShape, + half_t, + SmemLayoutB, + 0, + IteratorThreadMapB + >; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaTensorOp, + MatrixShape<0, 0>, + MatrixShape<0, 0>, + WarpCount::kK + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = float; + using LayoutA = layout::RowMajor; + using ElementB = float; + using LayoutB = layout::ColumnMajor; + using ElementC = float; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassTensorOp; + + /// Number of warps present + using WarpCount = GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + Shape::kK / WarpShape::kK + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 256; + + /// Default Operator + using Operator = arch::OpMultiplyAdd; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = + layout::RowMajorTensorOpMultiplicandCrosswise::value, + Shape::kK>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileIterator< + MatrixShape, + half_t, + SmemLayoutA, + 0, + IteratorThreadMapA + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileIterator< + MatrixShape, + half_t, + SmemLayoutB, + 1, + IteratorThreadMapB + >; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaTensorOp, + MatrixShape<0, 0>, + MatrixShape<0, 0>, + WarpCount::kK + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = float; + using LayoutA = layout::RowMajor; + using ElementB = float; + using LayoutB = layout::RowMajor; + using ElementC = float; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassTensorOp; + + /// Number of warps present + using WarpCount = GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + Shape::kK / WarpShape::kK + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 256; + + /// Default Operator + using Operator = arch::OpMultiplyAdd; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // Shared memory layout + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(half_t))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileIterator< + MatrixShape, + half_t, + SmemLayoutA, + 0, + IteratorThreadMapA + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value + >; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileIterator< + MatrixShape, + half_t, + SmemLayoutB, + 0, + IteratorThreadMapB + >; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaTensorOp, + MatrixShape<0, 0>, + MatrixShape<0, 0>, + WarpCount::kK + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = float; + using LayoutA = layout::ColumnMajor; + using ElementB = float; + using LayoutB = layout::ColumnMajor; + using ElementC = float; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassTensorOp; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 256; + + /// Default Operator + using Operator = arch::OpMultiplyAdd; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousB = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(half_t))>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileIterator< + MatrixShape, half_t, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileIterator< + MatrixShape, half_t, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, MatrixShape<0, 0>, + WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization: /// /// A: column-major-interleave32 diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h new file mode 100644 index 0000000000..d9b3d9a0c5 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h @@ -0,0 +1,2130 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming + expectations about data layout of the global memory fragments, data types, + and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting TensorOp + instructions. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_mma_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for double-precision +/// +/// A: column-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::ColumnMajor; + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 64; + + /// Default Operator + using Operator = Operator_; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +/// Partial specialization for double-precision +/// +/// A: column-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::ColumnMajor; + using ElementB = double; + using LayoutB = layout::RowMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 64; + + /// Default Operator + using Operator = Operator_; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + + // Shared memory layout + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for double-precision +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::RowMajor; + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 64; + + /// Default Operator + using Operator = Operator_; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Partial specialization for double-precision +/// +/// A: row-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::RowMajor; + using ElementB = double; + using LayoutB = layout::RowMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 64; + + /// Default Operator + using Operator = Operator_; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; + + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for float-precision +/// +/// ElementA: complex +/// ElementB: complex +/// ElementC: complex +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout for A operand + typename LayoutA_, + /// Layout for B operand + typename LayoutB_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// per-element transformation for elements of A + ComplexTransform TransformA_, + /// per-element transformation for elements of B + ComplexTransform TransformB_ + > +struct DefaultMmaCore< + Shape_, WarpShape_, GemmShape<16, 8, 8>, + complex, LayoutA_, + complex, LayoutB_, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + Operator_, + false, + CacheOpA, + CacheOpB, + TransformA_, TransformB_, true> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<16, 8, 8>; + using ElementA = complex; + using LayoutA = LayoutA_; + using ElementB = complex; + using LayoutB = LayoutB_; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + static const ComplexTransform TransformA = TransformA_; + static const ComplexTransform TransformB = TransformB_; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + static_assert( + platform::is_same::value || + platform::is_same::value, + "The operator tag must indicate complex multiplication."); + + // + // Underlying template + // + + using MmaComplexCore = DefaultMultistageMmaComplexCore< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + arch::OpClassTensorOp, + kStages, + TransformA, + TransformB, + Operator, + kCacheOpA, + kCacheOpB + >; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; + + // Shared memory layout + using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; + + /// ThreadMap of iterator B + using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; + + /// Policy used to define MmaPipelined + using MmaPolicy = typename MmaComplexCore::MmaPolicy; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for double-precision +/// +/// ElementA: complex +/// ElementB: complex +/// ElementC: complex +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout for A operand + typename LayoutA_, + /// Layout for B operand + typename LayoutB_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// per-element transformation for elements of A + ComplexTransform TransformA_, + /// per-element transformation for elements of B + ComplexTransform TransformB_ + > +struct DefaultMmaCore< + Shape_, WarpShape_, GemmShape<8, 8, 4>, + complex, LayoutA_, + complex, LayoutB_, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + Operator_, + false, + CacheOpA, + CacheOpB, + TransformA_, TransformB_, true> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<8, 8, 4>; + using ElementA = complex; + using LayoutA = LayoutA_; + using ElementB = complex; + using LayoutB = LayoutB_; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + static const ComplexTransform TransformA = TransformA_; + static const ComplexTransform TransformB = TransformB_; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 64; + + /// Default Operator + using Operator = Operator_; + + static_assert( + platform::is_same::value || + platform::is_same::value, + "The operator tag must indicate complex multiplication."); + + // + // Underlying template + // + + using MmaComplexCore = DefaultMultistageMmaComplexCore< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + arch::OpClassTensorOp, + kStages, + TransformA, + TransformB, + Operator, + kCacheOpA, + kCacheOpB + >; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; + + // Shared memory layout + using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; + + /// ThreadMap of iterator B + using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; + + /// Policy used to define MmaPipelined + using MmaPolicy = typename MmaComplexCore::MmaPolicy; +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementA))>; + + // Shared memory layout + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementB))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + + using LayoutA = layout::ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousB = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementA))>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // Shared memory layout + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementB))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major-interleaved +/// B: row-major-interleaved +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Number of interleaved K + int InterleavedK> +struct DefaultMmaCore, ElementB_, + layout::RowMajorInterleaved, ElementC_, + LayoutC_, arch::OpClassTensorOp, Stages, Operator_, + AccumulatorsInRowMajor, CacheOpA, CacheOpB> { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::ColumnMajorInterleaved; + using ElementB = ElementB_; + using LayoutB = layout::RowMajorInterleaved; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static int const kInterleavedK = InterleavedK; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kElementsPerAccess = + kAccessSizeInBits / sizeof_bits::value; + + static int const kWarpThreadArrangementContiguous = + kInterleavedK / kElementsPerAccess; + + static int const kWarpThreadArrangementStrided = + kWarpSize / kWarpThreadArrangementContiguous; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, kInterleavedK>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, kInterleavedK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; + + /// Transpose the ThreadMap of iterator A + using SmemThreadMapA = transform::TransposePitchLinearThreadMap< + IteratorThreadMapA, + layout::PitchLinearShape>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + SmemThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; + + /// Transpose the ThreadMap of iterator A + using SmemThreadMapB = transform::TransposePitchLinearThreadMap< + IteratorThreadMapB, + layout::PitchLinearShape>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + SmemThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + + // Shared memory layout + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory + // + + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// Policy of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Transpose the ThreadMap of iterator A + using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + SmemThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = 4; // TODO need to extract these from template data + static const int WarpNumThreadsN = 8; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = 128 / sizeof_bits::value; + static const int numElementsB = 128 / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + >; /// Used for partial specialization + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaWarpSimt, + MatrixShape<0, 0>, + MatrixShape<0, Shape::kK / 32>, + WarpCount::kK>; +}; + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + + // Shared memory layout + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory + // + + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// Policy of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = 4; // TODO need to extract these from template data + static const int WarpNumThreadsN = 8; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = 128 / sizeof_bits::value; + static const int numElementsB = 128 / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + >; /// Used for partial specialization + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaWarpSimt, + MatrixShape<0, 0>, + MatrixShape<0, 0>, + WarpCount::kK>; +}; + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + + // Shared memory layout + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory + // + + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Transpose the ThreadMap of iterator A + using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + SmemThreadMapA>; + + /// Policy of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Transpose the ThreadMap of iterator A + using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + SmemThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = 4; // TODO need to extract these from template data + static const int WarpNumThreadsN = 8; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = 128 / sizeof_bits::value; + static const int numElementsB = 128 / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + >; /// Used for partial specialization + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaWarpSimt, + MatrixShape, + MatrixShape<0, Shape::kK / 32>, + WarpCount::kK>; +}; + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + + // Shared memory layout + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Transpose the ThreadMap of iterator A + using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + SmemThreadMapA>; + + /// Policy of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = 4; // TODO need to extract these from template data + static const int WarpNumThreadsN = 8; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = 128 / sizeof_bits::value; + static const int numElementsB = 128 / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + >; /// Used for partial specialization + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy< + MmaWarpSimt, + MatrixShape, + MatrixShape<0, 0>, + WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/gemm/threadblock/default_mma_core_wmma.h b/include/cutlass/gemm/threadblock/default_mma_core_wmma.h index ef51be23a7..8214494321 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_wmma.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h b/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h new file mode 100644 index 0000000000..2f4a079619 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/mma_planar_complex_multistage.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Math operator tag (e.g. arch::OpMultiplyAdd) + typename Operator = arch::OpMultiplyAdd +> +struct DefaultMmaPlanarComplexMultistage { + + // Construct a planar complex variant from the real-valued variant + using RealMmaMultistage = typename DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator_, + LayoutC_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + Stages, + Operator + >::ThreadblockMma; + + using ThreadblockMma = MmaPlanarComplexMultistage< + ThreadblockShape_, + typename RealMmaMultistage::IteratorA, + typename RealMmaMultistage::SmemIteratorA, + cutlass::arch::CacheOperation::Global, + typename RealMmaMultistage::IteratorB, + typename RealMmaMultistage::SmemIteratorB, + cutlass::arch::CacheOperation::Global, + ElementAccumulator_, + LayoutC_, + typename RealMmaMultistage::Policy, + Stages, + TransformA, + TransformB + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h b/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h new file mode 100644 index 0000000000..04a856e9a4 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/gemm/warp/mma_planar_complex.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/mma_planar_complex_pipelined.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Math operator tag (e.g. arch::OpMultiplyAdd) + typename Operator = arch::OpMultiplyAdd +> +struct DefaultMmaPlanarComplexPipelined { + + // Construct a planar complex variant from the real-valued variant + using RealMma = typename DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator_, + LayoutC_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + Stages, + Operator + >::ThreadblockMma; + + using ThreadblockMma = MmaPlanarComplexPipelined< + ThreadblockShape_, + typename RealMma::IteratorA, + typename RealMma::SmemIteratorA, + typename RealMma::IteratorB, + typename RealMma::SmemIteratorB, + ElementAccumulator_, + LayoutC_, + typename RealMma::Policy, + Stages, + TransformA, + TransformB + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h b/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h new file mode 100644 index 0000000000..7f3d534a1f --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator = arch::OpMultiplyAddComplex, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultMultistageMmaComplex; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator> +struct DefaultMultistageMmaComplex { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, TransformA, TransformB, Operator>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h new file mode 100644 index 0000000000..613c88e3ea --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming + expectations about data layout of the global memory fragments, data types, + and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting TensorOp + instructions. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_mma_core.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from +/// threadblock tile size, global memory data layout, and target math +/// instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator = arch::OpMultiplyAddComplex, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global> +struct DefaultMultistageMmaComplexCore; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h new file mode 100644 index 0000000000..230e8d7681 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h @@ -0,0 +1,1113 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming + expectations about data layout of the global memory fragments, data types, + and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting TensorOp + instructions. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex double-precision +/// +/// A: column-major +/// B: row-major +/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<8, 8, 4>, + complex, layout::ColumnMajor, + complex, layout::RowMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<8, 8, 4>; + using ElementA = complex; + using LayoutA = layout::ColumnMajor; + using ElementB = complex; + using LayoutB = layout::RowMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped 128 + static int const kAccessSizeInBits = 128; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; + + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + + +/// Partial specialization for complex double-precision +/// +/// A: column-major +/// B: row-major +/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<8, 8, 4>, + complex, layout::ColumnMajor, + complex, layout::ColumnMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<8, 8, 4>; + using ElementA = complex; + using LayoutA = layout::ColumnMajor; + using ElementB = complex; + using LayoutB = layout::ColumnMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + using Operator = Operator_; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped 128 + static int const kAccessSizeInBits = 128; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex double-precision +/// +/// A: row-major +/// B: column-major +/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<8, 8, 4>, + complex, layout::RowMajor, + complex, layout::ColumnMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<8, 8, 4>; + using ElementA = complex; + using LayoutA = layout::RowMajor; + using ElementB = complex; + using LayoutB = layout::ColumnMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped 128 + static int const kAccessSizeInBits = 128; + + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + + +/// Partial specialization for complex double-precision +/// +/// A: row-major +/// B: row-major +/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<8, 8, 4>, + complex, layout::RowMajor, + complex, layout::RowMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<8, 8, 4>; + using ElementA = complex; + using LayoutA = layout::RowMajor; + using ElementB = complex; + using LayoutB = layout::RowMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped 128 + static int const kAccessSizeInBits = 128; + + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for complex floating-point +/// +/// A: column-major +/// B: column-major +/// Operator: arch::OpMultiplyAddComplex +/// Math Instruction: MMA.1688.F32.TF32 +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<16, 8, 8>, + complex, layout::ColumnMajor, + complex, layout::ColumnMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<16, 8, 8>; + using ElementA = complex; + using LayoutA = layout::ColumnMajor; + using ElementB = complex; + using LayoutB = layout::ColumnMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped + static int const kAccessSizeInBits = 64; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + + +/// Partial specialization for complex floating-point +/// +/// A: column-major +/// B: row-major +/// Operator: arch::OpMultiplyAddComplex +/// Math Instruction: MMA.1688.F32.TF32 +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<16, 8, 8>, + complex, layout::ColumnMajor, + complex, layout::RowMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<16, 8, 8>; + using ElementA = complex; + using LayoutA = layout::ColumnMajor; + using ElementB = complex; + using LayoutB = layout::RowMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped + static int const kAccessSizeInBits = 64; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex floating-point +/// +/// A: row-major +/// B: column-major +/// Operator: arch::OpMultiplyAddComplex +/// Math Instruction: MMA.1688.F32.TF32 +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<16, 8, 8>, + complex, layout::RowMajor, + complex, layout::ColumnMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<16, 8, 8>; + using ElementA = complex; + using LayoutA = layout::RowMajor; + using ElementB = complex; + using LayoutB = layout::ColumnMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped + static int const kAccessSizeInBits = 64; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex floating-point +/// +/// A: row-major +/// B: row-major +/// Operator: arch::OpMultiplyAddComplex +/// Math Instruction: MMA.1688.F32.TF32 +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Complex transformation on operand A + ComplexTransform TransformA, + /// Complex transformation on operand B + ComplexTransform TransformB, + /// Multiply-add operator (arch::OpMultiplyAddComplex) + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMultistageMmaComplexCore< + Shape_, WarpShape_, GemmShape<16, 8, 8>, + complex, layout::RowMajor, + complex, layout::RowMajor, + complex, LayoutC_, + arch::OpClassTensorOp, + Stages, + TransformA, TransformB, + Operator_, + CacheOpA, CacheOpB> { + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<16, 8, 8>; + using ElementA = complex; + using LayoutA = layout::RowMajor; + using ElementB = complex; + using LayoutB = layout::RowMajor; + using ElementC = complex; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + static_assert(WarpCount::kCount > 1, + "This specialization requires at least two warps."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped + static int const kAccessSizeInBits = 64; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; + + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<16, 2>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + WarpShape, InstructionShape, + ElementA, SmemLayoutA, + ElementB, SmemLayoutB, + ElementC, LayoutC, + kTransformA, kTransformB, + Operator>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/gemv.h b/include/cutlass/gemm/threadblock/gemv.h deleted file mode 100644 index 54da93a984..0000000000 --- a/include/cutlass/gemm/threadblock/gemv.h +++ /dev/null @@ -1,140 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a threadblock-scoped GEMV kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix-vector product using SIMT math instructions. -template < - class Core_ //< GemvCore -> -class Gemv { -public: - using Shape = typename Core_::Shape; - - /// The MMA operator that computes GEMV - using Operator = typename Core_::Operator; - - /// Iterates over A in global memory - using IteratorA = typename Core_::IteratorA; - - /// Iterates over B in global memory - using IteratorB = typename Core_::IteratorB; - - /// Fragment of operand C loaded from global memory - using IteratorC = typename Core_::IteratorC; - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand accumulator loaded/stored to global memory - using FragmentC = typename Operator::FragmentC; - - /// Shape of the per-thread GEMV operation - using ThreadShape = typename Core_::ThreadShape; - -public: - CUTLASS_DEVICE - Gemv() { } - - CUTLASS_DEVICE - void operator()( - GemmCoord const &problem_size, ///< problem size of batched GEMV - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum) { ///< source accumualtor tile - - // - // Prologue - // - - FragmentA frag_A; - FragmentB frag_B; - frag_A.clear(); - frag_B.clear(); - - iterator_A.load(frag_A); - iterator_B.load(frag_B); - ++iterator_A; - ++iterator_B; - - // - // Mainloop - // - Operator thread_mma; - int gemm_k = problem_size.k(); - - if (gemm_k < Shape::kK) - { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } - - // iterate over K to accumulate result - CUTLASS_GEMM_LOOP - for (; gemm_k > 0; gemm_k -= Shape::kK) { - thread_mma(accum, frag_A, frag_B, accum); - - iterator_A.load(frag_A); - iterator_B.load(frag_B); - ++iterator_A; - ++iterator_B; - - if (gemm_k < Shape::kK) - { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/include/cutlass/gemm/threadblock/mma_base.h b/include/cutlass/gemm/threadblock/mma_base.h index 7e6d4fe64b..dbf3d31f56 100644 --- a/include/cutlass/gemm/threadblock/mma_base.h +++ b/include/cutlass/gemm/threadblock/mma_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h new file mode 100644 index 0000000000..0431c3060f --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistage : + public MmaBase { +public: + ///< Base class + using Base = MmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + } + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + } + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + } + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_pipelined.h b/include/cutlass/gemm/threadblock/mma_pipelined.h index 1e707404b2..80954f6c4f 100644 --- a/include/cutlass/gemm/threadblock/mma_pipelined.h +++ b/include/cutlass/gemm/threadblock/mma_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -75,7 +75,7 @@ template < typename IteratorA_::Element, IteratorA_::Fragment::kElements>, /// - /// Transformation applied to A operand + /// Transformation applied to B operand typename TransformB_ = NumericArrayConverter< typename SmemIteratorB_::Element, typename IteratorB_::Element, @@ -118,6 +118,15 @@ class MmaPipelined : public MmaBase { /// Warp-level Mma using Operator = typename Policy::Operator; + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); @@ -256,8 +265,8 @@ class MmaPipelined : public MmaBase { __syncthreads(); - ++this->smem_iterator_B_; ++this->smem_iterator_A_; + ++this->smem_iterator_B_; // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory if (smem_write_stage_idx == 1) { @@ -299,7 +308,8 @@ class MmaPipelined : public MmaBase { } } - warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); } } @@ -311,3 +321,5 @@ class MmaPipelined : public MmaBase { } // namespace threadblock } // namespace gemm } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/include/cutlass/gemm/threadblock/mma_planar_complex_base.h new file mode 100644 index 0000000000..b37b418462 --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_base.h @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaPlanarComplexBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Stride to the imaginary part of the A operand + static int const kImaginaryStrideA = ShapeA::kCount; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + /// Stride to the imaginary part of the A operand + static int const kImaginaryStrideB = ShapeB::kCount; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPlanarComplexBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h new file mode 100644 index 0000000000..18e63b5805 --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h @@ -0,0 +1,642 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/array_planar_complex.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Transformation applied to A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Transformation applied to B + ComplexTransform TransformB = ComplexTransform::kNone +> +class MmaPlanarComplexMultistage : + public MmaPlanarComplexBase { +public: + ///< Base class + using Base = MmaPlanarComplexBase; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Archtecture tag + using ArchTag = arch::Sm80; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Transformation applied to A + static ComplexTransform const kTransformA = TransformA; + + /// Transformation applied to B + static ComplexTransform const kTransformB = TransformB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = ArrayPlanarComplex< + typename Policy::Operator::FragmentC::Element, + Policy::Operator::FragmentC::kElements + >; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of LDGSTS instructions to load one stage of operand A + static int const TBLDGSTSIterationsA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of LDGSTS instructions to load one stage of operand B + static int const TBLDGSTSIterationsB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of LDGSTS instructions to load on group of operand A + static int const kAccessesPerGroupA = + (TBLDGSTSIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of LDGSTS instructions to load on group of operand B + static int const kAccessesPerGroupB = + (TBLDGSTSIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPlanarComplexMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + +private: + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA &iterator_A_real, + IteratorA &iterator_A_imag, + + IteratorB &iterator_B_real, + IteratorB &iterator_B_imag, + + int group_start_A = 0, + int group_start_B = 0) { + + iterator_A_real.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + iterator_A_imag.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // LDGSTS for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + + auto gmem_ptr_real = iterator_A_real.get(); + auto gmem_ptr_imag = iterator_A_imag.get(); + + bool pred_guard = iterator_A_real.valid(); + cutlass::arch::cp_async( + dst_ptr + v, + gmem_ptr_real, + pred_guard); + cutlass::arch::cp_async( + dst_ptr + v + (Base::SharedStorage::kImaginaryStrideA / IteratorA::ThreadMap::kElementsPerAccess), + reinterpret_cast(gmem_ptr_imag), + pred_guard); + + ++iterator_A_real; + ++iterator_A_imag; + } + + ++this->smem_iterator_A_; + } + + iterator_B_real.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + iterator_B_imag.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr_real = iterator_B_real.get(); + auto gmem_ptr_imag = iterator_B_imag.get(); + + bool pred_guard = iterator_B_real.valid(); + cutlass::arch::cp_async( + dst_ptr + v, + gmem_ptr_real, + pred_guard); + cutlass::arch::cp_async( + dst_ptr + v + (Base::SharedStorage::kImaginaryStrideB / IteratorB::ThreadMap::kElementsPerAccess), + reinterpret_cast(gmem_ptr_imag), + pred_guard); + + ++iterator_B_real; + ++iterator_B_imag; + } + ++this->smem_iterator_B_; + } + } + + CUTLASS_DEVICE + void warp_mma_planar_complex( + Operator & warp_mma, + FragmentC &accum, + WarpFragmentA const & real_A, + WarpFragmentA const & imag_A, + WarpFragmentB const & real_B, + WarpFragmentB const & imag_B) { + + cutlass::negate> neg_op_B; + + WarpFragmentB neg_real_B = neg_op_B(real_B); + WarpFragmentB neg_imag_B = neg_op_B(imag_B); + + warp_mma(accum.real, real_A, real_B, accum.real); + + if (kTransformB == ComplexTransform::kNone) { + warp_mma(accum.imag, real_A, imag_B, accum.imag); + } + else { + warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); + } + + if (kTransformA == ComplexTransform::kNone) { + warp_mma(accum.imag, imag_A, real_B, accum.imag); + } + else { + warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); + } + + if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { + warp_mma(accum.real, imag_A, imag_B, accum.real); + } + else { + warp_mma(accum.real, imag_A, neg_imag_B, accum.real); + } + } + +public: + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A_real, + ///< iterator over A operand in global memory + IteratorA iterator_A_imag, + ///< iterator over B operand in global memory + IteratorB iterator_B_real, + ///< iterator over B operand in global memory + IteratorB iterator_B_imag, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + if (gemm_k_iterations == 0) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + + iterator_A_real.set_iteration_index(0); + iterator_A_imag.set_iteration_index(0); + + this->smem_iterator_A_.set_iteration_index(0); + + // LDGSTS for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + bool pred_guard = iterator_A_real.valid(); + + auto src_ptr_real = iterator_A_real.get(); + auto src_ptr_imag = iterator_A_imag.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, src_ptr_real, pred_guard); + + cutlass::arch::cp_async_zfill( + dst_ptr + v + + Base::SharedStorage::kImaginaryStrideA / + IteratorA::ThreadMap::kElementsPerAccess, + reinterpret_cast(src_ptr_imag), + pred_guard); + + ++iterator_A_real; + ++iterator_A_imag; + } + + ++this->smem_iterator_A_; + } + + iterator_B_real.set_iteration_index(0); + iterator_B_imag.set_iteration_index(0); + + this->smem_iterator_B_.set_iteration_index(0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB; ++j) { + + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + bool pred_guard = iterator_B_real.valid(); + + auto src_ptr_real = iterator_B_real.get(); + auto src_ptr_imag = iterator_B_imag.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, src_ptr_real, pred_guard); + + cutlass::arch::cp_async_zfill( + dst_ptr + v + + Base::SharedStorage::kImaginaryStrideB / + IteratorB::ThreadMap::kElementsPerAccess, + reinterpret_cast(src_ptr_imag), + pred_guard); + + ++iterator_B_real; + ++iterator_B_imag; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A_real.add_tile_offset({0, 1}); + iterator_A_imag.add_tile_offset({0, 1}); + + iterator_B_real.add_tile_offset({1, 0}); + iterator_B_imag.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a memory fence between stages of cp.async instructions + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Blocks until all but kStages-2 cp.async stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + + WarpFragmentA warp_frag_real_A[2]; + WarpFragmentA warp_frag_imag_A[2]; + + WarpFragmentB warp_frag_real_B[2]; + WarpFragmentB warp_frag_imag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); + this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); + + this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); + this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (gemm_k_iterations == 0) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag); + + Operator warp_mma; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); + + this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } + else { + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance( + iterator_A_real, + iterator_A_imag, + iterator_B_real, + iterator_B_imag, + group_start_iteration_A, + group_start_iteration_B); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a memory fence between stages of cp.async instructions + cutlass::arch::cp_async_fence(); + + // Blocks until all but kStages-2 cp.async stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A_real.add_tile_offset({0, 1}); + iterator_A_imag.add_tile_offset({0, 1}); + + iterator_B_real.add_tile_offset({1, 0}); + iterator_B_imag.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + if (gemm_k_iterations == 0) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + } + + warp_mma_planar_complex( + warp_mma, + accum, + warp_frag_real_A[warp_mma_k % 2], + warp_frag_imag_A[warp_mma_k % 2], + warp_frag_real_B[warp_mma_k % 2], + warp_frag_imag_B[warp_mma_k % 2]); + } + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h new file mode 100644 index 0000000000..ecf722d92a --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Transformation applied to A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Transformation applied to B + ComplexTransform TransformB = ComplexTransform::kNone +> +class MmaPlanarComplexPipelined : + public MmaPlanarComplexBase { +public: + ///< Base class + using Base = MmaPlanarComplexBase; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Policy describing tuning details + using Policy = Policy_; + + using ArchTag = typename Policy::Operator::ArchTag; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + /// Transformation applied to A + static ComplexTransform const kTransformA = TransformA; + + /// Transformation applied to B + static ComplexTransform const kTransformB = TransformB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = ArrayPlanarComplex< + typename Policy::Operator::FragmentC::Element, + Policy::Operator::FragmentC::kElements + >; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + private: + + using FragmentA = typename IteratorA::Fragment; + using FragmentB = typename IteratorB::Fragment; + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPlanarComplexPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + +private: + + CUTLASS_DEVICE + void warp_mma_planar_complex( + Operator & warp_mma, + FragmentC &accum, + WarpFragmentA const & real_A, + WarpFragmentA const & imag_A, + WarpFragmentB const & real_B, + WarpFragmentB const & imag_B) { + + cutlass::negate> neg_op_B; + + WarpFragmentB neg_real_B = neg_op_B(real_B); + WarpFragmentB neg_imag_B = neg_op_B(imag_B); + + warp_mma(accum.real, real_A, real_B, accum.real); + + if (kTransformB == ComplexTransform::kNone) { + warp_mma(accum.imag, real_A, imag_B, accum.imag); + } + else { + warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); + } + + if (kTransformA == ComplexTransform::kNone) { + warp_mma(accum.imag, imag_A, real_B, accum.imag); + } + else { + warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); + } + + if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { + warp_mma(accum.real, imag_A, imag_B, accum.real); + } + else { + warp_mma(accum.real, imag_A, neg_imag_B, accum.real); + } + } + +public: + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A_real, + ///< iterator over A operand in global memory + IteratorA iterator_A_imag, + ///< iterator over B operand in global memory + IteratorB iterator_B_real, + ///< iterator over B operand in global memory + IteratorB iterator_B_imag, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A_real; + FragmentA tb_frag_A_imag; + + FragmentB tb_frag_B_real; + FragmentB tb_frag_B_imag; + + tb_frag_A_real.clear(); + tb_frag_A_imag.clear(); + + tb_frag_B_real.clear(); + tb_frag_B_imag.clear(); + + // The last kblock is loaded in the prolog + iterator_A_real.load(tb_frag_A_real); + iterator_A_imag.load(tb_frag_A_imag); + + iterator_B_real.load(tb_frag_B_real); + iterator_B_imag.load(tb_frag_B_imag); + + ++iterator_A_real; + ++iterator_A_imag; + + ++iterator_B_real; + ++iterator_B_imag; + + this->smem_iterator_A_.store(tb_frag_A_real); + this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); + + this->smem_iterator_B_.store(tb_frag_B_real); + this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_real_A[2]; + WarpFragmentA warp_frag_imag_A[2]; + + WarpFragmentB warp_frag_real_B[2]; + WarpFragmentB warp_frag_imag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); + this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); + + this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); + this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); + + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + if (gemm_k_iterations <= 1) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(tb_frag_A_real); + this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); + + this->smem_iterator_B_.store(tb_frag_B_real); + this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); + + __syncthreads(); + + ++this->smem_iterator_B_; + ++this->smem_iterator_A_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); + + this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A_real.load(tb_frag_A_real); + iterator_A_imag.load(tb_frag_A_imag); + + iterator_B_real.load(tb_frag_B_real); + iterator_B_imag.load(tb_frag_B_imag); + + ++iterator_A_real; + ++iterator_A_imag; + ++iterator_B_real; + ++iterator_B_imag; + + // Avoid reading out of bounds if this was the last loop iteration + if (gemm_k_iterations <= 2) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + } + + warp_mma_planar_complex( + warp_mma, + accum, + warp_frag_real_A[warp_mma_k % 2], + warp_frag_imag_A[warp_mma_k % 2], + warp_frag_real_B[warp_mma_k % 2], + warp_frag_imag_B[warp_mma_k % 2]); + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_singlestage.h b/include/cutlass/gemm/threadblock/mma_singlestage.h index 99ec9d64c9..32d4d4ee60 100644 --- a/include/cutlass/gemm/threadblock/mma_singlestage.h +++ b/include/cutlass/gemm/threadblock/mma_singlestage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -168,7 +168,6 @@ class MmaSingleStage : public MmaBase { // Perform accumulation in the 'd' output operand accum = src_accum; - FragmentA tb_frag_A; FragmentB tb_frag_B; @@ -183,8 +182,9 @@ class MmaSingleStage : public MmaBase { ++iterator_B; // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; + WarpFragmentA warp_frag_A; + WarpFragmentB warp_frag_B; + Operator warp_mma; // Avoid reading out of bounds @@ -193,7 +193,6 @@ class MmaSingleStage : public MmaBase { iterator_B.clear_mask(); } - // // Mainloop // @@ -203,7 +202,6 @@ class MmaSingleStage : public MmaBase { this->smem_iterator_A_.store(tb_frag_A); this->smem_iterator_B_.store(tb_frag_B); - __syncthreads(); // @@ -216,16 +214,16 @@ class MmaSingleStage : public MmaBase { // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group // as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k) % 2]); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k) % 2]); + this->warp_tile_iterator_A_.load(warp_frag_A); + this->warp_tile_iterator_B_.load(warp_frag_B); ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); + warp_mma(accum, warp_frag_A, warp_frag_B, accum); } // Add negative offsets to return smem load iterators to the 'start' of the shared memory diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/include/cutlass/gemm/threadblock/threadblock_swizzle.h index cd386b4710..03d71d3197 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,7 +30,8 @@ #pragma once #include "cutlass/cutlass.h" - +#include "cutlass/layout/matrix.h" +#include "cutlass/platform/platform.h" #include "cutlass/gemm/gemm.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,12 +99,13 @@ int RematerializeBlockDimZ() { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Threadblock swizzling function for GEMMs +template struct GemmIdentityThreadblockSwizzle { CUTLASS_HOST_DEVICE GemmIdentityThreadblockSwizzle() { } - int const kTile = 1; + int const kTile = N; /// Returns the shape of the problem in units of logical tiles CUTLASS_HOST_DEVICE @@ -186,8 +188,8 @@ struct GemmBatchedIdentityThreadblockSwizzle { CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape( GemmCoord problem_size, - int batch_count, - GemmCoord tile_size) const { + GemmCoord tile_size, + int batch_count) const { return GemmCoord( (problem_size.m() + tile_size.m() - 1) / tile_size.m(), @@ -207,7 +209,7 @@ struct GemmBatchedIdentityThreadblockSwizzle { return GemmCoord{ RematerializeBlockIdxX(), RematerializeBlockIdxY(), - 0 + RematerializeBlockIdxZ() }; } @@ -221,8 +223,11 @@ struct GemmBatchedIdentityThreadblockSwizzle { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Threadblock swizzling function for split-K GEMMs +template struct GemmSplitKIdentityThreadblockSwizzle { + int const kTile = N; + /// Returns the shape of the problem in units of logical tiles CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape( @@ -239,16 +244,20 @@ struct GemmSplitKIdentityThreadblockSwizzle { /// Computes CUDA grid dimensions given a size in units of logical tiles CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const { - return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); + return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k()); } /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE GemmCoord get_tile_offset() const { + + int block_idx_x = RematerializeBlockIdxX(); + int block_idx_y = RematerializeBlockIdxY(); + return GemmCoord{ - RematerializeBlockIdxX(), - RematerializeBlockIdxY(), + (block_idx_x / kTile), + (block_idx_y * kTile) + (block_idx_x % kTile), RematerializeBlockIdxZ() }; } diff --git a/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h b/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h new file mode 100644 index 0000000000..3c6772aff7 --- /dev/null +++ b/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h @@ -0,0 +1,401 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_complex_tensor_op.h" +#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Complex transform on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transform on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator_ = arch::OpMultiplyAddComplex> +struct DefaultMmaComplexTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex*complex case +// 4 real-valued mma operations +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Real-valued underlying type of complex-valued A operand + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Real-valued underlying type of complex-valued B operand + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Real-valued underlying type of complex-valued C operand + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + InstructionShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddComplex> { + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + RealElementA, + cutlass::layout::RowMajor, + RealElementB, + cutlass::layout::ColumnMajor, + RealElementC, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex*complex case using GaussianComplex operation +// 3 real-valued mma operations +// A = (ar + j ai), B = (br +j bi), D = AB +// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) +// D = dr + j di = (P1 - P3) + j (P1 + P2) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Real-valued underlying type of complex-valued A operand + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Real-valued underlying type of complex-valued B operand + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Real-valued underlying type of complex-valued C operand + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + InstructionShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddGaussianComplex> { + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + RealElementA, + cutlass::layout::RowMajor, + RealElementB, + cutlass::layout::ColumnMajor, + RealElementC, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB>; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization - input and output types are complex*complex +// Use TF32 tensor operation internally +// 4 real-valued MMA.1688.F32.TF32 operations on TF32 +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + InstructionShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddComplex> { + + // Complex floating point tensor operation use MMA.1688.F32.TF32 mma instruction + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + tfloat32_t, + cutlass::layout::RowMajor, + tfloat32_t, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization - input and output types are complex*complex +// Use BF16 tensor operation internally +// 4 real-valued MMA.1688.F32.BF16 operations on BF16 +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + InstructionShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddFastBF16> { + + // Complex floating point tensor operation use MMA.1688.F32.BF16 mma instruction + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + bfloat16_t, + cutlass::layout::RowMajor, + bfloat16_t, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization - input and output types are complex*complex +// Use F16 tensor operation internally +// 4 real-valued MMA.1688.F32.F16 operations on F16 +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + InstructionShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddFastF16> { + + // Complex floating point tensor operation use MMA.1688.F32.F16 mma instruction + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + half_t, + cutlass::layout::RowMajor, + half_t, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB>; +}; + +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op.h b/include/cutlass/gemm/warp/default_mma_tensor_op.h index 5bf1b74a56..ea9ab5c931 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -60,10 +60,7 @@ template < int PartitionsK = 1, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Number of partitions along N dimension per warp - int PartitionsN = 1 -> + bool AccumulatorsInRowMajor = false> struct DefaultMmaTensorOp; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -92,9 +89,7 @@ template < int PartitionsK, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. - bool AccumulatorsInRowMajor, - /// Number of partitions along N dimension per warp - int PartitionsN> + bool AccumulatorsInRowMajor> struct DefaultMmaTensorOp { using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Mma; + Policy, PartitionsK, AccumulatorsInRowMajor>; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -114,3 +109,9 @@ struct DefaultMmaTensorOp { } // namespace warp } // namespace gemm } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "default_mma_tensor_op_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h new file mode 100644 index 0000000000..06d3afa59f --- /dev/null +++ b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -0,0 +1,186 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization - inputs and output types are float - uses BF16 internally +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp< + WarpShape_, + GemmShape<16, 8, 8>, + float, LayoutA, + float, LayoutB, + float, LayoutC, + arch::OpMultiplyAddFastBF16, + PartitionsK, AccumulatorsInRowMajor> { + + // Uses BF16 internally + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 8>, + 32, + bfloat16_t, cutlass::layout::RowMajor, + bfloat16_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOp< + WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization - inputs and output types are float - uses F16 internally +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp< + WarpShape_, + GemmShape<16, 8, 8>, + float, LayoutA, + float, LayoutB, + float, LayoutC, + arch::OpMultiplyAddFastF16, + PartitionsK, AccumulatorsInRowMajor> { + + // Uses F16 internally + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 8>, + 32, + half_t, cutlass::layout::RowMajor, + half_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOp< + WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization - inputs and output types are float - uses TF32 internally +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of target matrix multiply instruction (concept: GemmShape) + typename InstructionShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp< + WarpShape_, + InstructionShape_, + float, LayoutA, + float, LayoutB, + float, LayoutC, + arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { + + // Uses TF32 internally + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + tfloat32_t, cutlass::layout::RowMajor, + tfloat32_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOp< + WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h b/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h index 11964944f7..582fb472e1 100644 --- a/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h +++ b/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -61,9 +61,7 @@ template < /// Operator describing the tensor operation typename Operator_ = arch::OpMultiplyAdd, /// Number of partitions along K dimension - int PartitionsK = 1, - /// Number of partitions along N dimension per warp - int PartitionsN = 1 + int PartitionsK = 1 > struct DefaultMmaTensorOpWmma; @@ -90,9 +88,7 @@ template < /// Operator describing the tensor operation typename Operator_, /// Number of partitions along K dimension - int PartitionsK, - /// Number of partitions along N dimension per warp - int PartitionsN> + int PartitionsK> struct DefaultMmaTensorOpWmma { using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Wmma< @@ -116,8 +112,7 @@ struct DefaultMmaTensorOpWmma { ElementC, LayoutC, Policy, - PartitionsK, - PartitionsN>; + PartitionsK>; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -127,4 +122,3 @@ struct DefaultMmaTensorOpWmma { } // namespace cutlass #endif - diff --git a/include/cutlass/gemm/warp/mma.h b/include/cutlass/gemm/warp/mma.h index 5fb96d9f27..16c736e2b7 100644 --- a/include/cutlass/gemm/warp/mma.h +++ b/include/cutlass/gemm/warp/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h index 073b131c5a..2dc72fd333 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -35,9 +35,12 @@ #include "cutlass/complex.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" +#include "cutlass/functional.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma.h" @@ -45,6 +48,9 @@ #include "cutlass/gemm/warp/mma_tensor_op.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -53,6 +59,171 @@ namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +template < + /// Data type of real & imag members of complex numbers in the SourceFragment + typename RealElement, + /// Destination fragment required by the mma operation + typename DestinationFragment, + /// Source fragment holding complex elements + typename SourceFragment, + /// Number of mma operations performed + typename MmaIterations, + /// Shape of operand elements + typename MmaOperandShape, + /// Complex transform on A operand + ComplexTransform Transform_, + /// Operand A or Operand B + Operand Operand_, + /// Floating-point rounding style + FloatRoundStyle Round_> +struct UnpackComplexConvertAndPackForMma; + +// Partial specialization for OperandA and Congruous smem layout +template < + typename RealElement, + typename DestinationFragment, + typename SourceFragment, + typename MmaIterations, + typename MmaOperandShape, + ComplexTransform Transform_, + FloatRoundStyle Round_> +struct UnpackComplexConvertAndPackForMma < + RealElement, + DestinationFragment, + SourceFragment, + MmaIterations, + MmaOperandShape, + Transform_, + Operand::kA, + Round_> { + + // + // Type definitions + // + static Operand const kOperand = Operand::kA; + static ComplexTransform const kTransform = Transform_; + static FloatRoundStyle const kRound = Round_; + + // Data type of elements in the destination fragment + using MmaElement = typename DestinationFragment::Element; + + // Numeric convertor MmaElement <= RealElement + using Converter = NumericConverter; + + // Operand layout parameters + using SourceFragmentLayout = layout::ColumnMajor; + static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; + + /// Ctor + CUTLASS_DEVICE + UnpackComplexConvertAndPackForMma() {} + + CUTLASS_DEVICE + void operator()(DestinationFragment *dest, SourceFragment const &source) { + + Converter convert_op; + SourceFragmentLayout layout(kLdm); + + CUTLASS_PRAGMA_UNROLL + for(int i=0; i and apply rounding on real and imag parts + MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); + MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); + + // Unpack rounded complex and pack into DestinationFragment for mma operation + dest[i][pos] = a; + dest[i+MmaIterations::kRow][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); + + } + } + } + } +}; + +// Partial specialization for OperandB and Congruous smem layout +template < + typename RealElement, + typename DestinationFragment, + typename SourceFragment, + typename MmaIterations, + typename MmaOperandShape, + ComplexTransform Transform_, + FloatRoundStyle Round_> +struct UnpackComplexConvertAndPackForMma < + RealElement, + DestinationFragment, + SourceFragment, + MmaIterations, + MmaOperandShape, + Transform_, + Operand::kB, + Round_> { + + // + // Type definitions + // + static Operand const kOperand = Operand::kB; + static ComplexTransform const kTransform = Transform_; + static FloatRoundStyle const kRound = Round_; + + // Data type of elements in the destination fragment + using MmaElement = typename DestinationFragment::Element; + + // Numeric convertor MmaElement <= RealElement + using Converter = NumericConverter; + + // Operand layout parameters + using SourceFragmentLayout = layout::RowMajor; + static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; + + /// Ctor + CUTLASS_DEVICE + UnpackComplexConvertAndPackForMma() {} + + CUTLASS_HOST_DEVICE + void operator()(DestinationFragment *dest, SourceFragment const &source) { + + Converter convert_op; + SourceFragmentLayout layout(kLdm); + + CUTLASS_PRAGMA_UNROLL + for(int i=0; i apply rounding on real and imag parts + MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); + MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); + + // Unpack rounded complex and pack into DestinationFragment for mma operation + dest[i][pos] = a; + dest[i+MmaIterations::kColumn][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); + } + } + } + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, @@ -140,9 +311,12 @@ class MmaComplexTensorOp< /// Layout of accumulator matrix C using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) using Policy = Policy_; + /// Shape of underlying instruction + using InstructionShape = typename Policy::Operator::Shape; + /// Complex transform on A operand static ComplexTransform const kTransformA = TransformA; @@ -172,6 +346,9 @@ class MmaComplexTensorOp< /// Storage for A tile using FragmentA = typename IteratorA::Fragment; + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< MatrixShape, @@ -187,6 +364,8 @@ class MmaComplexTensorOp< /// Storage for B tile using FragmentB = typename IteratorB::Fragment; + /// Storage for transformed B tile + using TransformedFragmentB = FragmentB; static_assert( !(Shape::kM % Policy::Operator::Shape::kM) && @@ -242,7 +421,8 @@ class MmaComplexTensorOp< FragmentC &D, FragmentA const &A, FragmentB const &B, - FragmentC const &C) const { + FragmentC const &C + ) const { // Alias types for underlying real-valued matrix multiply operator using MmaOperandA = typename Policy::Operator::FragmentA; @@ -254,7 +434,7 @@ class MmaComplexTensorOp< "We can geneneralize later."); static_assert(MmaOperandB::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the A operand." + "This implementation only supports math instructions in which exactly one element is needed for the B operand." "We can geneneralize later."); D = C; @@ -310,7 +490,7 @@ class MmaComplexTensorOp< operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag()); operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); - // Complex-valued accumulator part + // Real-valued accumulator part MmaOperandC *accum = reinterpret_cast(&D) + (m + n * MmaIterations::kRow); @@ -328,7 +508,7 @@ class MmaComplexTensorOp< operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag()); operand_B[0] = B[n].real(); - // Real-valued accumulator part + // Complex-valued accumulator part MmaOperandC *accum = reinterpret_cast(&D) + (m + n * MmaIterations::kRow) + MmaIterations::kCount; @@ -336,6 +516,318 @@ class MmaComplexTensorOp< } } } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + //TODO: Implement this + dst_A = A; + dst_B = B; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex*complex+complex => complex: +// Operands data type: complex +// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) +// Math instruction: MMA.1688.F32.TF32 +// Output data type: complex +// +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB, + /// Used for partial specialization + typename Enable +> +class MmaComplexTensorOp< + Shape_, + complex, + LayoutA_, + complex, + LayoutB_, + complex, + LayoutC_, + Policy_, + TransformA, + TransformB, + Enable> { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of members of complex multiplicand A + using RealElementA = float; + + /// Data type of multiplicand A + using ElementA = complex; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of members of complex multiplicand B + using RealElementB = float; + + /// Data type of multiplicand B + using ElementB = complex; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of members of complex accumulator matrix C + using RealElementC = float; + + /// Data type of accumulator matrix C + using ElementC = complex; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Shape of underlying instruction + using InstructionShape = typename Policy::Operator::Shape; + + /// Underlying arch tag + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + 32, + 1 + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kColumn, + 32, + 1 + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + static_assert( + !(Shape::kM % Policy::Operator::Shape::kM) && + !(Shape::kN % Policy::Operator::Shape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + /// Number of complex products operations performed (one complex product needs four mma instructions) + using MmaIterations = MatrixShape< + Shape::kM / Policy::Operator::Shape::kM, + Shape::kN / Policy::Operator::Shape::kN + >; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, + ElementC, + LayoutC, + typename Policy::Operator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this + /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued + /// parts are stored consecutively followed by all imaginary parts. This matches the structure + /// of Tensor Cores which are always real-valued matrix multiplies. + using FragmentC = typename IteratorC::Fragment; + +private: + + // + // Data members + // + + /// Underlying real-valued matrix multiply operator (concept: arch::Mma) + typename Policy::Operator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaComplexTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + // Alias types for underlying real-valued matrix multiply operator + using InstMmaOperandA = typename Policy::Operator::FragmentA; + using InstMmaOperandB = typename Policy::Operator::FragmentB; + using MmaOperandC = typename Policy::Operator::FragmentC; + + static_assert(platform::is_same, typename Policy::Operator::Shape>::value, + "This implementation only supports MMA.1688 math instructions."); + + static_assert(InstMmaOperandA::kElements == 4, + "This implementation only supports math instructions in which exactly four element is needed for the A operand." + "We can geneneralize later."); + + static_assert(InstMmaOperandB::kElements == 2, + "This implementation only supports math instructions in which exactly two element is needed for the B operand." + "We can geneneralize later."); + + // Instruction Operands A & B holding real part followed by imaginary part for mma operations + InstMmaOperandA const *operand_A = reinterpret_cast(&A); + InstMmaOperandB const *operand_B = reinterpret_cast(&B); + + // + // Accumulate in place + // + D = C; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + // mma(accum.real(), a.real(), b.real(), accum.real()); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Real-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow); + + mma(*accum, operand_A[m], operand_B[n], *accum); + } + + // mma(accum.imag(), a.real(), b.imag(), accum.imag()); + CUTLASS_PRAGMA_UNROLL + for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { + + // Complex-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + MmaIterations::kCount; + + mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); + } + + // mma(accum.real(), a.imag(), -b.imag(), accum.real()) + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // negate OperandB to accumulate -(a.imag()*b.imag()) + // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements + negate negate_op; + + // Real-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow); + + mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); + } + + // mma(accum.imag(), a.imag(), b.real(), accum.imag()) + CUTLASS_PRAGMA_UNROLL + for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { + + // Complex-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + MmaIterations::kCount; + + mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); + } + } + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + // Alias types for underlying real-valued matrix multiply operator + using InstMmaOperandA = typename Policy::Operator::FragmentA; + using InstMmaOperandB = typename Policy::Operator::FragmentB; + + // + // Define conversions from source type to instruction operands' type + // + + FloatRoundStyle const kRoundA = FloatRoundStyle::round_half_ulp_trunc_dntz; + FloatRoundStyle const kRoundB = FloatRoundStyle::round_half_ulp_trunc_dntz; + + detail::UnpackComplexConvertAndPackForMma < + RealElementA, + InstMmaOperandA, + FragmentA, + MmaIterations, + MatrixShape<2, 2>, + kTransformA, + Operand::kA, + kRoundA> convert_A; + + detail::UnpackComplexConvertAndPackForMma < + RealElementB, + InstMmaOperandB, + FragmentB, + MmaIterations, + MatrixShape<2, 1>, + kTransformB, + Operand::kB, + kRoundB> convert_B; + + // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element + convert_A(reinterpret_cast(&dst_A), A); + convert_B(reinterpret_cast(&dst_B), B); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h new file mode 100644 index 0000000000..b95af0df15 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h @@ -0,0 +1,2448 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for loading 128b vectors of 128b elements. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::TensorOpMultiplicandCongruous128b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + static_assert(!(Shape::kContiguous % 8) && !(Shape::kStrided % 4), "Divisibility."); + + static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicandCongruous128b; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Load two elements per access + static int const kElementsPerAccess = 1; + + /// Policy defining internal details of tile iterator + struct Policy { + + /// Shape of one access + using Delta = layout::PitchLinearShape<8, 4>; + + /// Number of iterations to load + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / Delta::kContiguous, + InstructionShape::kStrided / Delta::kStrided + >; + }; + +private: + + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = AlignedArray; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + +private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): + stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { + + int quad_pair = lane_id / 8; + int quad = lane_id / 4; + int lane = lane_id % 4; + + int row = (quad & 1) * 4 + (lane ^ quad_pair); + + byte_offset_ = (row + quad_pair * stride_) * sizeof(AccessType); + + pointer_= reinterpret_cast(ref.data()); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + pointer_ += offset; + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + int offset = + (tile_offset.contiguous() * Shape::kContiguous) + + (tile_offset.strided() * InstructionShape::kStrided * stride_); + + add_pointer_offset(offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + pointer_ += stride_ * InstructionShape::kStrided; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + AccessType *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::Iterations::kStrided; ++s) { + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { + + int access_idx = c + s * Policy::Iterations::kContiguous; + + AccessType const *source_ptr = pointer_ + + Policy::Delta::kContiguous * c + + Policy::Delta::kStrided * s * stride_; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; + + AccessType const *source = reinterpret_cast(source_byte_ptr); + + fetch_ptr[access_idx] = *source; + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + Index pointer_offset = + tile_offset.contiguous() * Shape::kContiguous + + tile_offset.strided() * InstructionShape::kStrided * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicandCongruous128b, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.strided(), tile_offset.contiguous()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicandCongruous128b, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.contiguous(), tile_offset.strided()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// +/// Partial specialization for complex +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of underlying field of reals. + typename RealElement, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_> +class MmaTensorOpAccumulatorTileIterator< + Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { + public: + + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand::kC; + + /// Element type + using Element = complex; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + using OpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + static_assert(platform::is_same::value, + "Layouts must be defined for logical MatrixCoord coordinate space."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape; + }; + +private: + + // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire + // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements + // of that row. The accumulators within one row are assumed to be consecutive. + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators + /// are stored in a planar complex arrangement with the real parts as entirely contiguous + /// followed by the imaginary parts. + using Fragment = Array; + + static int const kRealIndex = 0; + static int const kImaginaryIndex = Shape::kCount / kThreads; + +private: + + /// Reference to output tensor + TensorRef ref_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator( + TensorRef const &ref, + int lane_id + ): + ref_(ref) { + + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + + MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); + + ref_.add_coord_offset(lane_offset); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator & operator++() { + // deliberate no-op + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator & operator--() { + // deliberate no-op + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + Fragment &frag, ///< fragment to load from the tensor + Index pointer_offset) const { ///< loads a tile with a linear offset + + TensorRef offset_ref(ref_); + offset_ref.add_pointer_offset(pointer_offset); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile; + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; + + Element z = offset_ref.at({accum_m, accum_n}); + + frag[mma_accum_start + row * kElementsPerAccess + col + kRealIndex] = z.real(); + frag[mma_accum_start + row * kElementsPerAccess + col + kImaginaryIndex] = z.imag(); + } + } + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + Fragment &frag, ///< fragment to load from the tensor + Index byte_offset) const { ///< loads a tile with a linear offset + + load_with_pointer_offset(byte_offset / sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + Fragment &frag, ///< fragment to load from the tensor + TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles + + load(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + Fragment &frag, ///< fragment to load from the tensor + TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles + Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset + + load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); + } + + /// Stores a fragment to memory + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) const { + store_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory with additional pointer offset + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, ///< fragment to store from the tensor + Index pointer_offset) const { ///< store a tile with a linear offset + + TensorRef offset_ref(ref_); + offset_ref.add_pointer_offset(pointer_offset); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile; + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; + int idx = mma_accum_start + row * kElementsPerAccess + col; + + Element z(frag[kRealIndex + idx], frag[kImaginaryIndex + idx]); + + offset_ref.at({accum_m, accum_n}) = z; + } + } + } + } + } + + /// Stores a fragment to memory with additional pointer offset + CUTLASS_DEVICE + void store_with_byte_offset( + Fragment const &frag, ///< fragment to store from the tensor + Index byte_offset) const { ///< store a tile with a linear offset + + store_with_pointer_offset(byte_offset / sizeof(Element)); + } + + /// Stores a fragment to memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void store( + Fragment &frag, ///< fragment to store to the tensor + TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles + + store(frag, tile_offset, 0); + } + + /// Stores a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void store( + /// fragment to store to the tensor + Fragment const &frag, + /// stores a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// stores a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for loading 128b vectors of 128b elements. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::TensorOpMultiplicandCrosswise128x4, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 8), "Divisibility."); + + static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicandCrosswise128x4; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Load two elements per access + static int const kElementsPerAccess = 1; + + /// Policy defining internal details of tile iterator + struct Policy { + + /// Shape of one access + using Delta = layout::PitchLinearShape<4, 8>; + + /// Number of iterations to load + using Iterations = layout::PitchLinearShape< + InstructionShape::kContiguous / Delta::kContiguous, + Shape::kStrided / Delta::kStrided + >; + }; + +private: + + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = AlignedArray; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + +private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): + stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { + + int quad = lane_id / 4; + int liq = lane_id % 4; + + int c = liq + (quad & 1) * 4; + int s = (quad / 2); + + byte_offset_ = (c + s * stride_) * sizeof(AccessType); + + pointer_= reinterpret_cast(ref.data()); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + pointer_ += offset; + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + // Compute the offset in units of elements. Note, the external coordinate system is + // approximately transposed with respect to the tiled internal structure + int offset = + (tile_offset.contiguous() * InstructionShape::kContiguous) * stride_ + + (tile_offset.strided() * Shape::kStrided); + + add_pointer_offset(offset); + + byte_offset_ ^= (tile_offset.contiguous() & 1) * 4 * sizeof(AccessType); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + pointer_ += stride_ * InstructionShape::kContiguous; + + byte_offset_ ^= 4 * sizeof(AccessType); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + AccessType *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::Iterations::kStrided; ++s) { + + int access_idx = s + c * Policy::Iterations::kStrided; + + AccessType const *source_ptr = pointer_ + + Policy::Delta::kContiguous * c * stride_ + + Policy::Delta::kStrided * s; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; + + AccessType const *source = reinterpret_cast(source_byte_ptr); + + fetch_ptr[access_idx] = *source; + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + Index pointer_offset = + tile_offset.contiguous() * InstructionShape::kContiguous * stride_ + + tile_offset.strided() * Shape::kStrided; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicandCrosswise128x4, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.strided(), tile_offset.contiguous()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicandCrosswise128x4, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.contiguous(), tile_offset.strided()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Congruous shared memory layout +// Warp-level iterators for complex*complex + complex => complex +// The underlying iterators are similar to that for MMA f64*f64 + f64 = f64 +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for loading 128b vectors of 64b elements. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, cutlass::complex, + cutlass::layout::TensorOpMultiplicandCongruous64b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 8), "Divisibility."); + + /// Element type + using Element = cutlass::complex; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Load two elements per access + static int const kElementsPerAccess = 2; + + /// Policy defining internal details of tile iterator + struct Policy { + + /// Shape of one access + using Delta = layout::PitchLinearShape<8, 4>; + + /// Number of iterations to load + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, + InstructionShape::kStrided / Delta::kStrided + >; + + }; + +private: + + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = AlignedArray; + + /// Internal counter used to jump to next K partition + int k_group_idx_; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + +private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): + stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), + k_group_idx_(0) { + + int access_strided = lane_id / Policy::Delta::kContiguous; + int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; + + pointer_= reinterpret_cast(ref.data()) + + access_contiguous + access_strided * stride_; + + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + byte_offset_ += offset * sizeof(Element); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + int offset = + (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + + tile_offset.contiguous() * Shape::kContiguous; + + add_pointer_offset(offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + add_tile_offset({0, 1}); + + return *this; + } + + /// Advances the iterator along the opposite of the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + add_tile_offset({0, -1}); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + AccessType *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::Iterations::kStrided; ++s) { + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { + + int access_idx = c + s * Policy::Iterations::kContiguous; + + AccessType const *source_ptr = pointer_ + + Policy::Delta::kContiguous * c + + Policy::Delta::kStrided * s * stride_; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; + + AccessType const *source = reinterpret_cast(source_byte_ptr); + + fetch_ptr[access_idx] = *source; + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + + Index pointer_offset = + tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + + tile_offset.strided() * InstructionShape::kStrided * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Crosswise shared memory layout +// Warp-level iterators for complex*complex + complex => complex +// The underlying iterators are similar to that for f64*f64 + f64 = f64 +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for loading 128b vectors of 64b elements. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, complex, + cutlass::layout::TensorOpMultiplicand64bCrosswise, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); + + static_assert(sizeof_bits>::value == 64, "This is specialized for 64b accesses."); + + /// Element type + using Element = complex; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Load two elements per access + static int const kElementsPerAccess = 2; + + /// Policy defining internal details of tile iterator + struct Policy { + + /// Shape of one access + using Delta = layout::PitchLinearShape<4, 16>; + + /// Number of iterations to load + using Iterations = layout::PitchLinearShape< + InstructionShape::kContiguous / Delta::kContiguous, + Shape::kStrided / Delta::kStrided + >; + + }; + +private: + + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = AlignedArray; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + +private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + + /// Internal counter for tracking K-group + Index k_group_idx_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): + stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), + k_group_idx_(0) { + + int access_strided = lane_id / 8; + int access_contiguous = (lane_id % 8); + + byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); + + pointer_= reinterpret_cast(ref.data()); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + pointer_ += offset / kElementsPerAccess; + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * + stride_ * kElementsPerAccess + + tile_offset.strided() * Shape::kStrided; + + add_pointer_offset(offset); + + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + pointer_ += stride_ * InstructionShape::kContiguous; + + // xor ptr + byte_offset_ ^= 0x40; + + ++k_group_idx_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + AccessType *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::Iterations::kStrided; ++s) { + + int access_idx = c * Policy::Iterations::kStrided + s; + + AccessType const *source_ptr = pointer_ + + Policy::Delta::kContiguous * c * stride_ + + Policy::Delta::kStrided * s / kElementsPerAccess; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; + + AccessType const *source = reinterpret_cast(source_byte_ptr); + + fetch_ptr[access_idx] = *source; + } + } + + Element *exchange_ptr = reinterpret_cast(&frag); + + // exchange on 64b granularity only for fragments held in k=8/2 to k=8 + CUTLASS_PRAGMA_UNROLL + for (int i = Fragment::kElements/2; i < Fragment::kElements; i += 2) { + Element tmp = exchange_ptr[i]; + exchange_ptr[i] = exchange_ptr[i + 1]; + exchange_ptr[i + 1] = tmp; + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + Index pointer_offset = tile_offset.contiguous() * + InstructionShape::kContiguous / + Layout::kElementsPerAccess + + tile_offset.strided() * Shape::kStrided * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + k_group_idx_ = k_group; + } +}; + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h new file mode 100644 index 0000000000..bf3d98dfbe --- /dev/null +++ b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Complex transform on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transform on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool +> +class MmaGaussianComplexTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB, + /// Used for partial specialization + typename Enable +> +class MmaGaussianComplexTensorOp< + Shape_, + complex, + LayoutA_, + complex, + LayoutB_, + complex, + LayoutC_, + Policy_, + TransformA, + TransformB, + Enable> { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = complex; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = complex; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = complex; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Shape of underlying instruction + using InstructionShape = typename Policy::Operator::Shape; + + /// Underlying architecture tag + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + 32, + 1 + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kColumn, + 32, + 1 + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = FragmentB; + + static_assert( + !(Shape::kM % Policy::Operator::Shape::kM) && + !(Shape::kN % Policy::Operator::Shape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + Shape::kM / Policy::Operator::Shape::kM, + Shape::kN / Policy::Operator::Shape::kN + >; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< + MatrixShape, + ElementC, + LayoutC, + typename Policy::Operator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this + /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is + /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively + /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. + using FragmentC = typename IteratorC::Fragment; + + static_assert( + FragmentC::kElements == 3 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements, + "Unexpected gaussian complex fragment length."); + +private: + + // + // Data members + // + + /// Underlying real-valued matrix multiply operator (concept: arch::Mma) + typename Policy::Operator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaGaussianComplexTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C + ) const { + + // Alias types for underlying real-valued matrix multiply operator + using MmaOperandA = typename Policy::Operator::FragmentA; + using MmaOperandB = typename Policy::Operator::FragmentB; + using MmaOperandC = typename Policy::Operator::FragmentC; + + static_assert(MmaOperandA::kElements == 1, + "This implementation only supports math instructions in which exactly one element is needed for the A operand." + "We can geneneralize later."); + + static_assert(MmaOperandB::kElements == 1, + "This implementation only supports math instructions in which exactly one element is needed for the B operand." + "We can geneneralize later."); + + D = C; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_Asum; + MmaOperandB operand_Br; + + operand_Asum[0] = A[m].real() + ((kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag()); + operand_Br[0] = B[n].real(); + + // accumulator part1 + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow); + + mma(*accum, operand_Asum, operand_Br, *accum); + } + + // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); + CUTLASS_PRAGMA_UNROLL + for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_Ar; + MmaOperandB operand_Bdiff; + + operand_Ar[0] = -A[m].real(); + operand_Bdiff[0] = B[n].real() - ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); + + // accumulator part2 + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + MmaIterations::kCount; + + mma(*accum, operand_Ar, operand_Bdiff, *accum); + } + + // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_Ai; + MmaOperandB operand_Bsum; + + operand_Ai[0] = (kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag(); + operand_Bsum[0] = B[n].real() + ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); + + // accumulator part3 + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; + + mma(*accum, operand_Ai, operand_Bsum, *accum); + } + } + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + //TODO: Implement this + dst_A = A; + dst_B = B; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// TODO - partial specializations of real*complex and complex*real + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h new file mode 100644 index 0000000000..8d9417b0fb --- /dev/null +++ b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_> +class MmaTensorOpGaussianComplexAccumulatorTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// +/// Partial specialization for complex +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of underlying field of reals. + typename RealElement, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_> +class MmaTensorOpGaussianComplexAccumulatorTileIterator< + Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { + public: + + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand::kC; + + /// Element type + using Element = complex; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + using OpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + static_assert(platform::is_same::value, + "Layouts must be defined for logical MatrixCoord coordinate space."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape; + }; + +private: + + // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire + // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements + // of that row. The accumulators within one row are assumed to be consecutive. + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators + /// are stored in a gaussian complex arrangement with parts 1, 2, and 3 as entirely contiguous + /// arranged as [part1, part2, part3] + using Fragment = Array; + + static int const kPart1Index = (Shape::kCount / kThreads) * 0; + static int const kPart2Index = (Shape::kCount / kThreads) * 1; + static int const kPart3Index = (Shape::kCount / kThreads) * 2; + +private: + + /// Reference to output tensor + TensorRef ref_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator( + TensorRef const &ref, + int lane_id + ): + ref_(ref) { + + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + + MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); + + ref_.add_coord_offset(lane_offset); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator & operator++() { + // deliberate no-op + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator & operator--() { + // deliberate no-op + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpGaussianComplexAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + Fragment &frag, ///< fragment to load from the tensor + Index pointer_offset) const { ///< loads a tile with a linear offset + + TensorRef offset_ref(ref_); + offset_ref.add_pointer_offset(pointer_offset); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile; + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; + + Element z = offset_ref.at({accum_m, accum_n}); + + frag[mma_accum_start + row * kElementsPerAccess + col + kPart1Index] = z.real() + z.imag(); + frag[mma_accum_start + row * kElementsPerAccess + col + kPart2Index] = -z.real(); + frag[mma_accum_start + row * kElementsPerAccess + col + kPart3Index] = z.imag(); + } + } + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + Fragment &frag, ///< fragment to load from the tensor + Index byte_offset) const { ///< loads a tile with a linear offset + + load_with_pointer_offset(byte_offset / sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + Fragment &frag, ///< fragment to load from the tensor + TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles + + load(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + Fragment &frag, ///< fragment to load from the tensor + TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles + Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset + + load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); + } + + /// Stores a fragment to memory + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) const { + store_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory with additional pointer offset + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, ///< fragment to store from the tensor + Index pointer_offset) const { ///< store a tile with a linear offset + + TensorRef offset_ref(ref_); + offset_ref.add_pointer_offset(pointer_offset); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile; + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; + int idx = mma_accum_start + row * kElementsPerAccess + col; + + Element z(frag[kPart1Index + idx] - frag[kPart3Index + idx], + frag[kPart1Index + idx] + frag[kPart2Index + idx]); + + offset_ref.at({accum_m, accum_n}) = z; + } + } + } + } + } + + /// Stores a fragment to memory with additional pointer offset + CUTLASS_DEVICE + void store_with_byte_offset( + Fragment const &frag, ///< fragment to store from the tensor + Index byte_offset) const { ///< store a tile with a linear offset + + store_with_pointer_offset(byte_offset / sizeof(Element)); + } + + /// Stores a fragment to memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void store( + Fragment &frag, ///< fragment to store to the tensor + TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles + + store(frag, tile_offset, 0); + } + + /// Stores a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void store( + /// fragment to store to the tensor + Fragment const &frag, + /// stores a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// stores a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_planar_complex.h b/include/cutlass/gemm/warp/mma_planar_complex.h new file mode 100644 index 0000000000..c579044065 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_planar_complex.h @@ -0,0 +1,176 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/array_planar_complex.h" +#include "cutlass/gemm/warp/tile_iterator_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Underlying real-valued warp-level matrix multiply + typename Operator_, + /// Transformation applied to A operand (typically folded into math instruction) + ComplexTransform TransformA = ComplexTransform::kNone, + /// Transformation applied to B operand (typically folded into math instruction) + ComplexTransform TransformB = ComplexTransform::kNone +> +class MmaPlanarComplex { +public: + + /// Underlying real-valued warp-level matrix multiply + using Operator = Operator_; + + /// Shape of warp-level matrix multipy + using Shape = typename Operator::Shape; + + /// Transformation applied to A operand (typically folded into math instruction) + static ComplexTransform const kTransformA = TransformA; + + /// Transformation applied to B operand (typically folded into math instruction) + static ComplexTransform const kTransformB = TransformB; + + /// Fragment of elements + using FragmentA = ArrayPlanarComplex; + + /// Iterator into planar complex + using IteratorA = TileIteratorPlanarComplex; + + /// Layout in memory of the A operand + using LayoutA = typename Operator::LayoutA; + + using FragmentB = ArrayPlanarComplex; + + /// Iterator into planar complex + using IteratorB = TileIteratorPlanarComplex; + + /// Layout in memory of the B operand + using LayoutB = typename Operator::LayoutB; + + /// Tile iterator for accumulator + using IteratorC = TileIteratorPlanarComplex; + + /// Accumulator fragment + using FragmentC = ArrayPlanarComplex; + + /// Layout of accumulator fragment in memory + using LayoutC = typename Operator::LayoutC; + +private: + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + Operator::Shape::kM / Operator::Policy::Operator::Shape::kM, + Operator::Shape::kN / Operator::Policy::Operator::Shape::kN + >; + +public: + /// Ctor + CUTLASS_DEVICE + MmaPlanarComplex() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A_in, + FragmentB const &B_in, + FragmentC const &C) const { + + D.real = C.real; + D.imag = C.imag; + + // + // Transform fragments based on conjugate operations. + // + + negate neg_A; + + FragmentA frag_A; + frag_A.real = A_in.real; + + if (kTransformA == ComplexTransform::kConjugate) { + frag_A.imag = neg_A(frag_A.imag); + } + else { + frag_A.imag = frag_A.imag; + } + + FragmentB frag_B; + frag_B.real = B_in.real; + + if (kTransformB == ComplexTransform::kConjugate) { + negate neg; + frag_B.imag = neg(frag_B.imag); + } + else { + frag_B.imag = frag_B.imag; + } + + // + // Accumulated real-valued matrix multiplies + // + + Operator real_mma; + + // D.i += A.i * B.r + real_mma(D.imag, frag_A.imag, frag_B.real, D.imag); + + // D.r += A.r * B.r + real_mma(D.real, frag_A.real, frag_B.real, D.real); + + // D.i += A.r * B.i + real_mma(D.imag, frag_A.real, frag_B.imag, D.imag); + + // D.r += -A.i * B.i + frag_A.imag = neg_A(frag_A.imag); + real_mma(D.real, frag_A.imag, frag_B.imag, D.real); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_simt.h b/include/cutlass/gemm/warp/mma_simt.h index eecb6aaeed..1bf23c7432 100644 --- a/include/cutlass/gemm/warp/mma_simt.h +++ b/include/cutlass/gemm/warp/mma_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -100,6 +100,16 @@ class MmaSimt { /// Indicates class of matrix operator using OperatorClass = arch::OpClassSimt; + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Layout of threads using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value, layout::ColumnMajor, typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value, @@ -137,6 +147,9 @@ class MmaSimt { dp4a_type >; + /// Shape of the underlying instruction + using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>; + public: /// Iterates over the A operand in memory @@ -153,6 +166,9 @@ class MmaSimt { /// Storage for A tile using FragmentA = typename IteratorA::Fragment; + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + /// Iterates over the B operand in memory using IteratorB = MmaSimtTileIterator< MatrixShape, @@ -167,6 +183,9 @@ class MmaSimt { /// Storage for B tile using FragmentB = typename IteratorB::Fragment; + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + /// Iterates over the C operand in memory using IteratorC = MmaSimtTileIterator< MatrixShape, @@ -201,6 +220,15 @@ class MmaSimt { mma(d, a, b, c); } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + //TODO: Implement this + dst_A = A; + dst_B = B; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_simt_policy.h b/include/cutlass/gemm/warp/mma_simt_policy.h index 782474337f..6abd0bf6a8 100644 --- a/include/cutlass/gemm/warp/mma_simt_policy.h +++ b/include/cutlass/gemm/warp/mma_simt_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/include/cutlass/gemm/warp/mma_simt_tile_iterator.h index 1d47e8f1a6..ed1e598702 100644 --- a/include/cutlass/gemm/warp/mma_simt_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_simt_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/warp/mma_tensor_op.h b/include/cutlass/gemm/warp/mma_tensor_op.h index d3e0fc0f21..3eff7b9054 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -31,18 +31,24 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma.h" #include "cutlass/gemm/warp/mma_tensor_op_policy.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -51,6 +57,81 @@ namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + return converter(source); + } +}; + +template +struct ConvertAndPack { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + return source; + } +}; + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + Array tmp; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); + tmp[i] = source[idx]; + } + + return converter(tmp); + } +}; + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + Array tmp; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); + tmp[i] = source[idx]; + } + + return converter(tmp); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> @@ -74,8 +155,6 @@ template < /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor = false, - /// PartitionsN indicating how many PartitionsN for multiplicand B - int PartitionsN_ = 1, /// Used for partial specialization typename Enable = bool > @@ -105,18 +184,27 @@ class MmaTensorOp { /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) using Policy = Policy_; + /// Architecture tag from underlying instruction + using ArchTag = typename Policy::Operator::ArchTag; + /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; + /// Shape of underlying instruction + using InstructionShape = typename Policy::Operator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Number of threads participating in warp-level matrix product static int const kThreadCount = 32; /// Number of partitions along K dimension static int const kPartitionsK = PartitionsK_; - /// PartitionsN indicating how many PartitionsN for multiplicand B - static int const kPartitionsN = PartitionsN_; - public: /// Iterates over the A operand in memory @@ -128,6 +216,10 @@ class MmaTensorOp { /// Storage for A tile using FragmentA = typename IteratorA::Fragment; + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kB, ElementB, LayoutB, @@ -137,6 +229,10 @@ class MmaTensorOp { /// Storage for B tile using FragmentB = typename IteratorB::Fragment; + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + /// Iterates over the C operand in memory using IteratorC = MmaTensorOpAccumulatorTileIterator< MatrixShape, ElementC, LayoutC, @@ -155,9 +251,7 @@ class MmaTensorOp { /// Number of mma operations performed using MmaIterations = MatrixShape< Shape::kM / Policy::Operator::Shape::kM, - (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ? - Shape::kN / Policy::Operator::Shape::kN / kPartitionsN : - 1 + Shape::kN / Policy::Operator::Shape::kN >; public: @@ -179,10 +273,10 @@ class MmaTensorOp { CUTLASS_DEVICE void operator()( FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C, - int const &partitionN_idx = 0) const { + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { using MmaOperandA = typename Policy::Operator::FragmentA; using MmaOperandB = typename Policy::Operator::FragmentB; @@ -194,8 +288,7 @@ class MmaTensorOp { MmaOperandB const *ptr_B = reinterpret_cast(&B); MmaOperandC *ptr_D = reinterpret_cast(&D); - // The offset of multilicand B for current partition - const int n_off = partitionN_idx * FragmentB::kElements / MmaOperandB::kElements / kPartitionsN; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) // Serpentine visitation order maximizing reuse of Rb CUTLASS_PRAGMA_UNROLL for (int n = 0; n < MmaIterations::kColumn; ++n) { @@ -213,13 +306,94 @@ class MmaTensorOp { ptr_D[n + m_serpentine * MmaIterations::kColumn]); } else { mma( - ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow], + ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], - ptr_B[n + n_off], - ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow]); + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } } } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + + // + // Define conversions from source type to instruction type + // + FloatRoundStyle const kRoundA = + PreferredRoundingMode::kRound; + FloatRoundStyle const kRoundB = + PreferredRoundingMode::kRound; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_B = + reinterpret_cast const *>(&B); + Array * + ptr_dst_B = reinterpret_cast *>(&dst_B); + + dst_A = convert_A(A); + + ptr_dst_B[0] = convert_B(ptr_B[0]); + ptr_dst_B[1] = convert_B(ptr_B[1]); + + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array * + ptr_dst_A = reinterpret_cast *>(&dst_A); + + dst_B = convert_B(B); + + ptr_dst_A[0] = convert_A(ptr_A[0]); + ptr_dst_A[1] = convert_A(ptr_A[1]); + #else + assert(0); + #endif } }; @@ -228,3 +402,5 @@ class MmaTensorOp { } // namespace warp } // namespace gemm } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h new file mode 100644 index 0000000000..85f5009d8c --- /dev/null +++ b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -0,0 +1,428 @@ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of a warp tile + that participate in one warp-level mma operation. + + Typically, this is used to access the accumulator tile/fragement of a warp-level mma operation. + The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into + next warp-level mma operation. + + This iterator is necessary to accomplish warp-level mma fusion where the accumulator tile is + reused as multiplicand tile for the next mma. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of the accumulation tile shape (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Output operation on the fragment + typename OutputOp_, + /// Whether beta is zero + bool IsBetaZero_ > +class MmaTensorOpFragmentIterator; + + +// Partial specialization for col-major accumulator tile +// And Element type is the same as Accumulator Element type + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Output operation on fragment + typename OutputOp_> +class MmaTensorOpFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Output operation on fragment + using OutputOp = OutputOp_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, OutputOp output_op) const { + + if (output_op.is_source_needed()) //beta must be zero + assert(0); + + AccessType src_fragment; + src_fragment.clear(); + + + AccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; n++) { + for (int m = 0; m < MmaIterations::kRow; m++) { + int accumulator_access_offset = + (n + index_n) * AccumulatorIterations::kRow + m + index_m; + + frag_ptr[n * MmaIterations::kRow + m].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + //frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; + frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); + } + } + } + +}; + +// Partial specialization for row-major accumulator tile + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Output operation on fragment + typename OutputOp_> +class MmaTensorOpFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Accumulator Element type + using ElementAccumulator = ElementAccumulator_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Output operation on fragment + using OutputOp = OutputOp_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + using FragmentAccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, OutputOp output_op) const { + + if (output_op.is_source_needed()) //beta must be zero + assert(0); + + FragmentAccessType src_fragment; + src_fragment.clear(); + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); +// NumericArrayConverter fragmentConverter; + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; m++) { + for (int n = 0; n < MmaIterations::kColumn; n++) { + int accumulator_access_offset = + (m + index_m) * AccumulatorIterations::kColumn + n + index_n; + + frag_ptr[m * MmaIterations::kColumn + n].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) +// frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]); + frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset], src_fragment); + } + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_tensor_op_policy.h b/include/cutlass/gemm/warp/mma_tensor_op_policy.h index 823860111d..68b28bfff1 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_policy.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_sm70.h index 836efb94a0..063c77f9cc 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -103,6 +103,18 @@ class MmaVoltaTensorOp { /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; + /// Architecture tag + using ArchTag = arch::Sm70; + + /// Underlying instruction shape + using InstructionShape = typename Policy::Operator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Number of threads participating in warp-level matrix product static int const kThreadCount = 32; @@ -201,8 +213,7 @@ class MmaVoltaTensorOp { FragmentC &D, FragmentA const &A, FragmentB const &B, - FragmentC const &C, - int const &partitionN_idx = 0) { + FragmentC const &C) { using MmaOperandA = typename Policy::Operator::FragmentA; using MmaOperandB = typename Policy::Operator::FragmentB; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 811ff60ef5..1a8fa4f915 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -199,7 +199,8 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = + Array; private: @@ -228,8 +229,11 @@ class MmaTensorOpMultiplicandTileIterator< k_group_idx_(0) { int quad_pair = (lane_id >> 3); + int quad_quad = (lane_id >> 4); int lane_in_quad = (lane_id & 3); int lane_in_quad_pair = (lane_id & 7); + int lane_in_quad_quad = (lane_id & 15); + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kPointerCount; ++i) { int partition_contiguous_idx = -1; @@ -241,6 +245,24 @@ class MmaTensorOpMultiplicandTileIterator< access_contiguous_idx = (quad_pair ^ lane_in_quad); access_strided_idx = lane_in_quad_pair; } + else if (Policy::LdsmShape::kContiguous == 2 && + kOperand == Operand::kA) { + // Matrix multiply 16816 A + // Q0 Q2 + // Q1 Q3 + partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); + access_contiguous_idx = + (((quad_pair & 1) + ((i & 1) << 1)) ^ lane_in_quad); + access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); + } else if (Policy::LdsmShape::kContiguous == 2 && + kOperand == Operand::kB) { + // Matrix multiply 16816 B + // Q0 Q1 + // Q2 Q3 + partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); + access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad); + access_strided_idx = lane_in_quad_quad; + } int access_contiguous = partition_contiguous_idx * Layout::PartitionShape::kContiguous + access_contiguous_idx; @@ -435,6 +457,364 @@ class MmaTensorOpMultiplicandTileIterator< }; //////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for 32-thread MMA.TF32 NT TensorOps. It +/// uses LDS.32 to load from shared memory and therefore must be initialized +/// with a TensorRef to shared memory. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::TensorOpMultiplicandCongruous<32, 32>, InstructionShape_, + OpDelta_, 32, PartitionsK_> { + public: + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand == Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for " + "A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicandCongruous<32, 32>; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kContiguous % InstructionShape::kContiguous), + "Shape of warp-level Mma must be divisible by operator shape."); + + // Determine number of elements along outer dimension per individual LDS.32 + // op. Every one warp of LDS.32 loads 8x4 elements + static int const kLdsOpInner = Layout::TileShape::kStrided; + static int const kLdsOpOuter = kThreads / kLdsOpInner; + + static_assert(!(Shape::kContiguous % kLdsOpOuter), + "Shape of warp-level mma must be divisible by LDS.32's " + "fundamental tile size."); + + static_assert(!(Shape::kStrided % kLdsOpInner), + "Shape of warp-level mma must be divisible by LDS.32's " + "fundamental tile size."); + + /// Number of LDS.32 instructions needed by one MMA instruction + /// 1684 A 2x1 + /// 1684 B 1x1 + /// 1688 A 2x2 + /// 1688 B 1x2 + static int const LdsShapeContiguous = + InstructionShape::kContiguous / kLdsOpOuter; + static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner; + using LdsShape = + layout::PitchLinearShape; + + /// Number and arrangement of LDS instructions + using LdsIterations = layout::PitchLinearShape< + Shape::kContiguous / LdsShapeContiguous / kLdsOpOuter, 1>; + + /// Number of groups for each tile + static int const kGroupsPerTile = + Shape::kStrided / InstructionShape::kStrided; + }; + + private: + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Number of internal pointers needed to reference shared memory + static int const kPointerCount = Layout::TileShape::kContiguous * + Layout::kElementsPerAccess / + Policy::kLdsOpOuter; + + /// Vectorized access is not used + static int const kElementsPerAccess = 1; + + /// Pointer type used for accesses + using AccessType = Element; + + /// Internal counter used to jump to next K partition + int k_group_idx_; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + + private: + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_[kPointerCount]; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + + public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {} + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) + : stride_(ref.stride(0)), byte_offset_(0), k_group_idx_(0) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPointerCount; ++i) { + int access_strided = lane_id % Policy::kLdsOpInner; + int access_contiguous = (lane_id / Policy::kLdsOpInner) + + (access_strided ^ i) * Policy::kLdsOpOuter; + + pointer_[i] = reinterpret_cast(ref.data()) + + access_contiguous + access_strided * stride_; + } + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + byte_offset_ += offset * sizeof(Element); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int contiguous_offset = tile_offset.contiguous(); + if (Shape::kContiguous == + Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) { + if (tile_offset.contiguous() % 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPointerCount / 2; ++i) { + AccessType const *tmp_pointer = pointer_[i]; + pointer_[i] = pointer_[i + kPointerCount / 2]; + pointer_[i + kPointerCount / 2] = tmp_pointer; + } + } + contiguous_offset = (tile_offset.contiguous() >> 1) << 1; + } + + int offset = (tile_offset.strided() * InstructionShape::kStrided) * stride_ + + contiguous_offset * Shape::kContiguous; + + add_pointer_offset(offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &operator++() { + add_tile_offset({0, 1}); + + if (kPartitionsK > 1) { + ++k_group_idx_; + // Jump to next stage + if (k_group_idx_ == Policy::kGroupsPerTile) { + k_group_idx_ = 0; + add_tile_offset( + {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); + } + } + + return *this; + } + + /// Advances the iterator along the opposite of the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &operator--() { + byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * + kElementsPerAccess; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of + ///< the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &operator+=( + TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of + ///< the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &operator-=( + TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + Element *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ss = 0; ss < Policy::LdsShape::kStrided; ++ss) { + CUTLASS_PRAGMA_UNROLL + for (int cc = 0; cc < Policy::LdsShape::kContiguous; ++cc) { + int access_idx = + cc + (ss + (c + s * Policy::LdsIterations::kContiguous) * + Policy::LdsShape::kStrided) * + Policy::LdsShape::kContiguous; + int access_idx_contiguous = cc + c * Policy::LdsShape::kContiguous; + int access_idx_strided = + (ss + s * Policy::LdsShape::kStrided) * Policy::kLdsOpInner; + + AccessType const *source_ptr = + pointer_[access_idx_contiguous % kPointerCount] + + Layout::TileShape::kContiguous * Layout::kElementsPerAccess * + (access_idx_contiguous / kPointerCount) + + access_idx_strided * stride_; + + char const *source_byte_ptr = + reinterpret_cast(source_ptr) + byte_offset + + byte_offset_; + + fetch_ptr[access_idx] = + *reinterpret_cast(source_byte_ptr); + } + } + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + Index pointer_offset = + tile_offset.contiguous() * Shape::kContiguous / + Layout::kElementsPerAccess + + tile_offset.strided() * InstructionShape::kStrided * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no op + } +}; + +//////////////////////////////////////////////////////////////////////////////// + /// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared /// memory and therefore must be initialized with a TensorRef to shared memory. /// @@ -516,7 +896,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -747,7 +1127,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -1023,7 +1403,8 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = Array; private: @@ -1067,7 +1448,6 @@ class MmaTensorOpMultiplicandTileIterator< k_group_idx_(0) { // Warp level iterator at most use double buffer to hide latency. If there // are more than 2 sections, every stage should have more than 1 section. - // TODO: refactor code after every case is implemented // Turing silicon requires all 32 threads in a warp provide valid addresses // even for LDSM.1 and LDSM.2 @@ -1075,6 +1455,8 @@ class MmaTensorOpMultiplicandTileIterator< lane_id = lane_id % (Policy::LdsmShape::kCount * Policy::kLdsmOpInner); #endif + int quad_quad = (lane_id >> 4); + int quad_pair = (lane_id >> 3); int lane_in_pair = (lane_id & 1); int lane_in_quad = (lane_id & 3); int lane_in_quad_pair = (lane_id & 7); @@ -1098,6 +1480,26 @@ class MmaTensorOpMultiplicandTileIterator< (lane_in_quad_quad / Layout::kFactor)); access_strided_idx = lane_id / Layout::kFactor; } + else if (Policy::LdsmShape::kStrided == + (Policy::LdsmShape::kCount / 2) && + kOperand == Operand::kA) { + // Integer matrix multiply 16832 A + partition_contiguous_idx = lane_in_quad / factor_in_partition; + access_strided_idx = lane_in_quad_quad / Layout::kFactor; + access_contiguous_idx = + ((lane_in_pair * factor_in_partition + quad_quad) ^ + access_strided_idx); + } + else if (Policy::LdsmShape::kStrided == + (Policy::LdsmShape::kCount / 2) && + kOperand == Operand::kB) { + // Integer matrix multiply 16832 B + partition_contiguous_idx = lane_in_quad / factor_in_partition; + access_strided_idx = lane_in_quad_pair / Layout::kFactor + quad_quad * 2; + access_contiguous_idx = + ((lane_in_pair * factor_in_partition + ((lane_id & 8) >> 3)) ^ + access_strided_idx); + } } else if (Layout::kFactor == 2) { // Super Matrix multiply kBlock = 32 if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { @@ -1111,6 +1513,28 @@ class MmaTensorOpMultiplicandTileIterator< access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor); access_strided_idx = lane_id / Layout::kFactor; } + else if (Policy::LdsmShape::kStrided == + (Policy::LdsmShape::kCount / 2) && + kOperand == Operand::kA) { + // Matrix multiply 16816|1688.TF32 A + // Q0 Q2 + // Q1 Q3 + partition_contiguous_idx = (lane_id % Layout::kFactor); + access_contiguous_idx = + (quad_quad ^ (lane_in_quad_pair / Layout::kFactor)); + access_strided_idx = (lane_in_quad_quad / Layout::kFactor); + } else if (Policy::LdsmShape::kStrided == + (Policy::LdsmShape::kCount / 2) && + kOperand == Operand::kB) { + // Matrix multiply 16816|1688.TF32 B + // Q0 Q1 + // Q2 Q3 + partition_contiguous_idx = (lane_id % Layout::kFactor); + access_contiguous_idx = + ((quad_pair & 1) ^ (lane_in_quad_pair / Layout::kFactor)); + access_strided_idx = + (lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor; + } } else if (Layout::kFactor == 1) { // Super Matrix multiply kBlock = 64 if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { @@ -1122,6 +1546,25 @@ class MmaTensorOpMultiplicandTileIterator< access_contiguous_idx = lane_in_quad; access_strided_idx = lane_id; } + else if (Policy::LdsmShape::kStrided == + (Policy::LdsmShape::kCount / 2) && + kOperand == Operand::kA) { + // Matrix multiply 16816|1688.TF32 A + // Q0 Q2 + // Q1 Q3 + partition_contiguous_idx = (lane_in_quad_pair >> 2); + access_contiguous_idx = (quad_quad ^ lane_in_quad); + access_strided_idx = lane_in_quad_quad; + } else if (Policy::LdsmShape::kStrided == + (Policy::LdsmShape::kCount / 2) && + kOperand == Operand::kB) { + // Matrix multiply 16816|1688.TF32 B + // Q0 Q1 + // Q2 Q3 + partition_contiguous_idx = (lane_in_quad_pair >> 2); + access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad); + access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); + } } int access_contiguous = @@ -1151,7 +1594,49 @@ class MmaTensorOpMultiplicandTileIterator< int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; byte_offset_ ^= k_groups_delta * sizeof_bits::value * - Layout::kElementsPerAccess / 8; + Layout::kElementsPerAccess * + Policy::LdsmShape::kContiguous / 8; + pointer_ += + tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + + whole_tiles * stride_ / sections_; + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( + TensorCoord const &tile_offset) { + + int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; + int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; + if (k_groups_delta < 0) { + whole_tiles -= 1; + k_groups_delta += Policy::kGroupsPerTile; + } + + if ((Policy::kGroupsPerTile / kPartitionsK) >= 2) { + byte_offset_ ^= (k_groups_delta & 1) * Policy::LdsmShape::kContiguous * + sizeof_bits::value * + Layout::kElementsPerAccess / 8; + } + if ((Policy::kGroupsPerTile / kPartitionsK) >= 4) { + byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 1)) & 2) * + Policy::LdsmShape::kContiguous * + sizeof_bits::value * + Layout::kElementsPerAccess / 8; + } + if ((Policy::kGroupsPerTile / kPartitionsK) == 8) { + byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 3)) & 4) * + Policy::LdsmShape::kContiguous * + sizeof_bits::value * + Layout::kElementsPerAccess / 8; + } + + k_group_idx_ += k_groups_delta; + whole_tiles += k_group_idx_ / (Policy::kGroupsPerTile / kPartitionsK); + k_group_idx_ = k_group_idx_ % (Policy::kGroupsPerTile / kPartitionsK); + pointer_ += tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + whole_tiles * stride_ / sections_; @@ -1162,12 +1647,23 @@ class MmaTensorOpMultiplicandTileIterator< CUTLASS_DEVICE MmaTensorOpMultiplicandTileIterator &operator++() { + // Integer matrix multiply 16832 Interleaved-32 + // NONE + // Integer matrix multiply 16816 Interleaved-32 || Integer matrix multiply 16816 kblock=32 + // Integer matrix multiply 8816 Interleaved-32 // ^1 ^1 + // Matrix multiply 1684.TF32 kblock=16 || Integer matrix multiply 16816 kblock=64 // Matrix multiply 1688 kblock=32 || Integer matrix multiply 8816 kblock=64 // ^1 ^3 ^1 ^3 // Matrix multiply 1688 kblock=64 // ^1 ^3 ^1 ^7 ^1 ^3 ^1 ^7 + + // Matrix multiply 16816 kblock=32 | 1688.TF32 kblock=16 || Integer matrix multiply 16832 kblock=64 + // ^2 ^2 + // Matrix multiply 16816 kblock=64 | 1688.TF32 kblock=32 || Integer matrix multiply 16832 kblock=128 + // ^2 ^6 ^2 ^6 + if ((Policy::kGroupsPerTile / kPartitionsK) > 1) { int mask = ((Policy::kGroupsPerTile / kPartitionsK) == 8) ? 3 @@ -1406,7 +1902,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator @@ -1440,6 +1936,16 @@ class MmaTensorOpMultiplicandTileIterator< return *this; } + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( + TensorCoord const &tile_offset) { + iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); + + return *this; + } + /// Advances the iterator along the advance dimension CUTLASS_HOST_DEVICE MmaTensorOpMultiplicandTileIterator &operator++() { @@ -1636,7 +2142,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator @@ -1670,6 +2176,16 @@ class MmaTensorOpMultiplicandTileIterator< return *this; } + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( + TensorCoord const &tile_offset) { + iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); + + return *this; + } + /// Advances the iterator along the advance dimension CUTLASS_HOST_DEVICE MmaTensorOpMultiplicandTileIterator &operator++() { @@ -1779,6 +2295,7 @@ class MmaTensorOpMultiplicandTileIterator< }; //////////////////////////////////////////////////////////////////////////////// + template < /// Size of the matrix to load (concept: MatrixShape) typename Shape_, @@ -2679,6 +3196,7 @@ class MmaTensorOpAccumulatorTileIterator< }; //////////////////////////////////////////////////////////////////////////////// + } // namespace warp } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index be271e777f..ed6384f05a 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -165,7 +165,8 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = Array; private: @@ -473,7 +474,8 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile, needs on more time number of registers - using Fragment = Array; + using Fragment = Array; private: @@ -738,7 +740,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -962,7 +964,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -1557,7 +1559,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = + Array; private: @@ -1869,7 +1873,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator @@ -2097,7 +2101,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h new file mode 100644 index 0000000000..e43373b64f --- /dev/null +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h @@ -0,0 +1,1579 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for loading 128b vectors of 64b elements. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::TensorOpMultiplicandCongruous64b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 4), "Divisibility."); + + static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Load two elements per access + static int const kElementsPerAccess = 2; + + /// Policy defining internal details of tile iterator + struct Policy { + + /// Shape of one access + using Delta = layout::PitchLinearShape<8, 4>; + + /// Number of iterations to load + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, + InstructionShape::kStrided / Delta::kStrided + >; + + }; + +private: + + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = AlignedArray; + + /// Internal counter used to jump to next K partition + int k_group_idx_; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + +private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): + stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), + k_group_idx_(0) { + + int access_strided = lane_id / Policy::Delta::kContiguous; + int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; + + pointer_= reinterpret_cast(ref.data()) + + access_contiguous + access_strided * stride_; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + byte_offset_ += offset * sizeof(Element); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + int offset = + (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + + tile_offset.contiguous() * Shape::kContiguous; + + add_pointer_offset(offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + add_tile_offset({0, 1}); + + return *this; + } + + /// Advances the iterator along the opposite of the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + add_tile_offset({0, -1}); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + AccessType *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::Iterations::kStrided; ++s) { + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { + + int access_idx = c + s * Policy::Iterations::kContiguous; + + AccessType const *source_ptr = pointer_ + + Policy::Delta::kContiguous * c + + Policy::Delta::kStrided * s * stride_; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; + + AccessType const *source = reinterpret_cast(source_byte_ptr); + + fetch_ptr[access_idx] = *source; + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + + Index pointer_offset = + tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + + tile_offset.strided() * InstructionShape::kStrided * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicandCongruous64b, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.strided(), tile_offset.contiguous()}, + byte_offset); + } + + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared +/// memory and therefore must be initialized with a TensorRef to shared memory. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicandCongruous64b, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.contiguous(), tile_offset.strided()}, + byte_offset); + } + + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is specialized for loading 128b vectors of 64b elements. +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: PitchLinearShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: PitchLinearShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::TensorOpMultiplicand64bCrosswise, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); + + static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Load two elements per access + static int const kElementsPerAccess = 2; + + /// Policy defining internal details of tile iterator + struct Policy { + + /// Shape of one access + using Delta = layout::PitchLinearShape<4, 16>; + + /// Number of iterations to load + using Iterations = layout::PitchLinearShape< + InstructionShape::kContiguous / Delta::kContiguous, + Shape::kStrided / Delta::kStrided + >; + + }; + +private: + + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = AlignedArray; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + +private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + + /// Internal counter for tracking K-group + Index k_group_idx_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): + stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), + k_group_idx_(0) { + + int access_strided = lane_id / 8; + int access_contiguous = (lane_id % 8); + + byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); + + pointer_= reinterpret_cast(ref.data()); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + pointer_ += offset / kElementsPerAccess; + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * + stride_ * kElementsPerAccess + + tile_offset.strided() * Shape::kStrided; + + add_pointer_offset(offset); + + int old_k_group_idx = k_group_idx_; + + k_group_idx_ += tile_offset.contiguous(); + + if ((k_group_idx_ & 2) ^ (old_k_group_idx & 2)) { + byte_offset_ ^= 0x40; + } + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + pointer_ += stride_ * InstructionShape::kContiguous; + + if (k_group_idx_ & 0x1) { + // xor ptr + byte_offset_ ^= 0x40; + } + + ++k_group_idx_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + AccessType *fetch_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::Iterations::kStrided; ++s) { + + int access_idx = c + s * Policy::Iterations::kContiguous; + + AccessType const *source_ptr = pointer_ + + Policy::Delta::kContiguous * c * stride_ + + Policy::Delta::kStrided * s / kElementsPerAccess; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; + + AccessType const *source = reinterpret_cast(source_byte_ptr); + + fetch_ptr[access_idx] = *source; + } + } + + Element *exchange_ptr = reinterpret_cast(&frag); + + if (k_group_idx_ & 1) { + // exchange on 64b granularity + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Fragment::kElements; i += 2) { + Element tmp = exchange_ptr[i]; + exchange_ptr[i] = exchange_ptr[i + 1]; + exchange_ptr[i + 1] = tmp; + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + Index pointer_offset = tile_offset.contiguous() * + InstructionShape::kContiguous / + Layout::kElementsPerAccess + + tile_offset.strided() * Shape::kStrided * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + k_group_idx_ = k_group; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicand64bCrosswise, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.strided(), tile_offset.contiguous()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Satisfies: +/// ReadableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIterator< + layout::PitchLinearShape, kOperand, Element, + layout::TensorOpMultiplicand64bCrosswise, + layout::PitchLinearShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.contiguous(), tile_offset.strided()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h index 0caf6247de..64be655680 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/include/cutlass/gemm/warp/mma_tensor_op_wmma.h index bbfa2dcbd7..824e207d74 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_wmma.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -40,6 +40,8 @@ #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma.h" @@ -75,8 +77,6 @@ template < typename Policy_, ///< Number of partitions along K dimension int PartitionsK_ = 1, - ///< Number of partitions along N dimension - int PartitionsN_ = 1, ///< Used for partial specialization typename Enable = bool > @@ -106,8 +106,20 @@ class MmaTensorOpWmma { /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) using Policy = Policy_; + /// Underlying instruction shape + using InstructionShape = typename Policy::Operator::Shape; + + /// Underlying architecture tag + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + using OperatorClass = arch::OpClassWmmaTensorOp; /// Number of threads participating in warp-level matrix product static int const kThreadCount = 32; @@ -115,9 +127,6 @@ class MmaTensorOpWmma { /// Number of partitions along K dimension static int const kPartitionsK = PartitionsK_; - /// PartitionsN indicating how many PartitionsN for multiplicand B - static int const kPartitionsN = PartitionsN_; - public: /// Iterates over the A operand in memory @@ -154,9 +163,7 @@ class MmaTensorOpWmma { /// Number of wmma operations performed using WmmaIterations = MatrixShape< Shape::kM / Policy::Operator::Shape::kM, - (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ? - Shape::kN / Policy::Operator::Shape::kN / kPartitionsN : - 1 + Shape::kN / Policy::Operator::Shape::kN >; public: @@ -180,8 +187,7 @@ class MmaTensorOpWmma { FragmentC &D, FragmentA const &A, FragmentB const &B, - FragmentC const &C, - int const &partitionN_idx = 0) const { + FragmentC const &C) const { CUTLASS_PRAGMA_UNROLL for (int n = 0; n < WmmaIterations::kColumn; ++n) { @@ -193,7 +199,6 @@ class MmaTensorOpWmma { } } } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/tile_iterator_planar_complex.h b/include/cutlass/gemm/warp/tile_iterator_planar_complex.h new file mode 100644 index 0000000000..a3050c4299 --- /dev/null +++ b/include/cutlass/gemm/warp/tile_iterator_planar_complex.h @@ -0,0 +1,244 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/array_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TileIteratorPlanarComplex { +public: + + /// Underlying iterator over real-valued tiles + using TileIterator = TileIterator_; + + /// Underlying element type + using Element = typename TileIterator::Element; + + /// Underlying layout type + using Layout = typename TileIterator::Layout; + + /// TensorRef type for loading element from a tensor + using TensorRef = typename TileIterator::TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Planar complex fragment + using Fragment = ArrayPlanarComplex; + +public: + + /// Underlying tile iterator + TileIterator tile_iterator_; + + /// Offset (in units of bytes) to the imaginary part of the planar complex matrix + LongIndex imaginary_offset_; + +public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + TileIteratorPlanarComplex(): imaginary_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + TileIteratorPlanarComplex( + TensorRef const &ref, + int lane_id, + LongIndex imaginary_offset + ): + tile_iterator_(ref, lane_id), + imaginary_offset_((imaginary_offset * sizeof_bits::value) / 8) { } + + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) { + + tile_iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) { + + tile_iterator_.add_tile_offset(tile_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + TileIteratorPlanarComplex & operator++() { + ++tile_iterator_; + return *this; + } + + // + // WIP + // + + /// Advances the iterator along the opposite of the advance dimension + CUTLASS_HOST_DEVICE + TileIteratorPlanarComplex & operator--() { + --tile_iterator_; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) { + tile_iterator_.add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) { + tile_iterator_.add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + tile_iterator_.load_with_byte_offset(frag.real, 0); + tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + tile_iterator_.load_with_byte_offset(frag.real, byte_offset); + tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + Index byte_offset = (pointer_offset * sizeof_bits::value)/8; + + tile_iterator_.load_with_byte_offset(frag.real, byte_offset); + tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0); + tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + Index byte_offset = (pointer_offset * sizeof_bits::value)/8; + + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); + tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + tile_iterator_.set_kgroup_index(k_group); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/half.h b/include/cutlass/half.h index ba6f0d951d..10d00de1c2 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -78,6 +78,10 @@ enum #include +#if defined(__i386__) || defined(__x86_64__) +#include +#endif + #define F16C_ROUND_NEAREST 0 #if !defined(__CUDA_ARCH__) @@ -110,9 +114,51 @@ __inline unsigned short _cvtss_sh (float __F, const int) { // Linux #include -#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) +#if defined(__i386__) || defined(__x86_64__) +#include #endif + +#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) + +#endif // _MSC_VER + +class CpuId { + + bool f16c_enabled; + + CpuId() { + #if defined(__i386__) || defined(__x86_64__) + #if defined(_MSC_VER) + int exx[4]; + + __cpuid (exx, 1); + f16c_enabled = exx[2] & 0x20000000; + + #else + // GCC / Clang + int eax, ebx, ecx, edx; + + __cpuid (1 , eax, ebx, ecx, edx); + f16c_enabled = ecx & 0x20000000; + #endif + #else + // Arm / PowerPC etc. + f16c_enabled = false; + #endif + } + +public: + + bool is_f16c_supported() const { + return f16c_enabled; + } + + static const CpuId& instance() { + static CpuId cpu; + return cpu; + } +}; #endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -154,10 +200,15 @@ struct alignas(2) half_t { static half_t convert(float const& flt) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) return half_t(__float2half_rn(flt)); - #elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); - return bitcast(u); #else + + #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C + if( CpuId::instance().is_f16c_supported() ) { + unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); + return bitcast(u); + } + #endif + // software implementation rounds toward nearest even unsigned const& s = reinterpret_cast(flt); uint16_t sign = uint16_t((s >> 16) & 0x8000); @@ -248,10 +299,15 @@ struct alignas(2) half_t { static float convert(half_t const& x) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) return __half2float(x.to_half()); - #elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - unsigned short u = x.storage; - return _cvtsh_ss(u); #else + + #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C + if( CpuId::instance().is_f16c_supported() ) { + unsigned short u = x.storage; + return _cvtsh_ss(u); + } + #endif + uint16_t const &h = x.storage; int sign = ((h >> 15) & 1); int exp = ((h >> 10) & 0x1f); diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 223346c47a..6b97f8222a 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,7 +29,11 @@ */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/platform/platform.h" @@ -48,7 +52,7 @@ struct integer_subbyte { static bool const kSigned = Signed; /// External type - using T = typename std::conditional::type; + using T = typename platform::conditional::type; /// Storage type using Storage = uint8_t; diff --git a/include/cutlass/kernel_launch.h b/include/cutlass/kernel_launch.h index b48fd7d0b0..bd84a35781 100644 --- a/include/cutlass/kernel_launch.h +++ b/include/cutlass/kernel_launch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/layout/layout.h b/include/cutlass/layout/layout.h index dda08dafa7..775357d125 100644 --- a/include/cutlass/layout/layout.h +++ b/include/cutlass/layout/layout.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -38,7 +38,6 @@ #include "cutlass/layout/matrix.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/tensor.h" -#include "cutlass/layout/tensor_nhwc.h" #include "cutlass/layout/vector.h" #include "cutlass/layout/tensor_op_multiplicand_sm70.h" diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index ba0361c081..7c02f8f2c2 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -98,7 +98,7 @@ class RowMajor { /// Assumes coordinate has convention (row, column) CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord const &coord) const { - return coord.row() * stride_[0] + coord.column(); + return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column(); } /// Inverse of layout function, mapping linear offset to logical coordinate @@ -134,7 +134,7 @@ class RowMajor { /// Compute the number of contiguous elements needed to store a tensor with the given size CUTLASS_HOST_DEVICE LongIndex capacity(MatrixCoord const &extent) const { - return extent.row() * stride_[0]; + return LongIndex(extent.row()) * LongIndex(stride_[0]); } }; @@ -191,7 +191,7 @@ class ColumnMajor { /// Assumes coordinate has convention (row, column) CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord const &coord) const { - return coord.row() + coord.column() * stride_[0]; + return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row(); } /// Inverse of layout function, mapping linear offset to logical coordinate @@ -227,7 +227,7 @@ class ColumnMajor { /// Compute the number of contiguous elements needed to store a tensor with the given size CUTLASS_HOST_DEVICE LongIndex capacity(MatrixCoord const &extent) const { - return extent.column() * stride_[0]; + return LongIndex(extent.column()) * LongIndex(stride_[0]); } }; @@ -290,7 +290,7 @@ struct RowMajorInterleaved { LongIndex operator()(MatrixCoord const &coord) const { Index row_major = coord.row() / kInterleave; Index row_minor = coord.row() % kInterleave; - return row_major * stride_[0] + coord.column() * kInterleave + row_minor; + return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor; } /// Inverse of layout function, mapping linear offset to logical coordinate @@ -397,7 +397,7 @@ struct ColumnMajorInterleaved { LongIndex operator()(MatrixCoord const &coord) const { Index column_major = coord.column() / kInterleave; Index column_minor = coord.column() % kInterleave; - return column_major * stride_[0] + coord.row() * kInterleave + column_minor; + return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor; } /// Inverse of layout function, mapping linear offset to logical coordinate diff --git a/include/cutlass/layout/pitch_linear.h b/include/cutlass/layout/pitch_linear.h index 2a326c7740..a6158b32a4 100644 --- a/include/cutlass/layout/pitch_linear.h +++ b/include/cutlass/layout/pitch_linear.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -116,6 +116,11 @@ struct PitchLinearCoord : public Coord<2, int> { return PitchLinearCoord(Base::operator-(b)); } + CUTLASS_HOST_DEVICE + PitchLinearCoord operator-() const { + return PitchLinearCoord(-at(0), -at(1)); + } + /// Element-wise multiplication CUTLASS_HOST_DEVICE PitchLinearCoord operator*(Base const& b) const { @@ -211,7 +216,7 @@ class PitchLinear { /// Assumes coordinate has convention (contiguous, strided) CUTLASS_HOST_DEVICE LongIndex operator()(TensorCoord const &coord) const { - return coord.contiguous() + coord.strided() * stride_[0]; + return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]); } /// Returns the logical coordinate given an offset. diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index dc2a7c889d..20d5bad777 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,7 +33,11 @@ defined in cutlass/tensor_ref.h. */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include "assert.h" +#endif #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/include/cutlass/layout/tensor_op_multiplicand_sm70.h index 26bd427e69..03f87db392 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm70.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/include/cutlass/layout/tensor_op_multiplicand_sm75.h index b4b35667ed..00870fb50f 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm75.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm80.h b/include/cutlass/layout/tensor_op_multiplicand_sm80.h new file mode 100644 index 0000000000..e5963a2a80 --- /dev/null +++ b/include/cutlass/layout/tensor_op_multiplicand_sm80.h @@ -0,0 +1,1133 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace layout { + +//////////////////////////////////////////////////////////////////////////////// + +/// Template based on element size (in bits) - defined in terms of pitch-linear +/// memory and Crosswise size (in elements). +struct TensorOpMultiplicandCongruous64b { + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = PitchLinearCoord; + + /// Stride vector + using Stride = Coord; + + // + // Static constants + // + + static int const kElementSize = 64; + static int const kElementsPerAccess = 1; + + private: + + // + // Data members + // + + /// Stride data member. + Stride stride_; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicandCongruous64b(Index ldm = 0) : stride_(ldm) {} + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicandCongruous64b(Stride stride) : stride_(stride) {} + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static TensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { + return TensorOpMultiplicandCongruous64b(extent[0]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + + int tc = coord.contiguous() / 16; + int ts = coord.strided() / 4; + + int c = coord.contiguous() % 16; + int s = coord.strided() % 4; + + + int bank = ((((c & 1) * 4 + (c & 6) / 2)) ^ (s & 1)) * 2 + (c / 8); + int row = (c & 6) / 2; + + bank ^= ((s & 2) * 2); + + LongIndex offset = tc * 16 + bank + (ts * 4 + row) * stride_[0]; + + return offset; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Compute the number of contiguous elements needed to store a tensor with + /// the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return extent[1] * stride_[0]; + } + + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + return TensorCoord(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template mapping a column-major view of pitch-linear memory to +/// TensorOpMultiplicand +struct ColumnMajorTensorOpMultiplicandCongruous64b { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicandCongruous64b; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static ColumnMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { + return ColumnMajorTensorOpMultiplicandCongruous64b(extent.row()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.row(), coord.column())); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + PitchLinearCoord coord = layout_.inverse(offset); + return MatrixCoord(coord.contiguous(), coord.strided()); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template mapping a row-major view of pitch-linear memory to +/// TensorOpMultiplicand +struct RowMajorTensorOpMultiplicandCongruous64b { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicandCongruous64b; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static RowMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { + return RowMajorTensorOpMultiplicandCongruous64b(extent.column()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.column(), coord.row())); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + PitchLinearCoord coord = layout_.inverse(offset); + return MatrixCoord(coord.strided(), coord.contiguous()); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template based on element size (in bits) - defined in terms of pitch-linear +/// memory and Crosswise size (in elements). +struct TensorOpMultiplicand64bCrosswise { + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = PitchLinearCoord; + + /// Stride vector + using Stride = Coord; + + // + // Static constants + // + + static int const kElementSize = 64; + static int const kElementsPerAccess = 1; + + private: + + // + // Data members + // + + /// Stride data member. + Stride stride_; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicand64bCrosswise(Index ldm = 0) : stride_(ldm) {} + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicand64bCrosswise(Stride stride) : stride_(stride) {} + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static TensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { + return TensorOpMultiplicand64bCrosswise(extent[0]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + + int tc = coord.contiguous() / 16; + int ts = coord.strided() / 16; + + int c = coord.contiguous() % 16; + int s = coord.strided() % 16; + + int k_group = c / 4; + int access_s = s / 2; + + int row = access_s % 4; + int bank = ((k_group & 2) << 2) ^ ((s % 2) << 3) + (c % 4) * 2 + (access_s / 4) ^ (k_group & 1); + + int smem_row = (k_group * 4 + row) + tc * 16; + int smem_col = ts * 16 + bank; + + LongIndex offset = smem_row * stride_[0] + smem_col; + + return offset; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Compute the number of contiguous elements needed to store a tensor with + /// the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return extent[1] * stride_[0]; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template based on element size (in bits) - defined in terms of pitch-linear +/// memory and Crosswise size (in elements). +struct ColumnMajorTensorOpMultiplicand64bCrosswise { + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicand64bCrosswise; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static ColumnMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { + return ColumnMajorTensorOpMultiplicand64bCrosswise(extent.column()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.row(), coord.column())); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template based on element size (in bits) - defined in terms of pitch-linear +/// memory and Crosswise size (in elements). +struct RowMajorTensorOpMultiplicand64bCrosswise { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicand64bCrosswise; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static RowMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { + return RowMajorTensorOpMultiplicand64bCrosswise(extent.row()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.column(), coord.row())); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template based on element size (in bits) - defined in terms of pitch-linear +/// memory and Crosswise size (in elements). +struct TensorOpMultiplicandCongruous128b { + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = PitchLinearCoord; + + /// Stride vector + using Stride = Coord; + + // + // Static constants + // + + static int const kElementSize = 128; + static int const kElementsPerAccess = 1; + + private: + + // + // Data members + // + + /// Stride data member. + Stride stride_; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicandCongruous128b(Index ldm = 0) : stride_(ldm) {} + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicandCongruous128b(Stride stride) : stride_(stride) {} + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static TensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { + return TensorOpMultiplicandCongruous128b(extent[0]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + + Index tc = coord.contiguous() / 8; + Index ts = coord.strided() / 4; + + Index c = coord.contiguous() % 8; + Index s = coord.strided() % 4; + + Index k_index = (c / 2); + + Index bank = (((c & 1) * 4) | (s ^ k_index)); + + LongIndex offset = tc * 8 + bank + (ts * 4 + k_index) * stride_[0]; + + return offset; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Compute the number of contiguous elements needed to store a tensor with + /// the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return extent[1] * stride_[0]; + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + return TensorCoord(); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Template mapping a column-major view of pitch-linear memory to +/// TensorOpMultiplicand +struct ColumnMajorTensorOpMultiplicandCongruous128b { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicandCongruous128b; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static ColumnMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { + return ColumnMajorTensorOpMultiplicandCongruous128b(extent.row()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.row(), coord.column())); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + PitchLinearCoord coord = layout_.inverse(offset); + return MatrixCoord(coord.contiguous(), coord.strided()); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template mapping a row-major view of pitch-linear memory to +/// TensorOpMultiplicand +struct RowMajorTensorOpMultiplicandCongruous128b { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicandCongruous128b; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static RowMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { + return RowMajorTensorOpMultiplicandCongruous128b(extent.column()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.column(), coord.row())); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + PitchLinearCoord coord = layout_.inverse(offset); + return MatrixCoord(coord.strided(), coord.contiguous()); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template based on element size (in bits) - defined in terms of pitch-linear +/// memory and Crosswise size (in elements). +struct TensorOpMultiplicandCrosswise128x4 { + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = PitchLinearCoord; + + /// Stride vector + using Stride = Coord; + + // + // Static constants + // + + static int const kElementSize = 128; + static int const kElementsPerAccess = 1; + + private: + + // + // Data members + // + + /// Stride data member. + Stride stride_; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicandCrosswise128x4(Index ldm = 0) : stride_(ldm) {} + + /// Ctor + CUTLASS_HOST_DEVICE + TensorOpMultiplicandCrosswise128x4(Stride stride) : stride_(stride) {} + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static TensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { + return TensorOpMultiplicandCrosswise128x4(extent[0]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + + Index tc = coord.contiguous() / 8; + Index ts = coord.strided() / 8; + + Index c = coord.contiguous() % 8; + Index s = coord.strided() % 8; + + Index liq = c % 4; + + Index bank = liq + ((s & 1) * 4) ^ (c & 4); + + Index k_index = (c & 4) + (s / 4) * 2 + ((s & 2) / 2); + + LongIndex offset = (tc * 8 + k_index) * stride_[0] + ts * 8 + bank; + + return offset; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Compute the number of contiguous elements needed to store a tensor with + /// the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return extent[1] * stride_[0]; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template mapping a column-major view of pitch-linear memory to +/// TensorOpMultiplicand +struct ColumnMajorTensorOpMultiplicandCrosswise128x4 { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicandCrosswise128x4; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + ColumnMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static ColumnMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { + return ColumnMajorTensorOpMultiplicandCrosswise128x4(extent.column()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.row(), coord.column())); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template mapping a row-major view of pitch-linear memory to +/// TensorOpMultiplicand +struct RowMajorTensorOpMultiplicandCrosswise128x4 { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 1; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + + // + // Invariants + // + + using Base = TensorOpMultiplicandCrosswise128x4; + +private: + + // + // Data members + // + + Base layout_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } + + /// Ctor + CUTLASS_HOST_DEVICE + RowMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static RowMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { + return RowMajorTensorOpMultiplicandCrosswise128x4(extent.row()); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (contiguous, strided) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return layout_(PitchLinearCoord(coord.column(), coord.row())); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace layout +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/layout/vector.h b/include/cutlass/layout/vector.h index 0700e58722..b54b6b3b18 100644 --- a/include/cutlass/layout/vector.h +++ b/include/cutlass/layout/vector.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/matrix_coord.h b/include/cutlass/matrix_coord.h index 8ba61a5ecf..b432665e8c 100644 --- a/include/cutlass/matrix_coord.h +++ b/include/cutlass/matrix_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/matrix_shape.h b/include/cutlass/matrix_shape.h index 1d0b4820fc..cb3118c2d6 100644 --- a/include/cutlass/matrix_shape.h +++ b/include/cutlass/matrix_shape.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/matrix_traits.h b/include/cutlass/matrix_traits.h index 8e7fe3305e..cf7002a42a 100644 --- a/include/cutlass/matrix_traits.h +++ b/include/cutlass/matrix_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 228d632796..78181ce79a 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -46,6 +46,8 @@ enum class FloatRoundStyle { round_to_nearest, ///< round to nearest even round_toward_infinity, ///< round toward infinity round_toward_neg_infinity, ///< round toward negative infinity + round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero + round_half_ulp_trunc_dntz ///< like round_half_ulp_truncate, except denorms are rounded *toward* zero }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -240,6 +242,232 @@ struct NumericConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for float <=> bfloat16_t +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for float <= bfloat16_t +template +struct NumericConverter { + + using result_type = float; + using source_type = bfloat16_t; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + + return static_cast(s); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + using result_type = bfloat16_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + return static_cast(s); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + using result_type = bfloat16_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + uint32_t x32 = reinterpret_cast(s); + + #if defined(__CUDA_ARCH__) + if (::isfinite(s)) { + x32 += 0x8000; + } + #else + if (std::isfinite(s)) { + x32 += 0x8000; + } + #endif + + uint16_t x16 = uint16_t((x32 >> 16) & 0xffff); + return bfloat16_t::bitcast(x16); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + using result_type = bfloat16_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + + uint32_t x32 = reinterpret_cast(s); + uint16_t x16 = uint16_t(x32 >> 16); + + return bfloat16_t::bitcast(x16); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for float <=> tfloat32_t +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for float <= tfloat32_t +template +struct NumericConverter { + + using result_type = float; + using source_type = tfloat32_t; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + + return static_cast(s); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + using result_type = tfloat32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + + unsigned storage = reinterpret_cast(s); + + if ((storage & 0x7f800000) != 0x7f800000) { + + bool mantissa_bit = ((storage & (1 << 13)) != 0); + bool round_bit = ((storage & (1 << 12)) != 0); + bool sticky_bit = ((storage & ((1 << 12) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { + storage += uint32_t(1 << 13); + } + + // Note, the following is intentionally commented out. TF32 + // does not define the low order bits, so they may be left in + // an undefined state. + // + // By not truncating these bit explicitly, we avoid an extra logical + // operation. + // + // TF32 may be implicitly converted to float by performing this + // operation as needed. + // + // storage = (storage & ~0x1fff); + } + else if (storage & ~0xff800000) { + storage = 0x7fffffff; + } + + return tfloat32_t::bitcast(storage); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + using result_type = tfloat32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + return tfloat32_t::round_half_ulp_truncate(s); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// This rounding operation is similar to half_ulp_truncate except it rounds denorms toward zero. +/// It avoids predicated code, though it requires a temporary register. +template <> +struct NumericConverter { + using result_type = tfloat32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_trunc_dntz; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + + unsigned y = reinterpret_cast(s); + y = y & 0xff800000; + float d = reinterpret_cast(y); + float z = d / float(1 << 11) + s; + + return reinterpret_cast(z); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + using result_type = tfloat32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + uint32_t x = reinterpret_cast(s); + return tfloat32_t::bitcast(x & 0xffffe000); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Conversion and Clamp operator for Integers @@ -268,8 +496,7 @@ struct NumericConverterClamp { result_type const kClamp_min = -kClamp_max - 1; bool is_int_min = !(s > kClamp_min); bool is_int_max = !(s < kClamp_max); - - return (is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s))); + return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s)); } CUTLASS_HOST_DEVICE @@ -295,15 +522,15 @@ struct NumericConverterClamp { CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - NumericConverter convert_op; + NumericConverter convert_op; - float kClamp_max = float((1 << (sizeof_bits::value - 1)) - 1); - float kClamp_min = -kClamp_max - 1; + double kClamp_max = double((1U << (sizeof_bits::value - 1)) - 1); + double kClamp_min = -kClamp_max - 1; - float source = s; + double source = s; - source = fmaxf(source, kClamp_min); - source = fminf(source, kClamp_max); + source = fmax(source, kClamp_min); + source = fmin(source, kClamp_max); return convert_op(source); } @@ -353,6 +580,24 @@ struct NumericArrayConverter { } }; +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return s; + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Array <= Array, round to nearest @@ -498,10 +743,86 @@ struct NumericArrayConverter { } }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array, round to nearest +template <> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned d; + + asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(d) : "f"(source[1]), "f"(source[0]) ); + + return reinterpret_cast(d); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + NumericConverter convert_element_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + if (N % 2) { + result[N - 1] = convert_element_(source[N - 1]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +#endif // if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + ///////////////////////////////////////////////////////////////////////////////////////////////// // Conditional guards to enable partial specialization for packed integers -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) /// Partial specialization for Array <= Array template < @@ -625,12 +946,13 @@ struct NumericArrayConverter { return convert(s); } }; - -#endif // Conditional guards to enable partial specialization for packed integers +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) /// Partial specialization for Array <= Array template < @@ -707,4 +1029,127 @@ struct NumericArrayConverter { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// FastNumericArrayConverter only works when the source is within center range. +/// Conversion operator for Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &s) { + result_type result; + NumericArrayConverter convert_; + + return convert_(s); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int tmp = source[i] + 1262485504 /*0x4B400000*/; + result[i] = reinterpret_cast(tmp) - 12582912.0f; + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + float tmp = source[i] + 12582912.0f; + result[i] = reinterpret_cast(tmp); + } + + result[0] = __byte_perm(result[0], result[1], 0x40); + result[2] = __byte_perm(result[2], result[3], 0x40); + result[0] = __byte_perm(result[0], result[2], 0x5410); + + return reinterpret_cast(result[0]); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + static_assert(!(N % 4), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + FastNumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = + reinterpret_cast *>(&result); + Array const *source_ptr = + reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines preferred rounding mode for a pair of types +template +struct PreferredRoundingMode { + static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; +}; + +/// Defines preferred rounding mode for a pair of types +template <> +struct PreferredRoundingMode { + static FloatRoundStyle const kRound = FloatRoundStyle::round_half_ulp_truncate; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index c44659f001..9479ccb08b 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,7 +28,11 @@ */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" @@ -65,6 +69,10 @@ struct sizeof_bits { ///////////////////////////////////////////////////////////////////////////////////////////////// #include "cutlass/integer_subbyte.h" + #include "cutlass/half.h" +#include "cutlass/bfloat16.h" +#include "cutlass/tfloat32.h" ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 3117cc7cc5..826b3977fc 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -95,7 +95,11 @@ // Dependencies //----------------------------------------------------------------------------- +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #if !defined(__CUDACC_RTC__) //----------------------------------------------------------------------------- diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index 4f6b123c24..9293696225 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,10 +28,13 @@ */ #pragma once -#if !defined(__CUDACC_RTC__) +#if defined(__CUDACC_RTC__) +#include +#include +#else #include -#endif #include +#endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/real.h b/include/cutlass/real.h index de1fee295c..45ab1864eb 100644 --- a/include/cutlass/real.h +++ b/include/cutlass/real.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,6 +30,11 @@ namespace cutlass { template struct RealType { using Type = T; + +CUTLASS_HOST_DEVICE + static T from_real(double x) { + return static_cast(x); + } }; template @@ -38,4 +43,5 @@ static T from_real(double r) { return T(r); } + } // namespace cutlass diff --git a/include/cutlass/reduction/batched_reduction.h b/include/cutlass/reduction/batched_reduction.h index 83324ec012..16132a0210 100644 --- a/include/cutlass/reduction/batched_reduction.h +++ b/include/cutlass/reduction/batched_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/reduction/batched_reduction_traits.h b/include/cutlass/reduction/batched_reduction_traits.h index c44238e1e8..46157dc703 100644 --- a/include/cutlass/reduction/batched_reduction_traits.h +++ b/include/cutlass/reduction/batched_reduction_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/reduction/device/reduce_split_k.h b/include/cutlass/reduction/device/reduce_split_k.h new file mode 100644 index 0000000000..e3626f88c0 --- /dev/null +++ b/include/cutlass/reduction/device/reduce_split_k.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/device_kernel.h" +#include "cutlass/reduction/kernel/reduce_split_k.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ReductionKernel_ +> +class ReduceSplitK { +public: + using ReductionKernel = ReductionKernel_; + + using Shape = typename ReductionKernel::Shape; + using ReductionOp = typename ReductionKernel::ReductionOp; + using OutputOp = typename ReductionKernel::OutputOp; + + using ElementWorkspace = typename ReductionKernel::ElementWorkspace; + using ElementAccumulator = typename ReductionKernel::ElementAccumulator; + using ElementOutput = typename ReductionKernel::ElementOutput; + + using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef; + using OutputTensorRef = typename ReductionKernel::OutputTensorRef; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + MatrixCoord problem_size; + int partitions; + size_t partition_stride; + WorkspaceTensorRef workspace; + OutputTensorRef destination; + OutputTensorRef source; + typename OutputOp::Params output; + typename ReductionOp::Params reduction; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() : + problem_size(0, 0), + partitions(1), + partition_stride(0) { } + + CUTLASS_HOST_DEVICE + Arguments( + MatrixCoord const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + MatrixCoord problem_size_, + int partitions_, + size_t partition_stride_, + WorkspaceTensorRef workspace_, + OutputTensorRef destination_, + OutputTensorRef source_, + typename OutputOp::Params output_ = typename OutputOp::Params(), + typename ReductionOp::Params reduction_ = typename ReductionOp::Params() + ): + problem_size(problem_size_), + partitions(partitions_), + partition_stride(partition_stride_), + workspace(workspace_), + destination(destination_), + source(source_), + output(output_), + reduction(reduction_) + { + + } + + }; + +private: + /// Kernel parameters object + typename ReductionKernel::Params params_; + +public: + /// Constructs Reduction SplitK + ReduceSplitK() { } + + /// Determines whether the ReduceSplitK can execute the given problem. + static Status can_implement(Arguments const &args) { + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + // needs no additional workspace + return 0; + } + + /// Initializes Reduction state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + // initialize the params structure from the arguments + params_ = typename ReductionKernel::Params( + args.problem_size, + args.partitions, + args.partition_stride, + args.workspace, + args.destination, + args.source, + args.output, + args.reduction + ); + + return Status::kSuccess; + + } + + /// Initializes Reduction kernel state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.workspace.reset(args.workspace.non_const_ref().data()); + params_.destination.reset(args.destination.non_const_ref().data()); + params_.source.reset(args.source.non_const_ref().data()); + params_.output = args.output; + params_.reduction = args.reduction; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // + // Launch reduction kernel + // + dim3 block = ReductionKernel::block_shape(); + dim3 grid = ReductionKernel::grid_shape(params_.problem_size); + + Kernel<<< grid, block, 0, stream >>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/include/cutlass/reduction/kernel/reduce_split_k.h b/include/cutlass/reduction/kernel/reduce_split_k.h index 1869102f10..586c90d86a 100644 --- a/include/cutlass/reduction/kernel/reduce_split_k.h +++ b/include/cutlass/reduction/kernel/reduce_split_k.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -129,8 +129,8 @@ class ReduceSplitK { cutlass::MatrixCoord problem_size) { return dim3( - (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn, - (problem_size.row() + Shape::kRow -1) / Shape::kRow); + (problem_size.row() + Shape::kRow - 1) / Shape::kRow, + (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn); } /// Determines the threadblock shape @@ -145,8 +145,8 @@ class ReduceSplitK { // Determine CTA position MatrixCoord thread_offset( - int(blockIdx.y) * Shape::kRow + threadIdx.y, - int(blockIdx.x) * Shape::kColumn + threadIdx.x * kElementsPerAccess + int(blockIdx.x) * Shape::kRow + threadIdx.y, + int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess ); // One guard conditional diff --git a/include/cutlass/reduction/thread/reduce.h b/include/cutlass/reduction/thread/reduce.h index ae03c82140..698b174f95 100644 --- a/include/cutlass/reduction/thread/reduce.h +++ b/include/cutlass/reduction/thread/reduce.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/reduction/thread/reduction_operators.h b/include/cutlass/reduction/thread/reduction_operators.h index af029124b8..6f9aeb6f32 100644 --- a/include/cutlass/reduction/thread/reduction_operators.h +++ b/include/cutlass/reduction/thread/reduction_operators.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -85,7 +85,13 @@ struct ReduceAdd { plus op; - return op(accumulator, element); + NumericArrayConverter< + ElementAccumulator, + Element, + kCount, + PreferredRoundingMode::kRound> converter; + + return op(accumulator, converter(element)); } }; diff --git a/include/cutlass/reduction/threadblock_swizzle.h b/include/cutlass/reduction/threadblock_swizzle.h index 6e42cadab4..2419cdf6f5 100644 --- a/include/cutlass/reduction/threadblock_swizzle.h +++ b/include/cutlass/reduction/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index cb6d68ca5e..5714fbd2fd 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -145,6 +145,28 @@ bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); } +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal( + bfloat16_t a, + bfloat16_t b, + bfloat16_t epsilon, + bfloat16_t nonzero_floor) { + + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal( + tfloat32_t a, + tfloat32_t b, + tfloat32_t epsilon, + tfloat32_t nonzero_floor) { + + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + template <> CUTLASS_HOST_DEVICE bool relatively_equal(float a, float b, float epsilon, float nonzero_floor) { diff --git a/include/cutlass/semaphore.h b/include/cutlass/semaphore.h index c032c32785..dc5523dca1 100644 --- a/include/cutlass/semaphore.h +++ b/include/cutlass/semaphore.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -66,8 +66,13 @@ class Semaphore { /// Permit fetching the synchronization mechanism early CUTLASS_DEVICE void fetch() { - - asm volatile ("ld.global.cg.s32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + if (wait_thread) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + #else + asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + #endif + } } /// Gets the internal state @@ -80,14 +85,8 @@ class Semaphore { CUTLASS_DEVICE void wait(int status = 0) { - if (wait_thread) { - while (state != status) { - - fetch(); - - __syncwarp(0x01); - - }; + while( __syncthreads_and(state != status) ) { + fetch(); } __syncthreads(); @@ -99,8 +98,11 @@ class Semaphore { __syncthreads(); if (wait_thread) { - - asm volatile ("st.global.cg.s32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #else + asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #endif } } }; diff --git a/include/cutlass/subbyte_reference.h b/include/cutlass/subbyte_reference.h index 9ce529015a..6f7aab2c6d 100644 --- a/include/cutlass/subbyte_reference.h +++ b/include/cutlass/subbyte_reference.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/tensor_coord.h b/include/cutlass/tensor_coord.h index 043f7a569d..d7a6d0df6a 100644 --- a/include/cutlass/tensor_coord.h +++ b/include/cutlass/tensor_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/tensor_ref.h b/include/cutlass/tensor_ref.h index a28dba57df..a805107c3d 100644 --- a/include/cutlass/tensor_ref.h +++ b/include/cutlass/tensor_ref.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -379,6 +379,7 @@ template < typename Element, typename Layout > +CUTLASS_HOST_DEVICE bool TensorRef_aligned(TensorRef const &ref, int alignment) { int const kStrideRank = Layout::kStrideRank; diff --git a/include/cutlass/tensor_ref_planar_complex.h b/include/cutlass/tensor_ref_planar_complex.h new file mode 100644 index 0000000000..54611911ca --- /dev/null +++ b/include/cutlass/tensor_ref_planar_complex.h @@ -0,0 +1,368 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides, bounds, and a pointer to tensor data. +*/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct PlanarComplexReference { + + // + // Type definitions + // + + using Element = Element_; + using ComplexElement = complex; + + // + // Data members + // + + Element *real; + Element *imag; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + PlanarComplexReference( + Element *real_ = nullptr, + Element *imag_ = nullptr + ): + real(real_), imag(imag_) { } + + /// Loads the complex element + CUTLASS_HOST_DEVICE + operator complex() const { + return complex{*real, *imag}; + } + + /// Stores a complex element to the location pointed to by the reference + CUTLASS_HOST_DEVICE + PlanarComplexReference &operator=(complex const &rhs) { + *real = rhs.real(); + *imag = rhs.imag(); + return *this; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank + and layout within memory. A TensorRef combines a pointer and a Layout concept + +*/ +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class TensorRefPlanarComplex { + public: + /// Data type of individual access + using Element = Element_; + + /// Complex element type + using ComplexElement = complex; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + static_assert(sizeof_bits::value >= 8, + "Planar complex not suitable for subbyte elements at this time"); + + /// Reference type to an element + using Reference = PlanarComplexReference; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// TensorRef to constant data + using ConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorRef to non-constant data + using NonConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// Pointer + Element* ptr_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + /// Offset to imaginary part + LongIndex imaginary_stride_; + + public: + + // + // Methods + // + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + Element *ptr = nullptr, ///< pointer to start of tensor + Layout const &layout = Layout(), ///< layout object containing stride and mapping function + LongIndex imaginary_stride = 0 + ): + ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + NonConstTensorRef const &ref ///< TensorRef to non-const data + ): + ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } + + /// Returns a reference to constant-valued tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(ptr_, layout_, imaginary_stride_); + } + + CUTLASS_HOST_DEVICE + NonConstTensorRef non_const_ref() const { + return NonConstTensorRef( + const_cast::type *>(ptr_), + layout_, + imaginary_stride_); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { + ptr_ = ptr; + imaginary_stride_ = imaginary_stride; + } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { + ptr_ = ptr; + layout_ = layout; + imaginary_stride_ = imaginary_stride; + } + + /// Returns true if the TensorRef is non-null + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * data() const { return ptr_; } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * imaginary_data() const { return ptr_ + imaginary_stride_; } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Reference data(LongIndex idx) const { + return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout & layout() { + return layout_; + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout layout() const { + return layout_; + } + + /// Gets the stride to an imaginary element + LongIndex imaginary_stride() const { + return imaginary_stride_; + } + + /// Gets the stride to an imaginary element + LongIndex &imaginary_stride() { + return imaginary_stride_; + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference operator[](TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { + ptr_ += offset_; + return *this; + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { + add_pointer_offset(offset(coord)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator+(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_coord_offset(b); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator+=(TensorCoord const& b) { + add_coord_offset(b); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator-(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_real() const { + return cutlass::TensorRef(data(), layout()); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_imag() const { + return cutlass::TensorRef(imaginary_data(), layout()); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +TensorRefPlanarComplex make_TensorRefPlanarComplex( + Element *ptr, + Layout const &layout, + int64_t imaginary_stride) { + + return TensorRefPlanarComplex(ptr, layout, imaginary_stride); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/tensor_view.h b/include/cutlass/tensor_view.h index 3efb16a5a2..a9cf569de4 100644 --- a/include/cutlass/tensor_view.h +++ b/include/cutlass/tensor_view.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -151,14 +151,20 @@ class TensorView : public TensorRef { /// Updates the pointer and layout object CUTLASS_HOST_DEVICE - void reset(Element* ptr, Layout const &layout, TensorCoord size) { + void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) { Base::reset(ptr, layout); - this->resize(extent_); + this->resize(extent); + } + + /// Updates the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr) { + Base::reset(ptr); } /// Changes the size of the view without affecting pointer or layout CUTLASS_HOST_DEVICE - void resize(TensorCoord extent) { + void resize(TensorCoord const &extent) { this->extent_ = extent; } diff --git a/include/cutlass/tensor_view_planar_complex.h b/include/cutlass/tensor_view_planar_complex.h new file mode 100644 index 0000000000..bdd29829da --- /dev/null +++ b/include/cutlass/tensor_view_planar_complex.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref_planar_complex.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Data type of element stored within tensor + typename Element_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename Layout_ +> +class TensorViewPlanarComplex : public TensorRefPlanarComplex { + public: + + /// Base tensor reference + using Base = cutlass::TensorRefPlanarComplex; + + /// Mapping function from logical coordinate to internal n-D array + using Layout = Layout_; + + /// TensorRef pointing to constant memory + using ConstTensorRef = typename Base::ConstTensorRef; + + /// Underlying TensorRef type + using TensorRef = Base; + + /// Data type of individual access + using Element = Element_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Coordinate in storage n-D array + using Stride = typename Layout::Stride; + + /// TensorView pointing to constant memory + using ConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorView pointing to non-constant memory + using NonConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// View extent + TensorCoord extent_; + + public: + + // + // Methods + // + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + Element *ptr, ///< pointer to start of tensor + Layout const &layout, ///< layout object containing stride and mapping function + LongIndex imaginary_stride, ///< stride between real and imaginary part + TensorCoord const &extent ///< size of the view in logical coordinates + ): + Base(ptr, layout, imaginary_stride), extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + TensorRef const &ref, ///< pointer and layout object referencing a tensor + TensorCoord const &extent ///< logical size of tensor + ): + Base(ref), extent_(extent) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + NonConstTensorView const &view ///< TensorView to non-const data + ): + Base(view), extent_(view.extent_) { } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { + Base::reset(ptr, layout, imaginary_stride); + this->resize(extent_); + } + + /// Changes the size of the view without affecting pointer or layout + CUTLASS_HOST_DEVICE + void resize(TensorCoord extent) { + this->extent_ = extent; + } + + /// Returns the extent of the view (the size along each logical dimension). + CUTLASS_HOST_DEVICE + TensorCoord const& extent() const { return extent_; } + + /// Returns the extent along a particular logical dimension. + CUTLASS_HOST_DEVICE + Index extent(int dim) const { return extent_.at(dim); } + + /// Determines whether a location is within a tensor + CUTLASS_HOST_DEVICE + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < kRank; ++dim) { + if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { + return false; + } + } + return true; + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + Base ref() const { + return Base(this->data(), this->layout(), this->imaginary_stride()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), this->layout()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent_); + } + + /// Returns a Tensor_view given location and size quantities + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex subview( + TensorCoord extent, ///< extent of the resulting view + TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view + ) const { + + return TensorViewPlanarComplex(ref(), extent.clamp(extent_ - location)).add_coord_offset(location); + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + size_t capacity() const { + return Base::layout().capacity(extent_); + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator+( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorViewPlanarComplex result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator+=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator-( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator-=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(-this->offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_real() const { + return cutlass::TensorView(this->data(), this->layout(), extent_); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_imag() const { + return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( + Element *ptr, + Layout const &layout, + typename Layout::LongIndex imaginary_stride, + typename Layout::TensorCoord const &extent) { + + return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h new file mode 100644 index 0000000000..64dc391497 --- /dev/null +++ b/include/cutlass/tfloat32.h @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a proxy class for storing Tensor Float 32 data type. +*/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#include +#include +#endif + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor Float 32 data type +struct alignas(4) tfloat32_t { + + // + // Data members + // + + /// Storage type + uint32_t storage; + + // + // Methods + // + + /// Constructs from an unsigned int + CUTLASS_HOST_DEVICE + static tfloat32_t bitcast(uint32_t x) { + tfloat32_t h; + h.storage = x; + return h; + } + + /// Emulated rounding is fast in device code + CUTLASS_HOST_DEVICE + static tfloat32_t round_half_ulp_truncate(float const &s) { + uint32_t x = reinterpret_cast(s); + + #if defined(__CUDA_ARCH__) + if (::isfinite(s)) { + x += 0x1000u; + } + #else + if (std::isfinite(s)) { + x += 0x1000u; + } + #endif + + return tfloat32_t::bitcast(x); + } + + /// Default constructor + CUTLASS_HOST_DEVICE + tfloat32_t() { } + + /// Floating-point conversion - round toward nearest even + CUTLASS_HOST_DEVICE + explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } + + /// Floating-point conversion - round toward nearest even + CUTLASS_HOST_DEVICE + explicit tfloat32_t(double x): tfloat32_t(float(x)) { + + } + + /// Integer conversion - round toward zero + CUTLASS_HOST_DEVICE + explicit tfloat32_t(int x) { + float flt = static_cast(x); + storage = reinterpret_cast(flt); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + + // Conversions to IEEE single-precision requires clearing dont-care bits + // of the mantissa. + unsigned bits = (storage & ~0x1fffu); + + return reinterpret_cast(bits); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator double() const { + return double(float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(float(*this)); + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + operator bool() const { + return (float(*this) != 0.0f); + } + + /// Obtains raw bits + CUTLASS_HOST_DEVICE + uint32_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((raw() & 0x80000000) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((raw() >> 23) & 0x0ff); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 127; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(raw() & 0x7fffff); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool signbit(cutlass::tfloat32_t const& h) { + return h.signbit(); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) { + return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isnan(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() == 0x0ff) && h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isfinite(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() != 0x0ff); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t nan_tf32(const char*) { + // NVIDIA canonical NaN + return cutlass::tfloat32_t::bitcast(0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isinf(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() == 0x0ff) && !h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isnormal(cutlass::tfloat32_t const& h) { + return h.exponent_biased() && h.exponent_biased() != 0x0ff; +} + +CUTLASS_HOST_DEVICE +int fpclassify(cutlass::tfloat32_t const& h) { + int exp = h.exponent_biased(); + int mantissa = h.mantissa(); + if (exp == 0x0ff) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { +#if defined(__CUDACC_RTC__) + return cutlass::tfloat32_t(sqrtf(float(h))); +#else + return cutlass::tfloat32_t(std::sqrt(float(h))); +#endif +} + +CUTLASS_HOST_DEVICE +tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { + + uint32_t a_mag = (reinterpret_cast(a) & 0x7fffffff); + uint32_t b_sign = (reinterpret_cast(b) & 0x80000000); + uint32_t result = (a_mag | b_sign); + + return reinterpret_cast(result); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace std { + +#if !defined(__CUDACC_RTC__) +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 19; + + /// Least positive value + static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } + + /// Minimum finite value + static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } + + /// Maximum finite value + static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } + + /// Returns smallest finite value + static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace std + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) + float(rhs)); +} + + +CUTLASS_HOST_DEVICE +tfloat32_t operator-(tfloat32_t const& lhs) { + float x = -reinterpret_cast(lhs); + return reinterpret_cast(x); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator++(tfloat32_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = tfloat32_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator--(tfloat32_t & lhs) { + float tmp(lhs); + --tmp; + lhs = tfloat32_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator++(tfloat32_t & lhs, int) { + tfloat32_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = tfloat32_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator--(tfloat32_t & lhs, int) { + tfloat32_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = tfloat32_t(tmp); + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t operator "" _tf32(long double x) { + return cutlass::tfloat32_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { + return cutlass::tfloat32_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/thread/matrix.h b/include/cutlass/thread/matrix.h index 1e1f3eebd3..a54b347150 100644 --- a/include/cutlass/thread/matrix.h +++ b/include/cutlass/thread/matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/transform/pitch_linear_thread_map.h b/include/cutlass/transform/pitch_linear_thread_map.h index 71edb936f1..812dbd772b 100644 --- a/include/cutlass/transform/pitch_linear_thread_map.h +++ b/include/cutlass/transform/pitch_linear_thread_map.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/transform/thread/transpose.h b/include/cutlass/transform/thread/transpose.h index 552295d847..268e648135 100644 --- a/include/cutlass/transform/thread/transpose.h +++ b/include/cutlass/transform/thread/transpose.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/transform/thread/unaryOp.h b/include/cutlass/transform/thread/unaryOp.h new file mode 100644 index 0000000000..de4f79b972 --- /dev/null +++ b/include/cutlass/transform/thread/unaryOp.h @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" + +namespace cutlass { +namespace transform { +namespace thread { + +namespace UnaryTransform { + struct Identity; ///< None (i.e., identity) + struct Conjugate; ///< Complex conjugate +} + +/// Element-wise unary operator that transforms one element of a fragment at a time +template< + typename FragmentIn, ///< Input Fragment + typename FragmentOut,///< Output Fragment + typename Transform> ///< Unary transform operator +class UnaryOp +{ + public: + CUTLASS_DEVICE + static FragmentOut execute(FragmentIn &in) + { + static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match."); + static_assert(std::is_same::value || + std::is_same::value, + "Unary Operator not supported."); + + FragmentOut out; + if( std::is_same::value ) + { + CUTLASS_PRAGMA_UNROLL + for(int i=0; i < FragmentIn::kElements; ++i){ + out[i] = static_cast(in[i]); + } + } + else if( std::is_same::value ) + { + for(int i=0; i < FragmentIn::kElements; ++i){ + out[i] = conj(static_cast(in[i])); + } + } + return out; + } +}; + +template +class UnaryOp +{ + public: + CUTLASS_DEVICE + static FragmentIn execute(FragmentIn &in) + { + static_assert(std::is_same::value || + std::is_same::value, + "Unary Operator not supported."); + + if( std::is_same::value ) + { + return in; + } + else if( std::is_same::value ) + { + for(int i=0; i < FragmentIn::kElements; ++i){ + in[i] = conj(in[i]); + } + } + return in; + } +}; +} +} +} + + diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index ff754cfaf0..c77a09ffbd 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without *modification, are permitted provided that the following conditions are met: @@ -128,13 +128,13 @@ class PredicatedTileAccessIterator::value / 8; if (kAdvanceRank) { // advance along strided dimension inc_advance_ = - Shape::kStrided * stride_ * sizeof_bits::value / 8; + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; } else { // advance along contiguous dimension inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; } - inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) * - ThreadMap::Delta::kStrided * stride_ * + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; }; }; @@ -216,6 +216,7 @@ class PredicatedTileAccessIterator::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + AccessType *frag_ptr = reinterpret_cast(&frag); CUTLASS_PRAGMA_UNROLL @@ -310,11 +315,15 @@ class PredicatedTileIterator(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); - if (address_iterator_.valid()) { - frag_ptr[idx] = *ptr; - } ++address_iterator_; } } @@ -323,11 +332,17 @@ class PredicatedTileIterator::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { address_iterator_.set_iteration_index(0); AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -340,8 +355,11 @@ class PredicatedTileIterator(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + if (address_iterator_.valid()) { - *(address_iterator_.get() + pointer_offset) = frag_ptr[idx]; + *access_ptr = frag_ptr[idx]; } ++address_iterator_; } @@ -351,7 +369,7 @@ class PredicatedTileIterator +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + + private: + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : byte_offset_(0) { + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast( + ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + AccessType *access_ptr = pointer_; + + int access_offset = + (iteration_strided_ * ThreadMap::Delta::kStrided * Layout::kInterleavedK + + iteration_contiguous_ * ThreadMap::Delta::kContiguous) / ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kCount); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for k interleaved arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// + +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + cutlass::MatrixShape, + Element, + layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, + (kAdvanceRank == 1 ? 0 : 1), + ThreadMap + >; + + private: + + /// Element type per access + using AccessType = Array; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace transform } // namespace cutlass diff --git a/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h new file mode 100644 index 0000000000..5a0c74fdc6 --- /dev/null +++ b/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h @@ -0,0 +1,1522 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous64b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 64; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 64b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + Index stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous64b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous64b, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous64b, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicand64bCrosswise, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 64; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 64b"); + + ///< Number of pointers - two pointers are needed if making more than 4 iterations along + ///< strided dimension + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + Index stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_[Detail::kPointerCount]; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data()); + + byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element); + + if (Detail::kPointerCount == 2) { + byte_offset_[1] = byte_offset_[0] ^ 8; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset / ThreadMap::kElementsPerAccess; + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + // Map the logical contiguous and strided access to the internal swizzled structure. + int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16; + + char *access_byte_ptr = reinterpret_cast(pointer_ + uniform_offset); + + int byte_offset; + + // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations) + // in the strided dimension + if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) { + byte_offset = byte_offset_[1]; + } + else { + byte_offset = byte_offset_[0]; + } + + return reinterpret_cast(access_byte_ptr + byte_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicand64bCrosswise, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicand64bCrosswise, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicand64bCrosswise, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous128b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + Index stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous128b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous128b, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous128b, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCrosswise128x4, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + + static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension"); + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + Index stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2); + int offset_s = (iteration_strided_ / 2) * 8; + + int access_offset = offset_c * stride_ + offset_s; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous * stride_ + + coord.strided() * Shape::kStrided * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise128x4, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise128x4, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator.h b/include/cutlass/transform/threadblock/regular_tile_iterator.h index 8445b83664..d7928ac00a 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h index 93849c6586..c3f0b5249b 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h index 4ea4729386..85d702fec6 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h index 23c6d946c0..c7f0690779 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -149,6 +149,12 @@ class RegularTileIterator< /// Loads a fragment from memory CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, Index byte_offset) { address_iterator_.set_iteration_index(0); AccessType *frag_ptr = reinterpret_cast(&frag); @@ -157,7 +163,11 @@ class RegularTileIterator< CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { int access_idx = c + s * ThreadMap::Iterations::kContiguous; - frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + frag_ptr[access_idx] = *access_ptr; ++address_iterator_; } } @@ -172,6 +182,11 @@ class RegularTileIterator< /// Store a fragment to memory CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { address_iterator_.set_iteration_index(0); AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -180,7 +195,11 @@ class RegularTileIterator< CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { int access_idx = c + s * ThreadMap::Iterations::kContiguous; - *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; ++address_iterator_; } } @@ -189,7 +208,7 @@ class RegularTileIterator< /// Store a fragment to memory CUTLASS_DEVICE void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); + store_with_byte_offset(frag, 0); } }; @@ -567,6 +586,11 @@ class RegularTileIterator::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { address_iterator_.set_iteration_index(0); AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -575,7 +599,11 @@ class RegularTileIterator(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; ++address_iterator_; } } @@ -803,6 +831,271 @@ class RegularTileIterator +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + + private: + + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_pointer_offset(Shape::kCount); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_pointer_offset(coord.contiguous() * Shape::kCount); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for k interleaved arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// + +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + cutlass::MatrixShape, + Element, + layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, + (kAdvanceRank == 1 ? 0 : 1), + ThreadMap + >; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace transform } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h index 2a57936e84..82c8842ec0 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -1146,7 +1146,7 @@ class RegularTileIterator< void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { AccessType *frag_ptr = reinterpret_cast(&frag); - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { @@ -1185,13 +1185,14 @@ class RegularTileIterator< void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { AccessType const *frag_ptr = reinterpret_cast(&frag); - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; - access_ptr += 16 * (s / 2); + access_ptr += 16 * (s / 2) + vec_pointer_offset; CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { @@ -1199,8 +1200,7 @@ class RegularTileIterator< for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { int access_offset = - c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + - vec_pointer_offset + i * line_size; + c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * Detail::kIterarionsPerAccess + i; diff --git a/include/cutlass/util/debug.h b/include/cutlass/util/debug.h deleted file mode 100644 index 9941b41a17..0000000000 --- a/include/cutlass/util/debug.h +++ /dev/null @@ -1,122 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -/** - * \file - * \brief Debugging and logging functionality - */ - -#include - -namespace cutlass { - -/****************************************************************************** - * Debug and logging macros - ******************************************************************************/ - -/** - * Formats and prints the given message to stdout - */ -#if !defined(CUDA_LOG) -#if !defined(__CUDA_ARCH__) -#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) -#else -#define CUDA_LOG(format, ...) \ - printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ - blockIdx.x, \ - blockIdx.y, \ - blockIdx.z, \ - threadIdx.x, \ - threadIdx.y, \ - threadIdx.z, \ - __VA_ARGS__); -#endif -#endif - -/** - * Formats and prints the given message to stdout only if DEBUG is defined - */ -#if !defined(CUDA_LOG_DEBUG) -#ifdef DEBUG -#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) -#else -#define CUDA_LOG_DEBUG(format, ...) -#endif -#endif - -/** - * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) - * along with the supplied source context. - * - * \return The CUDA error. - */ -__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, - const char* filename, - int line) { - (void)filename; - (void)line; - if (error) { -#if !defined(__CUDA_ARCH__) - fprintf( - stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error)); - fflush(stderr); -#else - printf("CUDA error %d [%s, %d]\n", error, filename, line); -#endif - } - return error; -} - -/** - * \brief Perror macro - */ -#ifndef CUDA_PERROR -#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__) -#endif - -/** - * \brief Perror macro with exit - */ -#ifndef CUDA_PERROR_EXIT -#define CUDA_PERROR_EXIT(e) \ - if (cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)) { \ - exit(1); \ - } -#endif - -/** - * \brief Perror macro only if DEBUG is defined - */ -#ifndef CUDA_PERROR_DEBUG -#ifdef DEBUG -#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) -#else -#define CUDA_PERROR_DEBUG(e) (e) -#endif -#endif - -} // namespace cutlass diff --git a/include/cutlass/wmma_array.h b/include/cutlass/wmma_array.h index 7758309ec5..e80961394d 100644 --- a/include/cutlass/wmma_array.h +++ b/include/cutlass/wmma_array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/media/docs/code_organization.md b/media/docs/code_organization.md index ffab354ec4..9a00d3056f 100644 --- a/media/docs/code_organization.md +++ b/media/docs/code_organization.md @@ -88,6 +88,7 @@ tools/ cutlass/ library/ # header files for CUTLASS Deliverables Library (in cutlass::library:: namespace) + handle.h # implements a host-side API for launching kernels, similar to cuBLAS library.h # defines enums and structs to describe the tiled structure of operator instances manifest.h # collection of all instances @@ -175,6 +176,14 @@ examples/ 07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores 08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores + + 10_planar_complex/ # example demonstrating planar complex GEMM kernels + + 11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes + + 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu + + 13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel ``` ## Media @@ -211,7 +220,7 @@ of tests run may vary over time as more are added. # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/doxygen_mainpage.md b/media/docs/doxygen_mainpage.md index 6b8e09dd40..15656d25e5 100644 --- a/media/docs/doxygen_mainpage.md +++ b/media/docs/doxygen_mainpage.md @@ -120,7 +120,7 @@ cudaError_t cutlass_sgemm_nn( # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/efficient_gemm.md b/media/docs/efficient_gemm.md index d601ff5a63..7a1a6ae7f4 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/efficient_gemm.md @@ -216,6 +216,7 @@ participating warps - since each warp now owns a partial sum (since they compute The following additional resources describe design and implementation details of GEMMs targeting NVIDIA GPUs. +- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100.](https://www.nvidia.com/en-us/gtc) (SR 21745) - [CUTLASS: Fast Linear Algebra in CUDA C++](https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/) - [CUTLASS: SOFTWARE PRIMITIVES FOR DENSE LINEAR ALGEBRA AT ALL LEVELS AND SCALES WITHIN CUDA](https://on-demand-gtc.gputechconf.com/gtcnew/sessionview.php?sessionName=s8854-cutlass%3a+software+primitives+for+dense+linear+algebra+at+all+levels+and+scales+within+cuda) - [Programming Tensor Cores: NATIVE VOLTA TENSOR CORES WITH CUTLASS](https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9593-cutensor-high-performance-tensor-operations-in-cuda-v2.pdf) @@ -224,7 +225,7 @@ targeting NVIDIA GPUs. # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/functionality.md b/media/docs/functionality.md index 26171ca231..465fae7d1f 100644 --- a/media/docs/functionality.md +++ b/media/docs/functionality.md @@ -27,6 +27,16 @@ Hyperlinks to relevant unit tests demonstrate how specific template instances ma | **TensorOp** | 75 | 10.2+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu) | | **TensorOp** | 75 | 10.2+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu) | | **TensorOp** | 75 | 10.2+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 80 | 11.0+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `bf16 * bf16 + f32 => {bf16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_bf16n_bf16t_bf16t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `tf32 * tf32 + f32 => f32`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `cf32 * cf32 + cf32 => cf32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu) | +| **TensorOp** | 80 | 11.0+ | `cf64 * cf64 + cf64 => cf64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu), [Gaussian 3m](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu) | ## Warp-level Matrix Multiply with Tensor Cores @@ -36,9 +46,13 @@ The following table summarizes supported warp level shapes for each TensorOp ins |-----------------|-----------------------|--------------------------------------------| | **TensorOp** | 8-by-8-by-4 | 32x32x4, 32x64x4, 64x32x4, 64x64x4 | | **TensorOp** | 16-by-8-by-8 | 32x32x8, 32x64x8, 64x32x8, 64x64x8 | +| **TensorOp** | 16-by-8-by-16 | 32x32x16, 32x64x16, 64x32x16, 64x64x16 | | **TensorOp** | 8-by-8-by-16 | 32x32x16, 32x64x16, 64x32x16, 64x64x16 | | **TensorOp** | 8-by-8-by-32 | 32x32x32, 32x64x32, 64x32x32, 64x64x32 | +| **TensorOp** | 16-by-8-by-32 | 32x32x32, 32x64x32, 64x32x32, 64x64x32 | +| **TensorOp** | 16-by-8-by-64 | 32x32x64, 32x64x64, 64x32x64, 64x64x64 | | **TensorOp** | 8-by-8-by-128 | 32x32x128, 32x64x128, 64x32x128, 64x64x128 | +| **TensorOp** | 16-by-8-by-256 | 32x32x256, 32x64x256, 64x32x256, 64x64x256 | TensorOp instructions depend on a permuted shared memory layout that can be efficiently loaded from. The following tables summarize the destination shared memory layout that @@ -67,6 +81,38 @@ from global memory with layout specified in the column "GMEM Layout." | **C** | `half_t` | `RowMajor` | `RowMajor` | | **C** | `float` | `RowMajor` | `RowMajor` | +**TensorOp 16-by-8-by-8.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `tfloat32_t` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<32>` | +| **A** | `tfloat32_t` | `RowMajor` | `RowMajorTensorOpCrosswise<32>` | +| **B** | `tfloat32_t` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<32>` | +| **B** | `tfloat32_t` | `RowMajor` | `RowMajorTensorOpCongruous<32>` | +| **C** | `float` | `RowMajor` | `RowMajor` | + + +**TensorOp 16-by-8-by-16.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `half_t`, `bfloat16_t` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<16>` | +| **A** | `half_t`, `bfloat16_t` | `RowMajor` | `RowMajorTensorOpCrosswise<16>` | +| **B** | `half_t`, `bfloat16_t` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<16>` | +| **B** | `half_t`, `bfloat16_t` | `RowMajor` | `RowMajorTensorOpCongruous<16>` | +| **C** | `half_t` | `RowMajor` | `RowMajor` | +| **C** | `float` | `RowMajor` | `RowMajor` | + +**TensorOp 8-by-8-by-4.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `double` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<64>` | +| **A** | `double` | `RowMajor` | `RowMajorTensorOpCrosswise<64>` | +| **B** | `double` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<64>` | +| **B** | `double` | `RowMajor` | `RowMajorTensorOpCongruous<64>` | +| **C** | `double` | `RowMajor` | `RowMajor` | + **TensorOp 8-by-8-by-16.** |**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | @@ -75,6 +121,14 @@ from global memory with layout specified in the column "GMEM Layout." | **B** | `int8_t` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<8>` | | **C** | `int32_t` | `RowMajor` | `RowMajor` | +**TensorOp 16-by-8-by-32.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `int8_t` | `RowMajor` | `RowMajorTensorOpCrosswise<8>` | +| **B** | `int8_t` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<8>` | +| **C** | `int32_t` | `RowMajor` | `RowMajor` | + **TensorOp 8-by-8-by-32.** |**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | @@ -83,6 +137,14 @@ from global memory with layout specified in the column "GMEM Layout." | **B** | `int4b_t` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<4>` | | **C** | `int32_t` | `RowMajor` | `RowMajor` | +**TensorOp 16-by-8-by-64.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `int4b_t` | `RowMajor` | `RowMajorTensorOpCrosswise<4>` | +| **B** | `int4b_t` | `ColumnMajor` | `ColumnMajorTensorOpCongruous<4>` | +| **C** | `int32_t` | `RowMajor` | `RowMajor` | + **TensorOp 8-by-8-by-128.** |**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | @@ -118,7 +180,7 @@ CUDA exposes warp-level matrix operations in the CUDA C++ WMMA API. The CUDA C++ # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/fundamental_types.md b/media/docs/fundamental_types.md index 9837402271..7556cd45dc 100644 --- a/media/docs/fundamental_types.md +++ b/media/docs/fundamental_types.md @@ -16,6 +16,8 @@ Most types in CUTLASS are usable in both host code and device code. Moreover, th CUTLASS defines classes for the following numeric data types. * `half_t`: IEEE half-precision floating point (exponent: 5b, mantissa: 10b; literal suffix `_hf`) +* `bfloat16_t`: BFloat16 data type (exponent: 8b, mantissa: 7b; literal suffix `_bf16`) +* `tfloat32_t`: Tensor Float 32 data type (exponent: 8b, mantissa: 10b; literal suffix `_tf32`) * `int4_t`, `uint4_t`: 4b signed and unsigned integer (literal suffx `_s4`, `_u4`) * `bin1_t`: 1b binary numeric type (literal suffix `_b1`) * `complex`: defines complex-valued data type based on the supplied real-valued numeric type @@ -182,6 +184,39 @@ AlignedArray *ptr = reinterpret_cast *>(smem_ AlignedArray x = ptr[threadIdx.x]; // 128b shared memory load ``` +### Numeric Conversion + +CUTLASS defines procedures for performing numeric conversion between data types in `cutlass/numeric_conversion.h`. +Where possible, these target hardware acceleration on the target architecture and support multiple rounding modes. + +```c++ +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +NumericConverter convert_f32_to_f16; +NumericConverter convert_f32_to_tf32; + +half_t x = convert_f32_to_f16(3.14159f); +tfloat32_t y = convert_f32_to_tf32(3.14159f); +``` + +Recent GPU architectures such as NVIDIA Turing and Ampere combine numeric conversion with efficient packing +into bit vectors. Consequently, CUTLASS defines conversion on both scalars and `Array<>` objects to implement +the optimal code sequence on all architectures. + +```c++ +// +// Example: convert and pack 32b signed integers to a vector of packed signed 8-bit integers. +// +int const kN = 16; +Array destination; +Array source; + +NumericConverter convert; + +destination = convert(source); +``` + ### Coord ```c++ @@ -311,7 +346,7 @@ support on current and future NVIDIA GPUs. # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/gemm_api.md b/media/docs/gemm_api.md index 0d58cd36fe..759b1cd417 100644 --- a/media/docs/gemm_api.md +++ b/media/docs/gemm_api.md @@ -514,7 +514,7 @@ to inline PTX. # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/layout.md b/media/docs/layout.md index fc36a27619..bacec0e442 100644 --- a/media/docs/layout.md +++ b/media/docs/layout.md @@ -267,7 +267,7 @@ Permuted Shared Memory Layouts: # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 34051651d1..7d2356c558 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -15,10 +15,12 @@ $ make cutlass_profiler -j To limit compilation time, only one tile size (128x128) is instantiated for each data type, math instruction, and layout. To instantiate all sizes, set the following environment variable when running CMake from an empty `build/` directory. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON ... $ make cutlass_profiler -j ``` +Enabling the unity build places multiple kernel instances in one compilation unit, thereby reducing size of the compiled +binary and avoiding linker limitations on some platforms. The CUTLASS Profiler sources are stored in ```bash @@ -102,7 +104,7 @@ Report: --verbose= If true (default), prints human-readable text to stdout. About: - --version CUTLASS 2.0.0 built on Nov 19 2019 at 13:01:00 + --version CUTLASS 2.2.0 built on Jun 8 2020 at 07:59:33 Operations: --operation= Specifies a particular operation to run or print the usage statement. @@ -191,29 +193,34 @@ Test your changes to gemm kernels with a quick functional test and save results Example command line for profiling SGEMM kernels is as follows: ```bash -$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096 +$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096 + + ============================= Problem ID: 1 - Provider: CUTLASS - Operation: cutlass_simt_sgemm_128x128_nn + Provider: CUTLASS + OperationKind: gemm + Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1 + + Status: Success + Verification: ON + Disposition: Passed - Disposition: Passed - Status: Success + cuBLAS: Passed - Arguments: --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 \ - --split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 \ - --stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 \ - --max_cc=1024 + Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \ + --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \ + --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024 - Bytes: 52428800 bytes - FLOPs: 146064539648 flops + Bytes: 180355072 bytes + FLOPs: 115992428544 flops - Runtime: 10.5424 ms - Memory: 4.63158 GiB/s + Runtime: 6.73655 ms + Memory: 24.934 GiB/s - Math: 13854.9 GFLOP/s + Math: 17218.4 GFLOP/s ``` Note, the arguments which appear in the output may be used as command line parameters for subsequent invocations. @@ -224,31 +231,34 @@ Note, the arguments which appear in the output may be used as command line param To execute kernels targeting Tensor Core operations, supply the flag `--op_class=tensorop` in the command line. ```bash -$ ./tools/profiler/cutlass_profiler --op_class=tensorop +$ ./tools/profiler/cutlass_profiler --op_class=tensorop --m=3456 --n=4096 --k=8192 + + ============================= Problem ID: 1 - Provider: CUTLASS - Operation: cutlass_turing_h1688gemm_128x128_nt - - Disposition: Passed - Status: Success + Provider: CUTLASS + OperationKind: gemm + Operation: cutlass_tensorop_s16816gemm_f16_256x128_32x3_nn_align8 - Arguments: --m=4352 --n=4096 --k=4096 --A=f16:column --B=f16:row --C=f16:column --alpha=1 --beta=0 \ - --op_class=tensorop --accum=f16 --cta_m=128 --cta_n=128 --cta_k=32 --stages=2 \ - --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=8 \ - --min_cc=75 --max_cc=1024 + Status: Success + Verification: ON + Disposition: Passed + cuBLAS: Passed - Bytes: 52428800 bytes - FLOPs: 146064539648 flops + Arguments: --m=3456 --n=4096 --k=8192 --A=f16:column --B=f16:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \ + --batch_count=1 --op_class=tensorop --accum=f32 --cta_m=256 --cta_n=128 --cta_k=32 --stages=3 --warps_m=4 \ + --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=16 --min_cc=80 --max_cc=1024 - Runtime: 1.51255 ms - Memory: 32.2821 GiB/s + Bytes: 180355072 bytes + FLOPs: 231956545536 flops - Math: 96568.7 GFLOP/s + Runtime: 0.98647 ms + Memory: 170.272 GiB/s + Math: 235138 GFLOP/s ``` ## Covering the problem space @@ -271,7 +281,7 @@ with the `--output=` command line option as shown: ```bash $ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sgemm_128x128_nn \ - --m=4352 --n=4096 --k=8:4096:8 --output=report.csv + --m=3456 --n=4096 --k=8:4096:8 --output=report.csv ``` To faclitate generation of pivot tables and charts, additional columns may be prepended with the @@ -279,13 +289,13 @@ To faclitate generation of pivot tables and charts, additional columns may be pr ```bash $ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sgemm_128x128_nn \ - --m=4352 --n=4096 --k=8:4096:8 --output=report.csv \ - --tags=cutlass:2.0,date:2019-11-19 + --m=3456 --n=4096 --k=8:4096:8 --output=report.csv \ + --tags=cutlass:2.2,date:2020-06-08 ``` # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/programming_guidelines.md b/media/docs/programming_guidelines.md index 5ce16af1de..0cf7ea257f 100644 --- a/media/docs/programming_guidelines.md +++ b/media/docs/programming_guidelines.md @@ -104,6 +104,14 @@ for (int idx = 0; idx < kN; ++idx) { // Loop has constant number of iterati ## Style +### C++ Style + +CUTLASS source code follows the +[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) with exceptions and extensions. + +Design choices should be consistent with the +[CppCoreGuidelines](https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md) recommendations by Stroustrup and Sutter. + ### CUDA Built-in Variables Avoid direct access to CUDA built-in variables `threadIdx`, `blockIdx`, `blockDim`, and `gridDim` within @@ -132,14 +140,6 @@ In particular, be sure to use: Avoid defining alternative implementations of the same functionality. Instead, prefer to enhance or extend additional components where it makes sense. -### C++ Style - -CUTLASS source code follows the -[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) with exceptions and extensions. - -Design choices should be consistent with the -[CppCoreGuidelines](https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md) recommendations by Stroustrup and Sutter. - ### Classes and Structs Type names use `CapitalLetters` except when implementations are a _perfect_ drop-in replacement for @@ -178,9 +178,10 @@ Members within classes and structures should be organized as follows: 3. Constructors 4. Other methods -This convention follows the [CUB library](https://nvlabs.github.io/cub/), -and it also approximates the usual order of Systems and Controls textbooks. That is, they start by -(1.) identifying relevant constants, (2.) define a state-space representation of the dynamical system +This convention follows the [CUB library](https://nvlabs.github.io/cub/) and is also described by +[Howard Hinnant](https://howardhinnant.github.io/classdecl.html). Unsurprisingly, it approximates +the usual ordering of chapters in a typical Systems and Controls textbook. That is, +(1.) identify relevant constants, (2.) define a state-space representation of the dynamical system under study (i.e. the data members), and (3.) devote subsequent chapters to definining dynamical behavior of the system (i.e. the methods). @@ -291,7 +292,7 @@ Github's pretty printer. # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 6db005dd36..082b4c10b4 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -7,7 +7,7 @@ ## Prerequisites CUTLASS requires: -- NVIDIA CUDA Toolkit (9.2 or later required, 10.2 recommended) +- NVIDIA CUDA Toolkit (9.2 or later required, [11.0](https://developer.nvidia.com/cuda-toolkit) recommended) - CMake 3.12+ - host compiler supporting C++11 or greater (g++ 7.3.0 or Microsoft Visual Studio 2015 recommended) - Python 3.6+ @@ -20,23 +20,7 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc $ mkdir build && cd build -$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture -``` - -## Clang - -For experimental purposes, CUTLASS may be compiled with -[clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the -[CUDA 10.0 Toolkit](https://developer.nvidia.com/cuda-10.0-download-archive). -At this time, compiling with clang enables the CUTLASS SIMT GEMM kernels (sgemm, dgemm, hgemm, igemm) -but does not enable TensorCores. - -```bash -$ mkdir build && cd build - -$ cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ .. - -$ make test_unit -j +$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture ``` ## Build and run the CUTLASS Profiler @@ -120,6 +104,53 @@ $ make test_unit_gemm_warp -j [100%] Built target test_unit_gemm_warp ``` +## Building for Multiple Architectures + +To minimize compilation time, specific GPU architectures can be enabled via the CMake command, +selected by [CUDA Compute Capability.](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities) + +**NVIDIA Ampere Architecture.** +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture +``` + +**NVIDIA Turing Architecture.** +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA Turing GPU architecture +``` + +**NVIDIA Volta Architecture.** +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=70 # compiles for NVIDIA Volta GPU architecture +``` + +**NVIDIA Pascal Architecture.** +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="60;61" # compiles for NVIDIA Pascal GPU architecture +``` + +**NVIDIA Maxwell Architecture.** +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="50;53" # compiles for NVIDIA Maxwell GPU architecture +``` + +## Clang + +For experimental purposes, CUTLASS may be compiled with +[clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the +[CUDA 10.0 Toolkit](https://developer.nvidia.com/cuda-10.0-download-archive). +At this time, compiling with clang enables the CUTLASS SIMT GEMM kernels (sgemm, dgemm, hgemm, igemm) +but does not enable TensorCores. + +```bash +$ mkdir build && cd build + +$ cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ .. + +$ make test_unit -j +``` + + ## Using CUTLASS within other applications Applications should list [`/include`](/include) within their include paths. They must be @@ -130,6 +161,7 @@ compiled as C++11 or greater. #include #include #include +#include int main() { @@ -141,12 +173,15 @@ int main() { } ``` -## Launching a GEMM kernel +## Launching a GEMM kernel in CUDA -**Example:** launch a mixed-precision GEMM targeting Volta Tensor Cores. +**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores. + +_Note, this example uses CUTLASS Utilities. Be sure `tools/util/include` is listed as an include path._ ```c++ #include -#include +#include + #include int main() { @@ -161,7 +196,7 @@ int main() { cutlass::layout::ColumnMajor, // LayoutOutput float, // ElementAccumulator cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores - cutlass::arch::Sm70 // tag indicating target GPU compute architecture + cutlass::arch::Sm75 // tag indicating target GPU compute architecture >; Gemm gemm_op; @@ -193,7 +228,7 @@ int main() { int lda = A.device_ref().stride(0); int ldb = B.device_ref().stride(0); int ldc = C.device_ref().stride(0); - int ldd = D.device_ref().stride(0); + int ldd = C.device_ref().stride(0); // // Launch GEMM on the device // @@ -235,9 +270,180 @@ Note, the above could be simplified as follows using helper methods defined in ` }); ``` +# CUTLASS Library + +The [CUTLASS Library](./tools/library) defines an API for managing and executing collections of compiled +kernel instances and launching them from host code without template instantiations in client code. + +The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its +kernel selection procedure is intended only to be functionally sufficient. It may not launch the +optimal tile size for a given problem. It chooses the first available kernel whose data types, +layouts, and alignment constraints satisfy the given problem. Kernel instances and a data structure +describing them are completely available to client applications which may choose to implement their +own selection logic. + +[cuBLAS](https://developer.nvidia.com/cublas) offers the best performance and functional coverage +for dense matrix computations on NVIDIA GPUs. + +The CUTLASS Library is used by the CUTLASS Profiler to manage kernel instances, and it is also used +by several SDK examples. + +* [10_planar_complex](/examples/10_planar_complex/planar_complex.cu) +* [11_planar_complex_array](/examples/11_planar_complex_array/planar_complex_array.cu) + +The CUTLASS Library defines enumerated types describing numeric data types, matrix and tensor +layouts, math operation classes, complex transformations, and more. + +Client applications should specify [`tools/library/include`](/tools/library/include) in their +include paths and link against libcutlas_lib.so. + +The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies +its dependency on the CUTLASS Library with the following CMake command. +``` +target_link_libraries( + 10_planar_complex + PRIVATE + cutlass_lib + cutlass_tools_util_includes +) +``` + +A sample kernel launch from host-side C++ is shown as follows. + +```c++ +#include "cutlass/library/library.h" +#include "cutlass/library/handle.h" + +int main() { + + // + // Define the problem size + // + int M = 512; + int N = 256; + int K = 128; + + float alpha = 1.25f; + float beta = -1.25f; + + // + // Allocate device memory + // + + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B({K, N}); + cutlass::HostTensor C({M, N}); + + float const *ptrA = A.device_data(); + float const *ptrB = B.device_data(); + float const *ptrC = C.device_data(); + float *ptrD = C.device_data(); + + int lda = A.device_ref().stride(0); + int ldb = B.device_ref().stride(0); + int ldc = C.device_ref().stride(0); + int ldd = D.device_ref().stride(0); + + // + // CUTLASS Library call to execute device GEMM + // + + cutlass::library::Handle handle; + + // + // Launch GEMM on CUDA device. + // + + cutlass::Status status = handle.gemm( + M, + N, + K, + + cutlass::library::NumericTypeID::kF32, // data type of internal accumulation + cutlass::library::NumericTypeID::kF32, // data type of alpha/beta scalars + + &alpha, // pointer to alpha scalar + + cutlass::library::NumericTypeID::kF32, // data type of A matrix + cutlass::library::LayoutTypeID::kColumnMajor, // layout of A matrix + ptrA, // pointer to A matrix in device memory + lda, // leading dimension of A matrix + + cutlass::library::NumericTypeID::kF32, // data type of B matrix + cutlass::library::LayoutTypeID::kColumnMajor, // layout of B matrix + ptrB, // pointer to B matrix in device memory + ldb, // leading dimension of B matrix + + &beta, // pointer to beta scalar + + cutlass::library::NumericTypeID::kF32, // data type of C and D matrix + + ptrC, // pointer to C matrix in device memory + ldc, // leading dimension fo C matrix + + ptrD, // pointer to D matrix in device memory + ldd // leading dimension of D matrix + ); + + if (status != cutlass::Status::kSuccess) { + return -1; + } + + return 0; +} +``` + +Kernels can be selectively included in the CUTLASS Library by specifying filter strings when +executing CMake. For example, only single-precision GEMM kernels can be instantiated as follows. + +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=sgemm +``` + +Compling only the kernels desired reduces compilation time. + +To instantiate kernels of all tile sizes, data types, and alignment constraints, specify +`-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`. + +Several recipes are defined below for convenience. They may be combined as a comma-delimited list. + +**Example.** All GEMM kernels targeting NVIDIA Ampere Tensor Cores. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS=tensorop*gemm +``` + +**Example.** All kernels for NVIDIA Volta, Turing, and Ampere architectures. Enabling +the "unity build" instantiates multiple kernel instances in each compilation unit, thereby +reducing binary size and avoiding linker limitations on some platforms. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all \ + -DCUTLASS_UNITY_BUILD_ENABLED=ON +``` + +**Example.** All GEMM kernels targeting Turing Tensor Cores. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=tensorop*gemm +``` + +**Example.** All GEMM kernels with single-precision accumulation. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=s*gemm +``` + +**Example.** All kernels which expect A and B to be column-major. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=gemm*nn +``` + +**Example.** All planar complex GEMM variants. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=planar_complex +``` + + # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/terminology.md b/media/docs/terminology.md index 1ef0b3839a..07464143cb 100644 --- a/media/docs/terminology.md +++ b/media/docs/terminology.md @@ -74,7 +74,7 @@ contiguous and strided dimensions of a tile. # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/tile_iterator_concept.md b/media/docs/tile_iterator_concept.md index 4fd068f894..061ff90734 100644 --- a/media/docs/tile_iterator_concept.md +++ b/media/docs/tile_iterator_concept.md @@ -466,7 +466,7 @@ struct WriteableReadableRandomAccessContiguousTileIteratorConcept { # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/docs/utilities.md b/media/docs/utilities.md index e3d2a52c03..b9ddc79a70 100644 --- a/media/docs/utilities.md +++ b/media/docs/utilities.md @@ -111,8 +111,8 @@ std::cout << tensor.host_view() << std::endl; ## Device Allocations -To strictly allocate memory on the device using the smart pointers to manage allocation and deallocation, -use `cutlass::device_memory::allocation<>`. +To strictly allocate memory on the device using the smart pointer pattern to manage allocation and deallocation, +use `cutlass::DeviceAllocation<>`. **Example:** allocating an array in device memory. ```c++ @@ -128,7 +128,7 @@ int main() { size_t N = 1024; - cutlass::device_memory::allocation device_alloc(N); + cutlass::DeviceAllocation device_alloc(N); // Call a CUDA kernel passing device memory as a pointer argument kernel<<< grid, block >>>(alloc.get()); @@ -340,8 +340,9 @@ used throughout the unit tests. ```c++ #include #include -#include + #include +#include int main() { @@ -378,7 +379,7 @@ int main() { # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/media/images/cutlass-performance-plot.png b/media/images/cutlass-performance-plot.png index 1d76a7e64b..9caf022349 100644 Binary files a/media/images/cutlass-performance-plot.png and b/media/images/cutlass-performance-plot.png differ diff --git a/media/images/gemm-hierarchy-with-epilogue-no-labels.png b/media/images/gemm-hierarchy-with-epilogue-no-labels.png index 59bc99fb90..b87e8e2ecb 100644 Binary files a/media/images/gemm-hierarchy-with-epilogue-no-labels.png and b/media/images/gemm-hierarchy-with-epilogue-no-labels.png differ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1ada4b079f..35994ba6d8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 5243906e9b..610eee0112 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -33,7 +33,7 @@ target_link_libraries( PUBLIC CUTLASS cutlass_tools_util_includes - $<$:cublas> + $<$:nvidia::cublas> gtest ) @@ -48,6 +48,9 @@ target_link_libraries( PUBLIC cutlass_test_unit_infra ) + +set(CUTLASS_INSTALL_TESTS ON CACHE BOOL "Install test executables") +set(CUTLASS_TEST_EXECUTION_ENVIRONMENT "" CACHE BOOL "Environment in which to invoke unit test executables") function(cutlass_test_unit_add_executable) @@ -65,7 +68,7 @@ function(cutlass_test_unit_add_executable) PRIVATE cutlass_test_unit_infra cutlass_test_unit_infra_lib - ) + ) string(REGEX REPLACE cutlass_ "" NAME_STEM ${NAME}) @@ -74,12 +77,19 @@ function(cutlass_test_unit_add_executable) add_custom_target( ${NAME_STEM} COMMAND - $ + ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ DEPENDS ${NAME} ) - # message(STATUS "cutlass_test_unit_add_executable(${NAME} c${NAME_STEM} ${NAME_STEM})") + if (CUTLASS_INSTALL_TESTS) + + install( + TARGETS ${NAME} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + + endif() endfunction() diff --git a/test/unit/common/cutlass_unit_test.h b/test/unit/common/cutlass_unit_test.h index ddbd186b69..81908265fa 100644 --- a/test/unit/common/cutlass_unit_test.h +++ b/test/unit/common/cutlass_unit_test.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 7a6aced023..0c548bdf86 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -64,37 +64,23 @@ void FilterArchitecture() { /// Maximum compute capability for which the kernels are enabled int max_compute_capability; - - /// If true, architecture is assumed to be silicon - bool silicon; - } test_filters[] = { - { "SM50*", 50, kMaxDevice, true}, - { "SM60*", 60, kMaxDevice, true}, - { "SM61*", 61, kMaxDevice, true}, - { "SM70*", 70, 75, true}, - { "SM75*", 75, kMaxDevice, true}, + { "SM50*", 50, kMaxDevice}, + { "SM60*", 60, kMaxDevice}, + { "SM61*", 61, kMaxDevice}, + { "SM70*", 70, 75}, + { "SM75*", 75, kMaxDevice}, + { "SM80*", 80, kMaxDevice}, { 0, 0, false } }; - bool running_on_silicon = false; - for (int i = 0; test_filters[i].filter; ++i) { - if (deviceMajorMinor == test_filters[i].min_compute_capability) { - running_on_silicon = test_filters[i].silicon; - break; - } - } - // Set negative test filters std::stringstream ss; ss << "-"; for (int i = 0, j = 0; test_filters[i].filter; ++i) { - if (!running_on_silicon && deviceMajorMinor != test_filters[i].min_compute_capability) { - ss << (j++ ? ":" : "") << test_filters[i].filter; - } - else if (deviceMajorMinor < test_filters[i].min_compute_capability || + if (deviceMajorMinor < test_filters[i].min_compute_capability || deviceMajorMinor > test_filters[i].max_compute_capability) { ss << (j++ ? ":" : "") << test_filters[i].filter; diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index a7d0e21165..d72f42fb03 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -24,6 +24,8 @@ cutlass_test_unit_add_executable( cutlass_test_unit_core array.cu half.cu + bfloat16.cu + tfloat32.cu complex.cu predicate_vector.cu tensor_ref.cu diff --git a/test/unit/core/array.cu b/test/unit/core/array.cu index 72f5b5a833..5a8cc855b0 100644 --- a/test/unit/core/array.cu +++ b/test/unit/core/array.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -228,6 +228,14 @@ TEST(Array, Float16x8) { } #endif +TEST(Array, FloatBF16x8) { + TestArray().run(); +} + +TEST(Array, FloatTF32x4) { + TestArray().run(); +} + TEST(Array, Float32x4) { TestArray().run(); } diff --git a/test/unit/core/bfloat16.cu b/test/unit/core/bfloat16.cu new file mode 100644 index 0000000000..9fa99ebb7f --- /dev/null +++ b/test/unit/core/bfloat16.cu @@ -0,0 +1,209 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void convert_bf16_f32(cutlass::bfloat16_t *output, float const *input, int N) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < N) { + output[tid] = static_cast(input[tid]); + } +} + +__global__ void convert_and_pack_bf16(cutlass::bfloat16_t *output, float const *input, int N) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid * 2 < N) { + + cutlass::NumericArrayConverter convert; + + cutlass::Array *dst_ptr = + reinterpret_cast *>(output + tid * 2); + + cutlass::Array const *src_ptr = + reinterpret_cast const *>(input + tid * 2); + + *dst_ptr = convert(*src_ptr); + } +} + +TEST(bfloat16_t, device_conversion) { + using T = cutlass::bfloat16_t; + using S = float; + + int const N = 256; + + cutlass::HostTensor destination({N, 1}); + cutlass::HostTensor source({N, 1}); + + for (int i = 0; i < N; ++i) { + source.at({i, 0}) = float(i - 128); + destination.at({i, 0}) = T(0); + } + + source.sync_device(); + destination.sync_device(); + + convert_bf16_f32<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N); + + ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error."; + + destination.sync_host(); + + int errors = 0; + for (int i = 0; i < N; ++i) { + T got = destination.at({i, 0}); + S expected = source.at({i, 0}); + + if (S(got) != expected) { + ++errors; + if (errors < 10) { + std::cerr << "Basic conversion error - [" << i << "] - got " << got << ", expected " << expected << "\n"; + } + } + + destination.at({i, 0}) = T(0); + } + + destination.sync_device(); + + convert_and_pack_bf16<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N); + + ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error."; + + destination.sync_host(); + + for (int i = 0; i < N; ++i) { + T got = destination.at({i, 0}); + S expected = source.at({i, 0}); + + if (S(got) != expected) { + ++errors; + if (errors < 10) { + std::cerr << "Convert and pack error - [" << i << "] - got " << got << ", expected " << expected << "\n"; + } + } + } + + EXPECT_EQ(errors, 0); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Host +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(bfloat16_t, host_conversion) { + for (int i = -128; i < 128; ++i) { + float f = static_cast(i); + + cutlass::bfloat16_t x = static_cast(i); + cutlass::bfloat16_t y = static_cast(f); + + EXPECT_TRUE(static_cast(x) == i); + EXPECT_TRUE(static_cast(y) == f); + } + + // Try out user-defined literals + EXPECT_TRUE(cutlass::bfloat16_t(7) == 7_bf16); + EXPECT_TRUE(7 == static_cast(7_bf16)); +} + +TEST(bfloat16_t, host_arithmetic) { + + for (int i = -100; i < 100; ++i) { + for (int j = -100; j < 100; ++j) { + + cutlass::bfloat16_t x = static_cast(i); + cutlass::bfloat16_t y = static_cast(j); + + EXPECT_TRUE(static_cast(x + y) == (i + j)); + } + } +} + +TEST(bfloat16_t, host_round) { + + struct { + uint32_t f32_bits; + uint16_t expected; + } tests[] = { + {0x40040000, 0x4004}, // M=0, R=0, S=0 => rtz + {0x40048000, 0x4004}, // M=0, R=1, S=0 => rtz + {0x40040001, 0x4004}, // M=0, R=1, S=1 => +inf + {0x4004c000, 0x4005}, // M=0, R=1, S=1 => +inf + {0x4004a000, 0x4005}, // M=0, R=1, S=1 => +inf + {0x40050000, 0x4005}, // M=1, R=0, S=0 => rtz + {0x40054000, 0x4005}, // M=1, R=0, S=1 => rtz + {0x40058000, 0x4006}, // M=1, R=1, S=0 => +inf + {0x40058001, 0x4006}, // M=1, R=1, S=1 => +inf + {0x7f800000, 0x7f80}, // +inf + {0xff800000, 0xff80}, // -inf + {0x7fffffff, 0x7fff}, // canonical NaN + {0x7ff00001, 0x7fff}, // NaN -> canonical NaN + {0xfff00010, 0x7fff}, // Nan -> canonical NaN + {0, 0} + }; + + bool running = true; + for (int i = 0; running; ++i) { + + float f32 = reinterpret_cast(tests[i].f32_bits); + + cutlass::bfloat16_t bf16 = cutlass::bfloat16_t(f32); + + bool passed = (tests[i].expected == bf16.raw()); + + EXPECT_TRUE(passed) + << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits + << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << bf16.raw(); + + if (!tests[i].f32_bits) { + running = false; + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Device +// +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/complex.cu b/test/unit/core/complex.cu index 946e2f262b..9f70708d37 100644 --- a/test/unit/core/complex.cu +++ b/test/unit/core/complex.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/functional.cu b/test/unit/core/functional.cu index 2bdbb5e093..ab843154ef 100644 --- a/test/unit/core/functional.cu +++ b/test/unit/core/functional.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -347,13 +347,13 @@ TEST(Functional, divides_f16x17) { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -void Functional_multiply_add_f16xN() { +template +void Functional_multiply_add_TxN() { - using Element = cutlass::Array; + using Element = cutlass::Array; using Operator = cutlass::multiply_add; - using Tensor = cutlass::HostTensor; + using Tensor = cutlass::HostTensor; Tensor D({1, kN}); Tensor A({1, kN}); @@ -361,10 +361,10 @@ void Functional_multiply_add_f16xN() { Tensor C({1, kN}); for (int i = 0; i < kN; ++i) { - A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); - B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); - C.host_data()[i] = cutlass::half_t((i * 3 + 11) % 11); - D.host_data()[i] = cutlass::half_t(0); + A.host_data()[i] = T((i * 2 + 1) % 5); + B.host_data()[i] = T((i * 4 + 8) % 7); + C.host_data()[i] = T((i * 3 + 11) % 11); + D.host_data()[i] = T(0); } D.sync_device(); @@ -399,12 +399,25 @@ void Functional_multiply_add_f16xN() { EXPECT_TRUE(some_d_nonzero); } +///////////////////////////////////////////////////////////////////////////////////////////////// + TEST(Functional, multiply_add_f16x16) { - Functional_multiply_add_f16xN<16>(); + Functional_multiply_add_TxN(); } TEST(Functional, multiply_add_f16x17) { - Functional_multiply_add_f16xN<17>(); + Functional_multiply_add_TxN(); } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Functional, multiply_add_bf16x16) { + Functional_multiply_add_TxN(); +} + +TEST(Functional, multiply_add_bf16x17) { + Functional_multiply_add_TxN(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/core/half.cu b/test/unit/core/half.cu index a0dcd96698..be5e9b433d 100644 --- a/test/unit/core/half.cu +++ b/test/unit/core/half.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/matrix_coord.cu b/test/unit/core/matrix_coord.cu index 676bd2c03f..841d4cb72a 100644 --- a/test/unit/core/matrix_coord.cu +++ b/test/unit/core/matrix_coord.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index ea062b737d..5f8f383987 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/predicate_vector.cu b/test/unit/core/predicate_vector.cu index 17de2cd2d4..f9a0675c01 100644 --- a/test/unit/core/predicate_vector.cu +++ b/test/unit/core/predicate_vector.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/tensor_ref.cu b/test/unit/core/tensor_ref.cu index aa8a5633e3..6bedddc577 100644 --- a/test/unit/core/tensor_ref.cu +++ b/test/unit/core/tensor_ref.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/tensor_view.cu b/test/unit/core/tensor_view.cu index b660b3d67b..b35fc426b8 100644 --- a/test/unit/core/tensor_view.cu +++ b/test/unit/core/tensor_view.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/test_unit_core.cpp b/test/unit/core/test_unit_core.cpp index 3823bd76e1..a6dfbf4bbc 100644 --- a/test/unit/core/test_unit_core.cpp +++ b/test/unit/core/test_unit_core.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/core/tfloat32.cu b/test/unit/core/tfloat32.cu new file mode 100644 index 0000000000..32155df7c4 --- /dev/null +++ b/test/unit/core/tfloat32.cu @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/util/device_memory.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Host +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(tfloat32_t, host_conversion) { + for (int i = -1024; i < 1024; ++i) { + float f = static_cast(i); + + cutlass::tfloat32_t x = static_cast(i); + cutlass::tfloat32_t y = static_cast(f); + + EXPECT_TRUE(static_cast(x) == i); + EXPECT_TRUE(static_cast(y) == f); + } + + // Try out user-defined literals + EXPECT_TRUE(cutlass::tfloat32_t(7) == 7_tf32); + EXPECT_TRUE(7 == static_cast(7_tf32)); +} + +TEST(tfloat32_t, host_arithmetic) { + + for (int i = -100; i < 100; ++i) { + for (int j = -100; j < 100; ++j) { + + cutlass::tfloat32_t x = static_cast(i); + cutlass::tfloat32_t y = static_cast(j); + + EXPECT_TRUE(static_cast(x + y) == (i + j)); + } + } +} + +TEST(tfloat32_t, host_round_nearest) { + + struct { + uint32_t f32_bits; + uint32_t expected; + } tests[] = { + {0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz + {0x40001000, 0x40000000}, // M=0, R=1, S=0 => rtz + {0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz + {0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf + {0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz + {0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz + {0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf + {0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf + {0x7f800000, 0x7f800000}, // +inf + {0xff800000, 0xff800000}, // -inf + {0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN + {0x7f800001, 0x7fffffff}, // NaN to canonical NaN + {0xff800001, 0x7fffffff}, // NaN to canonical NaN + {0, 0} + }; + + bool running = true; + for (int i = 0; running; ++i) { + + float f32 = reinterpret_cast(tests[i].f32_bits); + + cutlass::NumericConverter< + cutlass::tfloat32_t, + float, + cutlass::FloatRoundStyle::round_to_nearest> converter; + + cutlass::tfloat32_t tf32 = converter(f32); + + // note, we must explicitly truncate the low-order bits since they are not defined in TF32. + if (cutlass::isfinite(tf32)) { + tf32.storage &= 0xffffe000; + } + + bool passed = (tests[i].expected == tf32.raw()); + + EXPECT_TRUE(passed) + << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits + << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw(); + + if (!tests[i].f32_bits) { + running = false; + } + } +} + +namespace test { +namespace core { + +__global__ void convert_tf32_half_ulp(cutlass::tfloat32_t *out, float const *in) { + + cutlass::NumericConverter< + cutlass::tfloat32_t, + float, + cutlass::FloatRoundStyle::round_half_ulp_truncate> convert; + + *out = convert(*in); +} + +} +} + + +TEST(tfloat32_t, host_round_half_ulp) { + + struct { + uint32_t f32_bits; + uint32_t expected; + } tests[] = { + {0x40001fff, 0x40002000}, + {0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz + {0x40001000, 0x40002000}, // M=0, R=1, S=0 => rtz - this difers from RNE + {0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz + {0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf + {0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz + {0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz + {0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf + {0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf + {0x7f800000, 0x7f800000}, // +inf + {0xff800000, 0xff800000}, // -inf + {0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN + {0x7f800001, 0x7f800001}, // NaN to NaN + {0xff800001, 0xff800001}, // NaN to NaN + {0, 0} + }; + + cutlass::NumericConverter< + cutlass::tfloat32_t, + float, + cutlass::FloatRoundStyle::round_half_ulp_truncate> convert; + + bool running = true; + for (int i = 0; running; ++i) { + + float f32 = reinterpret_cast(tests[i].f32_bits); + + cutlass::tfloat32_t tf32 = convert(f32); + + // note, for this test, we must explicitly truncate the low-order bits since they are not + // defined in TF32. + if (cutlass::isfinite(tf32)) { + tf32.storage &= 0xffffe000; + } + + bool passed = (tests[i].expected == tf32.raw()); + + EXPECT_TRUE(passed) + << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits + << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw(); + + if (!tests[i].f32_bits) { + running = false; + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Device +// +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/CMakeLists.txt b/test/unit/epilogue/CMakeLists.txt old mode 100644 new mode 100755 index 1948a8ab73..9de2d56edb --- a/test/unit/epilogue/CMakeLists.txt +++ b/test/unit/epilogue/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -38,4 +38,4 @@ add_custom_target( test_unit_epilogue_thread test_unit_epilogue_warp test_unit_epilogue_threadblock - ) + ) diff --git a/test/unit/epilogue/thread/CMakeLists.txt b/test/unit/epilogue/thread/CMakeLists.txt index b719784cb4..9b04f7752a 100644 --- a/test/unit/epilogue/thread/CMakeLists.txt +++ b/test/unit/epilogue/thread/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -23,4 +23,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_epilogue_thread linear_combination.cu - ) + linear_combination_planar_complex.cu +) diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index cf0d1ea56c..6518e98738 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/thread/linear_combination_planar_complex.cu b/test/unit/epilogue/thread/linear_combination_planar_complex.cu new file mode 100644 index 0000000000..89d1be5e02 --- /dev/null +++ b/test/unit/epilogue/thread/linear_combination_planar_complex.cu @@ -0,0 +1,280 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace epilogue { +namespace thread { + +using FunctorPlanarComplexF32F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float>; + +__global__ void epilogue_thread_functor_planar_complex_f32_f32( + float *output_ptr, + float const *accum_ptr, + float const *source_ptr, + typename FunctorPlanarComplexF32F32::Params params) { + + FunctorPlanarComplexF32F32 linear_combination_op(params); + + auto accum = *reinterpret_cast const *>(accum_ptr); + auto source = *reinterpret_cast const *>(source_ptr); + + *reinterpret_cast*>(output_ptr) = linear_combination_op(accum, source); +} + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_planar_complex, f32) { + + using Element = float; + using ElementOutput = float; + int const kCount = 4; + + using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kCount, + Element, + Element>; + + cutlass::complex alpha(Element(2), Element(1)); + cutlass::complex beta(Element(1), Element(-1)); + + typename Functor::Params params(alpha, beta); + + Functor linear_combination_op(params); + + cutlass::ArrayPlanarComplex source; + cutlass::ArrayPlanarComplex accum; + + // Define arbitrary inputs + for (int i = 0; i < kCount; ++i) { + accum.real[i] = Element(i * 2); + accum.imag[i] = Element((i * 3 % 6) - 3); + source.real[i] = ElementOutput((i * 7 % 9) - 4); + source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); + } + + cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + + // Verify each result + for (int i = 0; i < kCount; ++i) { + + cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + + beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); + + cutlass::complex got(destination.real[i], destination.imag[i]); + + EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); + EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); + EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace epilogue { +namespace thread { + +using FunctorPlanarComplexF16F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 4, + float, + float>; + +__global__ void epilogue_thread_functor_planar_complex_f16_f32( + cutlass::half_t *output_ptr, + float const *accum_ptr, + cutlass::half_t const *source_ptr, + typename FunctorPlanarComplexF16F32::Params params, + int N) { + + FunctorPlanarComplexF16F32 linear_combination_op(params); + + + auto accum = *reinterpret_cast const *>(accum_ptr); + auto source = *reinterpret_cast const *>(source_ptr); + + #pragma unroll 1 + for (int n = 0; n < N; ++n) { + source = linear_combination_op(accum, source); + } + + *reinterpret_cast*>(output_ptr) = source; +} + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_planar_complex, f16_f32) { + + using Element = float; + using ElementOutput = cutlass::half_t; + int const kCount = 4; + + using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kCount, + Element, + Element>; + + cutlass::complex alpha(Element(2), Element(1)); + cutlass::complex beta(Element(1), Element(-1)); + + typename Functor::Params params(alpha, beta); + + Functor linear_combination_op(params); + + cutlass::ArrayPlanarComplex source; + cutlass::ArrayPlanarComplex accum; + + // Define arbitrary inputs + for (int i = 0; i < kCount; ++i) { + accum.real[i] = Element(i * 2); + accum.imag[i] = Element((i * 3 % 6) - 3); + source.real[i] = ElementOutput((i * 7 % 9) - 4); + source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); + } + + cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + + // Verify each result + for (int i = 0; i < kCount; ++i) { + + cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + + beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); + + cutlass::complex got(destination.real[i], destination.imag[i]); + + EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); + EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); + EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace epilogue { +namespace thread { + +using FunctorPlanarComplexF16F16 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 4, + cutlass::half_t, + cutlass::half_t>; + +__global__ void epilogue_thread_functor_planar_complex_f16_f16( + cutlass::half_t *output_ptr, + cutlass::half_t const *accum_ptr, + cutlass::half_t const *source_ptr, + typename FunctorPlanarComplexF16F16::Params params, + int N) { + + FunctorPlanarComplexF16F16 linear_combination_op(params); + + auto accum = *reinterpret_cast const *>(accum_ptr); + auto source = *reinterpret_cast const *>(source_ptr); + + #pragma unroll 1 + for (int n = 0; n < N; ++n) { + source = linear_combination_op(accum, source); + } + + *reinterpret_cast*>(output_ptr) = source; +} + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_planar_complex, f16_f16) { + + using Element = cutlass::half_t; + using ElementOutput = cutlass::half_t; + int const kCount = 8; + + using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kCount, + Element, + Element>; + + cutlass::complex alpha(Element(2), Element(1)); + cutlass::complex beta(Element(1), Element(-1)); + + typename Functor::Params params(alpha, beta); + + Functor linear_combination_op(params); + + cutlass::ArrayPlanarComplex source; + cutlass::ArrayPlanarComplex accum; + + // Define arbitrary inputs + for (int i = 0; i < kCount; ++i) { + accum.real[i] = Element(i * 2); + accum.imag[i] = Element((i * 3 % 6) - 3); + source.real[i] = ElementOutput((i * 7 % 9) - 4); + source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); + } + + cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + + // Verify each result + for (int i = 0; i < kCount; ++i) { + + cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + + beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); + + cutlass::complex got(destination.real[i], destination.imag[i]); + + EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); + EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); + EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/CMakeLists.txt b/test/unit/epilogue/threadblock/CMakeLists.txt old mode 100644 new mode 100755 index 4785d734f0..cb8b7a62d5 --- a/test/unit/epilogue/threadblock/CMakeLists.txt +++ b/test/unit/epilogue/threadblock/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -30,4 +30,5 @@ cutlass_test_unit_add_executable( epilogue_tensor_op.cu epilogue_volta_tensor_op.cu epilogue_wmma_tensor_op_sm70.cu - ) + epilogue_planar_complex.cu +) diff --git a/test/unit/epilogue/threadblock/epilogue_planar_complex.cu b/test/unit/epilogue/threadblock/epilogue_planar_complex.cu new file mode 100644 index 0000000000..76b70f5069 --- /dev/null +++ b/test/unit/epilogue/threadblock/epilogue_planar_complex.cu @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +// Tensor Op +#include "cutlass/gemm/warp/default_mma_tensor_op.h" + +// Volta Tensor Op +#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" +#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" + +// Simt +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" + +// Epilogue components + +#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "testbed_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_tensor_op_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + Element, LayoutA, + Element, LayoutB, + ElementAccumulator, cutlass::layout::RowMajor + >::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f32_tensor_op_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + Element, LayoutA, + Element, LayoutB, + ElementAccumulator, cutlass::layout::RowMajor + >::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f16_tensor_op_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + Element, LayoutA, + Element, LayoutB, + ElementAccumulator, cutlass::layout::RowMajor + >::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_volta_tensor_op_64x64_32x32x4) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; + using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + cutlass::gemm::GemmShape<16, 16, 4>, + 32, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::layout::RowMajor, + cutlass::arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> + >; + + using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + ElementAccumulator, + cutlass::layout::RowMajor, + Policy + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f32_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using Element = float; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using ElementOutput = Element; + using ElementAccumulator = Element; + using ElementCompute = Element; + + using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<4, 8>, + cutlass::layout::RowMajorInterleaved<2>, + cutlass::gemm::GemmShape<4, 4, 1> + > + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaSimt, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f64_64x64_16x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + using Element = double; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using ElementOutput = Element; + using ElementAccumulator = Element; + using ElementCompute = Element; + + using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<4, 8>, + cutlass::layout::RowMajorInterleaved<2>, + cutlass::gemm::GemmShape<4, 4, 1> + > + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaSimt, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/epilogue_simt.cu b/test/unit/epilogue/threadblock/epilogue_simt.cu index 0d4f9ae5bc..935a812426 100644 --- a/test/unit/epilogue/threadblock/epilogue_simt.cu +++ b/test/unit/epilogue/threadblock/epilogue_simt.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu b/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu index 3dd0fdd6ca..25cd8933c5 100644 --- a/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu +++ b/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu b/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu index 0151f1d8e8..fcc8426ca3 100644 --- a/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu +++ b/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/epilogue_tensor_op.cu b/test/unit/epilogue/threadblock/epilogue_tensor_op.cu index 6662213d20..db8e68a3a5 100644 --- a/test/unit/epilogue/threadblock/epilogue_tensor_op.cu +++ b/test/unit/epilogue/threadblock/epilogue_tensor_op.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -34,6 +34,7 @@ #include "cutlass/half.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/gemm/warp/default_mma_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" @@ -45,6 +46,541 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_32x32x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x64_64x32x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x128_32x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_32x128_32x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x32_64x32x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 32, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_256x128_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x256_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 256, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_64x64x16) { // @@ -54,11 +590,70 @@ TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_64x64x16) { using ElementOutput = int8_t; using ElementAccumulator = int; using ElementCompute = float; - int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_32x3216) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = int8_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; int const kPartitionsK = 1; using Shape = cutlass::gemm::GemmShape<64, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using Element = ElementOutput; using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< @@ -104,7 +699,7 @@ TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_64x64x16) { EXPECT_TRUE(passed); } -TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_32x3216) { +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x16) { // // Define the warp-level matrix multiply @@ -113,11 +708,11 @@ TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_32x3216) { using ElementOutput = int8_t; using ElementAccumulator = int; using ElementCompute = float; - int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; int const kPartitionsK = 1; - using Shape = cutlass::gemm::GemmShape<64, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using Shape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using Element = ElementOutput; using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< @@ -163,7 +758,7 @@ TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_32x3216) { EXPECT_TRUE(passed); } -TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x16) { +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_64x64x16) { // // Define the warp-level matrix multiply @@ -1980,6 +2575,249 @@ TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x64_64x32x8) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x64_32x32x4) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x64_64x32x4) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x128_32x64x4) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x128_32x64x4) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x128_64x64x8) { diff --git a/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu b/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu index 99b7ae1175..88fa98cf03 100644 --- a/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu +++ b/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu b/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu index 3d1fdf0dd3..24752a1df0 100644 --- a/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu +++ b/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/output_tile_threadmap.cu b/test/unit/epilogue/threadblock/output_tile_threadmap.cu index 549e6e4d40..6e6e96e71f 100644 --- a/test/unit/epilogue/threadblock/output_tile_threadmap.cu +++ b/test/unit/epilogue/threadblock/output_tile_threadmap.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/predicated_tile_iterator.cu b/test/unit/epilogue/threadblock/predicated_tile_iterator.cu index 7fcdd8e463..40874f7bf1 100644 --- a/test/unit/epilogue/threadblock/predicated_tile_iterator.cu +++ b/test/unit/epilogue/threadblock/predicated_tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/testbed.h b/test/unit/epilogue/threadblock/testbed.h index c888b9a2d2..1dc9baa317 100644 --- a/test/unit/epilogue/threadblock/testbed.h +++ b/test/unit/epilogue/threadblock/testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/threadblock/testbed_planar_complex.h b/test/unit/epilogue/threadblock/testbed_planar_complex.h new file mode 100644 index 0000000000..6afa603293 --- /dev/null +++ b/test/unit/epilogue/threadblock/testbed_planar_complex.h @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_planar_complex_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + int64_t imaginary_stride_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + int64_t imaginary_stride_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int64_t imaginary_stride_accum, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D_real( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params_D, + ptr_D + imaginary_stride_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C_real( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params_C, + ptr_C + imaginary_stride_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + // + // Load accumulators + // + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + + accumulator_iterator.load(accumulators.real); + accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop so assembly is clearly visible + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpiloguePlanarComplexTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + + using ComplexElementOutput = cutlass::complex; + using ComplexElementAccumulator = cutlass::complex; + using ComplexElementCompute = cutlass::complex; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensorPlanarComplex accumulator_tensor; + cutlass::HostTensorPlanarComplex source_tensor; + cutlass::HostTensorPlanarComplex output_tensor; + +public: + + // + // Methods + // + + EpiloguePlanarComplexTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + #if 1 + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + #else + + cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); + + #endif + } + + bool run_all() { + + cutlass::complex alpha_values[3]; + + alpha_values[0] = cutlass::complex(1, 0); + alpha_values[1] = cutlass::complex(0, 0); + alpha_values[2] = cutlass::complex(2.25f, -0.5f); + + cutlass::complex beta_values[3]; + + beta_values[0] = cutlass::complex(0, 0); + beta_values[1] = cutlass::complex(1, 0); + beta_values[2] = cutlass::complex(0.5f, -2.25f); + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + cutlass::MatrixCoord problem_size( + quantized_size.row() - m_idx * 3, + quantized_size.column() - n_idx * Epilogue::kElementsPerAccess + ); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = run(problem_size, {alpha, beta}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + output_tensor.imaginary_stride(), + params_C, + source_tensor.device_data(), + source_tensor.imaginary_stride(), + output_params, + problem_size, + accumulator_tensor.device_view_real(), + accumulator_tensor.imaginary_stride() + ); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ComplexElementOutput got = output_tensor.at(coord); + + ComplexElementOutput expected = default_output; + + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + ComplexElementOutput src = source_tensor.at(coord); + + ComplexElementCompute tmp = + output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ComplexElementCompute(src.real(), src.imag()); + + expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + + + std::cout << "Incorrect result for problem(" + << problem_size.row() << ", " + << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; + + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + + std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/warp/CMakeLists.txt b/test/unit/epilogue/warp/CMakeLists.txt index 89d693e3e8..dbd7ee65b5 100644 --- a/test/unit/epilogue/warp/CMakeLists.txt +++ b/test/unit/epilogue/warp/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu b/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu index 4881e5cc90..9e94616f72 100644 --- a/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu +++ b/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu b/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu index a89ec49c89..3522c9e925 100644 --- a/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu +++ b/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu b/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu index a3a406dc7c..4931d93718 100644 --- a/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu +++ b/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/CMakeLists.txt b/test/unit/gemm/CMakeLists.txt index 4d42c000fe..4ac245716f 100644 --- a/test/unit/gemm/CMakeLists.txt +++ b/test/unit/gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 4750dd8b0d..f536b1136f 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -23,9 +23,71 @@ cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device + BATCH_SOURCES ON + BATCH_SIZE 4 + + gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu + gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu + gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu + + gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm80.cu + gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu + gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu + gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu + + gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu + gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu + + gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu + gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu + + gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu + gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu + + gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu + gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu + gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu + gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu + gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu + gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu + gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu + gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu + gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu + gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu + gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu + gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu + gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu + gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu + gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu + + gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu + gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu + + simt_sgemm_nt_sm80.cu + simt_sgemm_tn_sm80.cu + + gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu + gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu + gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu + gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu + gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu + gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu + gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu + + gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu + gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu + + gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu + gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu + + gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu + gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu + + gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu - gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu + gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu + gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu @@ -90,6 +152,7 @@ cutlass_test_unit_add_executable( simt_zgemm_tn_sm50.cu simt_zgemm_tt_sm50.cu + gemm_splitk_serial_tensor_op_sm75.cu gemm_splitk_tensor_op_sm75.cu gemm_splitk_tensor_op_sm70.cu gemm_splitk_simt_sm50.cu @@ -144,6 +207,5 @@ cutlass_test_unit_add_executable( gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu -) - +) diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu index fb7fe985df..fc887bce36 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -62,7 +62,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -84,7 +84,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -106,7 +106,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -128,7 +128,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -150,7 +150,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -172,7 +172,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..d8b9072736 --- /dev/null +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu @@ -0,0 +1,373 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 1024>, + cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 1024>, + cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 512>, + cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 512>, + cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 512>, + cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu index 099c463986..03f0b75251 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc >; @@ -104,7 +104,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -135,7 +135,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -166,7 +166,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -197,7 +197,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -228,7 +228,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu index f88a73d9a3..77777a66f3 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -62,7 +62,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -84,7 +84,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -106,7 +106,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -128,7 +128,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -150,7 +150,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -172,7 +172,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..f6862b0d2d --- /dev/null +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu @@ -0,0 +1,374 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x1024_64x64x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x1024_64x64x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x1024_64x64x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x1024_64x64x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x1024_64x64x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x1024_32x64x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 1024>, + cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x1024_64x32x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 1024>, + cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x1024_32x32x1024, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 512>, + cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 512>, + cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 512>, + cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, + false, cutlass::arch::OpXorPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu index 1254a19b34..b4fb7eba02 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc >; @@ -104,7 +104,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -135,7 +135,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -166,7 +166,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -197,7 +197,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; @@ -228,7 +228,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; diff --git a/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..3da9cdbb58 --- /dev/null +++ b/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, + cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..b0dbbdc856 --- /dev/null +++ b/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu @@ -0,0 +1,337 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu b/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu new file mode 100644 index 0000000000..b15af10764 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu @@ -0,0 +1,253 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Operands data type: complex +// Rounding: float -> tfloat32_t (half_ulp_truncate) +// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) +// Math instruction: MMA.1688.F32.TF32 +// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) +// Output data type: complex +///////////////////////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { + + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { + + using Element = cutlass::complex;; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { + + using Element = cutlass::complex;; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu b/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu new file mode 100644 index 0000000000..cec5ce60a5 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Operands data type: complex +// Rounding: float -> tfloat32_t (round to nearest) +// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) +// Math instruction: MMA.1688.F32.TF32 +// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) +// Output data type: complex +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { + + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { + + using Element = cutlass::complex;; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { + + using Element = cutlass::complex;; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { + + using Element = cutlass::complex;; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu new file mode 100644 index 0000000000..c7df15d140 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu @@ -0,0 +1,192 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<16, 32, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu new file mode 100644 index 0000000000..5113d2f800 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<16, 32, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) { + + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu new file mode 100644 index 0000000000..427c1e0e13 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 16, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu new file mode 100644 index 0000000000..74fbc1f549 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu index b40f294536..ea3da85d52 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu index 479004e519..167949d8c6 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu index 6e42c5de21..ae72cade2f 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..858fd301fe --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu index 1ea87c43f5..2dc224ab2e 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu index 67f9598742..71f21444cf 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu index 6e07cc8c33..bb1665062e 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -320,7 +320,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -354,7 +354,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -388,7 +388,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu index 6b6d66f55e..3e8b96584f 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..cd6e48a3a2 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu @@ -0,0 +1,337 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu index c42771b987..a9f9ea9978 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -108,7 +108,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -142,7 +142,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu index d94a7f0df2..d797ed5577 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..7cf1fad244 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu index abe553224e..be764f5282 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu index ab15f1c59f..25d3e5bee8 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu index 5dd4e2f878..f7c8fb23f2 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu index 81ee6d7147..2798007695 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu similarity index 96% rename from test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu rename to test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu index 30ddd06a96..b4114ffe51 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32) ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu new file mode 100644 index 0000000000..6ca8ada8a5 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu @@ -0,0 +1,82 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) { + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu index 3f96597bcf..64b697af81 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu new file mode 100644 index 0000000000..cff5070599 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64> , + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..8a760b02ab --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + /* + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); + */ +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu index dbf02b24bc..9f2c2c542d 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -63,7 +63,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -94,7 +94,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -125,7 +125,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -156,7 +156,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -187,7 +187,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -218,7 +218,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -249,7 +249,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu index 031e226836..aa92606167 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu index 235c139699..dac3675b84 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu index 41824839b1..74434cc9fa 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -108,7 +108,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -142,7 +142,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu index 38337c6426..176112d10f 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..47e927d450 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,339 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu index d2f58b1ca4..de19ca0047 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -63,7 +63,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -94,7 +94,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -125,7 +125,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -156,7 +156,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -187,7 +187,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -218,7 +218,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -249,7 +249,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu index b5ff3b9936..0b83c6cbb7 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu index 3bfe6d8fe3..a81684241b 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x256x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x64x32_6 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -137,7 +137,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x128x32_6 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -170,7 +170,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x32_32 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -202,7 +202,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x64_32 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -234,7 +234,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x64_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -270,7 +270,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -305,7 +305,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu index 7455a1bddb..585b1df179 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu index a2374a618c..ab030e5a97 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x16_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -140,7 +140,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu index 5629dc98c8..b8fa4dad8e 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x256x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x64x32_6 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -137,7 +137,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x128x32_6 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -170,7 +170,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x32_32 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -202,7 +202,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x64_32 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -234,7 +234,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x64_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -270,7 +270,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -305,7 +305,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75_slicedk.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu similarity index 96% rename from test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75_slicedk.cu rename to test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu index d78d34e68b..358aacecd9 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75_slicedk.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32) ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu new file mode 100644 index 0000000000..957bcd2ab0 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) { + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu index 8463e9e314..7c0f3b406a 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu new file mode 100644 index 0000000000..972756bba8 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu @@ -0,0 +1,339 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x32_64x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu index 68d551a1c2..14030b1d41 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu index 6a66888f21..9a1918db44 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu index a7c61a1a4d..51a09194e4 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -319,7 +319,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -353,7 +353,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -387,7 +387,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu index 34859eddc2..74d64af70d 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu index ca63f26df0..d4bc720bca 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -74,7 +74,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x64x32_6 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x128x32_6 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -138,7 +138,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x64x32_32 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -174,7 +174,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; @@ -209,7 +209,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_ ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu index f941832da0..dd0976d9f7 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..83c5cd1479 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu index 90e44ee51f..6d78dc9a9b 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu index 05374010b7..5ea2f9ce00 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu index 3f922ebad7..0f773de4f2 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu index c4ab9f4df5..54d6229a0d 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -140,7 +140,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu index 748f64d19f..d123931e1a 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu index 037efb8238..b1286accd1 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -320,7 +320,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -354,7 +354,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -388,7 +388,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu index d7474d87af..5a511540fa 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..26f41ac2b7 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu index da55acbda2..06498afb9a 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu index 30bb55833e..e377980bbf 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..96f5dcc947 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu index 8418381c7f..0f94d589c6 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu index 2d9d41678d..2163711b84 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu new file mode 100644 index 0000000000..91095a945d --- /dev/null +++ b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface using BF16. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 4, + 4, + false, + cutlass::arch::OpMultiplyAddFastBF16 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..2108eeb4e4 --- /dev/null +++ b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,82 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu new file mode 100644 index 0000000000..64fe313c50 --- /dev/null +++ b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu @@ -0,0 +1,212 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu new file mode 100644 index 0000000000..63c765c551 --- /dev/null +++ b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu @@ -0,0 +1,212 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu new file mode 100644 index 0000000000..99303712e5 --- /dev/null +++ b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-level GEMM API for Planar Complex. +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "testbed_planar_complex.h" + +#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s884_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s884_tn : gemm_planar_complex_s884_tn_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s884_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s884_nt : gemm_planar_complex_s884_nt_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu new file mode 100644 index 0000000000..993b0b9d5a --- /dev/null +++ b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu @@ -0,0 +1,217 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-level GEMM API for Planar Complex. +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "testbed_planar_complex.h" + + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s1688_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s1688_tn : gemm_planar_complex_s1688_tn_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s1688_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kConjugate, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s1688_hc : gemm_planar_complex_s1688_hc_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s1688_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s1688_nt : gemm_planar_complex_s1688_nt_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s1688_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s1688_ch : gemm_planar_complex_s1688_ch_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu new file mode 100644 index 0000000000..25fd50cfc3 --- /dev/null +++ b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu @@ -0,0 +1,216 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-level GEMM API for Planar Complex. +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "testbed_planar_complex.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s16816_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s16816_tn : gemm_planar_complex_s16816_tn_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s16816_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kConjugate, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s16816_hc : gemm_planar_complex_s16816_hc_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s16816_nt : gemm_planar_complex_s16816_nt_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s16816_ch : gemm_planar_complex_s16816_ch_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu index 832981f9cd..4cc4068170 100644 --- a/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 64x128x128_32x64x128) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -107,7 +107,7 @@ TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x128x128_64x64x128) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -143,7 +143,7 @@ TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 256x128x128_64x64x128) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -179,7 +179,7 @@ TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x256x128_64x64x128) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..d53e3c0768 --- /dev/null +++ b/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu @@ -0,0 +1,213 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "multistage_testbed_interleaved.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 64x128x128_32x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::ColumnMajorInterleaved<64>, + cutlass::int4b_t, + cutlass::layout::RowMajorInterleaved<64>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<64>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 32, + 32, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::ColumnMajorInterleaved<64>, + cutlass::int4b_t, + cutlass::layout::RowMajorInterleaved<64>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<64>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 32, + 32, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::ColumnMajorInterleaved<64>, + cutlass::int4b_t, + cutlass::layout::RowMajorInterleaved<64>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<64>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 32, + 32, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x256x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::ColumnMajorInterleaved<64>, + cutlass::int4b_t, + cutlass::layout::RowMajorInterleaved<64>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<64>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 32, + 32, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu index a737cf5c68..983dff337f 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -97,13 +97,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -129,13 +129,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -161,13 +161,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -193,13 +193,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,13 +225,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..8dd541838f --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu @@ -0,0 +1,354 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x64x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x256x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x256_32x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 256>, + cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x256_64x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 256>, + cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x64x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x256x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif //#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu index ebef12f6f4..01a65b32a5 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -66,13 +66,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -98,13 +98,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -130,13 +130,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -162,13 +162,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,13 +194,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -226,13 +226,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x3 cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu index 165d404b76..33f3b07a2a 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128) { cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -97,13 +97,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128) { cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -129,13 +129,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128) { cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -161,13 +161,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x128_32x64x128) { cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -193,13 +193,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128) { cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,13 +225,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128) { cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..1a3f7dba85 --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu @@ -0,0 +1,357 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 256>, + cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 256>, + cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x128_32x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu index 70e69dea0f..857df472a7 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -66,13 +66,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -98,13 +98,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -130,13 +130,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -162,13 +162,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -194,13 +194,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -226,13 +226,13 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x3 cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu new file mode 100644 index 0000000000..51d182cd66 --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu new file mode 100644 index 0000000000..90fe6bcfd8 --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu index 7f4772b934..393e68bfd6 100644 --- a/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 32x64x64_16x32x64) { cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,13 +99,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -137,13 +133,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -173,13 +167,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -209,13 +201,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -245,13 +235,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -281,13 +269,11 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..c4900e489e --- /dev/null +++ b/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu @@ -0,0 +1,361 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "multistage_testbed_interleaved.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x256x64_64x64x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x64x64_64x64x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + test::gemm::device::MultistageInterleavedTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu index 4a9906b481..6ac9b71bf2 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -97,13 +97,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -129,13 +129,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -161,13 +161,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -193,13 +193,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,13 +225,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu index 1328be0b57..cc6e4c3a5d 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16 cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -96,13 +96,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -130,13 +130,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu index ce640e82a3..86a678d22b 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64) { cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -97,13 +97,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64) { cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -129,13 +129,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64) { cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -161,13 +161,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64) { cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -193,13 +193,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64) { cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -225,13 +225,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64) { cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..a86dc2442e --- /dev/null +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu @@ -0,0 +1,355 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x64_64x64x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x64_64x64x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64, { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu index b0001dbf3e..d53571a2d7 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16 cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -97,13 +97,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -131,13 +131,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<32, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,13 +165,13 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu index 6317fd7d49..024cba0a49 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -57,10 +57,9 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -77,11 +76,10 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; - + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) @@ -96,10 +94,9 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -116,15 +113,80 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); + +} ) + +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); + +} ) + ///////////////////////////////////////////////////////////////////////////////////////////////// #endif diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..2d6db336f6 --- /dev/null +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu @@ -0,0 +1,368 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "multistage_testbed.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x64x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x256x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x128_32x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x128_64x32x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x128_32x32x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x64x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x256x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu index 0f9ee12b9f..ac5757e0ee 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -96,13 +94,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -130,13 +126,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<32, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -164,13 +158,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu index 2a6f6da35d..93642e64b6 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -57,10 +57,9 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) @@ -76,10 +75,9 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) @@ -95,10 +93,9 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -115,11 +112,52 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; - + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; + + test::gemm::device::Testbed testbed; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 32 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; + + test::gemm::device::Testbed testbed; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..197e69b710 --- /dev/null +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu @@ -0,0 +1,368 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "multistage_testbed.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x64x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x256x128_64x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x128_32x64x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x128_64x32x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x128_32x32x128, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x64x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x256x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 64 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu index c756def25b..719e2ac760 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -65,13 +65,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -97,13 +95,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -131,13 +127,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<32, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,13 +159,11 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu b/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu new file mode 100644 index 0000000000..e7a01bed61 --- /dev/null +++ b/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmSplitKSerial_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + static const int kStages = 2; + + static const int kAlignmentA = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ElementA, + ElementB, + ElementOutput, + ElementAccumulator>::kAlignmentA; + + static const int kAlignmentB = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ElementA, + ElementB, + ElementOutput, + ElementAccumulator>::kAlignmentB; + + static const bool kSplitKSerial = true; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, + cutlass::layout::ColumnMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + kStages, + kAlignmentA, + kAlignmentB, + kSplitKSerial + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/test/unit/gemm/device/gemm_splitk_simt_sm50.cu b/test/unit/gemm/device/gemm_splitk_simt_sm50.cu index c35535dd6b..39b5f10a70 100644 --- a/test/unit/gemm/device/gemm_splitk_simt_sm50.cu +++ b/test/unit/gemm/device/gemm_splitk_simt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu b/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu index c350b063e3..42e991ed09 100644 --- a/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu +++ b/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -42,6 +42,7 @@ #include "testbed_splitk.h" +// These operators are assert(0) unless extended PTX is used. #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu b/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu index d78164e090..3381f1703a 100644 --- a/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu +++ b/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -42,6 +42,7 @@ #include "testbed_splitk.h" +// These operators are assert(0) unless extended PTX is used. #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..78c6e8657e --- /dev/null +++ b/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,549 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x64x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x128x16_32x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x64x16_64x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x64x16_32x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..11af88897f --- /dev/null +++ b/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,549 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x64x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x128x16_32x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x64x16_64x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x64x16_32x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..a28101f3d5 --- /dev/null +++ b/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x64x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x128x16_32x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x64x16_64x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x64x16_32x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..a1a0fd7e31 --- /dev/null +++ b/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,550 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x256x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x64x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x128x16_64x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x128x16_32x64x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x64x16_64x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x64x16_32x32x16) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + cutlass::tfloat32_t, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu index 4d31c08962..a63163680b 100644 --- a/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -71,7 +71,7 @@ TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -102,7 +102,7 @@ TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -136,7 +136,7 @@ TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -170,7 +170,7 @@ TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x1 ElementAccumulator, ElementAccumulator >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..e32441941d --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu @@ -0,0 +1,193 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf32n_cf32t_cf32n_tensor_op_f32, 64x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf32n_cf32h_cf32n_tensor_op_f32, 64x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf32h_cf32t_cf32n_tensor_op_f32, 64x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf32h_cf32c_cf32n_tensor_op_f32, 64x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu b/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu new file mode 100644 index 0000000000..301cce7851 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64n_cf64t_cf64n_tensor_op_f64_gaussian, 64x64x32_32x32x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64n_cf64h_cf64n_tensor_op_f64_gaussian, 64x64x32_32x32x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64h_cf64t_cf64n_tensor_op_f64_gaussian, 64x32x32_32x16x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64h_cf64c_cf64n_tensor_op_f64_gaussian, 64x64x32_32x16x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu new file mode 100644 index 0000000000..df28110a33 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64n_cf64t_cf64n_tensor_op_f64, 64x64x32_32x32x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64n_cf64h_cf64n_tensor_op_f64, 64x64x32_32x32x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64h_cf64t_cf64n_tensor_op_f64, 64x64x32_32x32x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_cf64h_cf64c_cf64n_tensor_op_f64, 64x64x32_32x32x32) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmUniversal< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 3, + 1, + 1, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..e7b4405a08 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm80.cu @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 2>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + + +TEST(SM75_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32_updated_batch_count) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 2, + 1, + 1>; + + EXPECT_TRUE(test::gemm::device::TestGemmUniversal( + {128, 128, 2}, + cutlass::gemm::GemmUniversalMode::kGemm, + 15)); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/multistage_testbed.h b/test/unit/gemm/device/multistage_testbed.h new file mode 100644 index 0000000000..bdc4b77081 --- /dev/null +++ b/test/unit/gemm/device/multistage_testbed.h @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageTestbed { + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = + typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, + -scope, 0); + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), + view.capacity()); + } else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor + tensor_A(problem_size.mk()); + + cutlass::HostTensor + tensor_B(problem_size.kn()); + + cutlass::HostTensor + tensor_C(problem_size.mn()); + + cutlass::HostTensor + tensor_D(problem_size.mn()); + + cutlass::HostTensor + reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, tensor_A.device_ref(), tensor_B.device_ref(), + tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, + reference_D.host_ref(), ElementAccumulator(0)); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + + fname << "error_Gemm_device_" << problem_size.m() << "x" + << problem_size.n() << "x" << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN + << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM + << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK + << ".txt"; + + std::ofstream file(fname.str()); + + file << "problem: " << problem_size << ", alpha: " << alpha + << ", beta: " << beta << "\n\n"; + + file << "A =\n" + << tensor_A.host_view() << "\nB =\n" + << tensor_B.host_view() << "\nC =\n" + << tensor_C.host_view() << "\n\nReference =\n" + << reference_D.host_view() << "\nComputed =\n" + << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = {16, 528}; + + int problem_size_n[] = {16, 528}; + + int problem_size_k[] = {Gemm::InstructionShape::kK, + Gemm::ThreadblockShape::kK * Gemm::kStages + + Gemm::InstructionShape::kK}; + + double problem_alpha[] = {1.0}; + + // TODO Try non zero beta value after multistaged epilogue is implemented + double problem_beta[] = {0.0}; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + passed = + run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/multistage_testbed_interleaved.h b/test/unit/gemm/device/multistage_testbed_interleaved.h new file mode 100644 index 0000000000..c98264de01 --- /dev/null +++ b/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageInterleavedTestbed { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageInterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 0.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_cgemm_nn_sm50.cu b/test/unit/gemm/device/simt_cgemm_nn_sm50.cu index d399b766a0..5aabfca587 100644 --- a/test/unit/gemm/device/simt_cgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_nn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_cgemm_nt_sm50.cu b/test/unit/gemm/device/simt_cgemm_nt_sm50.cu index 7c1922416e..c5265ce2b9 100644 --- a/test/unit/gemm/device/simt_cgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_nt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_cgemm_tn_sm50.cu b/test/unit/gemm/device/simt_cgemm_tn_sm50.cu index 89728ba201..9db96c996a 100644 --- a/test/unit/gemm/device/simt_cgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_tn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_cgemm_tt_sm50.cu b/test/unit/gemm/device/simt_cgemm_tt_sm50.cu index 8d4c9fddc8..0ac7b4c9f8 100644 --- a/test/unit/gemm/device/simt_cgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_tt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_cgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_dgemm_nn_sm50.cu b/test/unit/gemm/device/simt_dgemm_nn_sm50.cu index 3d5c52ed90..1efa9d0446 100644 --- a/test/unit/gemm/device/simt_dgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_nn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_dgemm_nt_sm50.cu b/test/unit/gemm/device/simt_dgemm_nt_sm50.cu index 05fa3c94a0..886c0f9c74 100644 --- a/test/unit/gemm/device/simt_dgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_nt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_dgemm_tn_sm50.cu b/test/unit/gemm/device/simt_dgemm_tn_sm50.cu index f0f2530075..a43d0afd5d 100644 --- a/test/unit/gemm/device/simt_dgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_tn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_dgemm_tt_sm50.cu b/test/unit/gemm/device/simt_dgemm_tt_sm50.cu index 38066b946e..0175978d00 100644 --- a/test/unit/gemm/device/simt_dgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_tt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_dgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_hgemm_nn_sm50.cu b/test/unit/gemm/device/simt_hgemm_nn_sm50.cu index 79af9b47d2..a3aa5ce840 100644 --- a/test/unit/gemm/device/simt_hgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_hgemm_nn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x128x8_32x128x1_8x16_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x64x8_64x64x1_16x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x32x8_128x32x1_16x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x256x8_32x128x1_8x16_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_nn, 64x128x8_64x64x1_16x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x64x8_64x64x1_16x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x32x8_128x32x1_16x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x256x8_32x128x1_8x16_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x128x8_64x64x1_16x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x64x8_128x32x1_16x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x256x8_64x64x1_16x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1777,7 +1777,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1807,7 +1807,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_nn, 256x128x8_64x64x1_16x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1837,7 +1837,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1867,7 +1867,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1897,7 +1897,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1927,7 +1927,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1957,7 +1957,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1987,7 +1987,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2017,7 +2017,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2047,7 +2047,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2077,7 +2077,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2107,7 +2107,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x256x8_32x64x1_8x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2137,7 +2137,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2167,7 +2167,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x128x8_64x32x1_8x8_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_hgemm_nt_sm50.cu b/test/unit/gemm/device/simt_hgemm_nt_sm50.cu index 1401d2fa9e..d5541939e9 100644 --- a/test/unit/gemm/device/simt_hgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_hgemm_nt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x128x8_32x128x1_8x16_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x64x8_64x64x1_16x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x32x8_128x32x1_16x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x256x8_32x128x1_8x16_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_nt, 64x128x8_64x64x1_16x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x64x8_64x64x1_16x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x32x8_128x32x1_16x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x256x8_32x128x1_8x16_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x128x8_64x64x1_16x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x64x8_128x32x1_16x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x256x8_64x64x1_16x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1777,7 +1777,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1807,7 +1807,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_nt, 256x128x8_64x64x1_16x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1837,7 +1837,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1867,7 +1867,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1897,7 +1897,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1927,7 +1927,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1957,7 +1957,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1987,7 +1987,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2017,7 +2017,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2047,7 +2047,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2077,7 +2077,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2107,7 +2107,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x256x8_32x64x1_8x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2137,7 +2137,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2167,7 +2167,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x128x8_64x32x1_8x8_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_hgemm_tn_sm50.cu b/test/unit/gemm/device/simt_hgemm_tn_sm50.cu index f1b7a043f2..526bc01a4c 100644 --- a/test/unit/gemm/device/simt_hgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_hgemm_tn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x128x8_32x128x1_8x16_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x64x8_64x64x1_16x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x32x8_128x32x1_16x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x256x8_32x128x1_8x16_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_tn, 64x128x8_64x64x1_16x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x64x8_64x64x1_16x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x32x8_128x32x1_16x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x256x8_32x128x1_8x16_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x128x8_64x64x1_16x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x64x8_128x32x1_16x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x256x8_64x64x1_16x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1777,7 +1777,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1807,7 +1807,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_tn, 256x128x8_64x64x1_16x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1837,7 +1837,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1867,7 +1867,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1897,7 +1897,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1927,7 +1927,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1957,7 +1957,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1987,7 +1987,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2017,7 +2017,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2047,7 +2047,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2077,7 +2077,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2107,7 +2107,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x256x8_32x64x1_8x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2137,7 +2137,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2167,7 +2167,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x128x8_64x32x1_8x8_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_hgemm_tt_sm50.cu b/test/unit/gemm/device/simt_hgemm_tt_sm50.cu index 4c1b591366..ad464b3018 100644 --- a/test/unit/gemm/device/simt_hgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_hgemm_tt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x128x8_32x128x1_8x16_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x64x8_64x64x1_16x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x32x8_128x32x1_16x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x256x8_32x128x1_8x16_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_tt, 64x128x8_64x64x1_16x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x64x8_64x64x1_16x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x32x8_128x32x1_16x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x256x8_32x128x1_8x16_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x128x8_64x64x1_16x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x64x8_128x32x1_16x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x256x8_64x64x1_16x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1777,7 +1777,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1807,7 +1807,7 @@ CUTLASS_TEST_L0(SM50_device_hgemm_tt, 256x128x8_64x64x1_16x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1837,7 +1837,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1867,7 +1867,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1897,7 +1897,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1927,7 +1927,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1957,7 +1957,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1987,7 +1987,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2017,7 +2017,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2047,7 +2047,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2077,7 +2077,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2107,7 +2107,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x256x8_32x64x1_8x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2137,7 +2137,7 @@ CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -2167,7 +2167,7 @@ CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x128x8_64x32x1_8x8_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_igemm_nn_sm50.cu b/test/unit/gemm/device/simt_igemm_nn_sm50.cu index 59a8dbfe15..3db133ebfd 100644 --- a/test/unit/gemm/device/simt_igemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_nn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_igemm_nt_sm50.cu b/test/unit/gemm/device/simt_igemm_nt_sm50.cu index 7ff0c5cd28..01f56ea030 100644 --- a/test/unit/gemm/device/simt_igemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_nt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_igemm_tn_sm50.cu b/test/unit/gemm/device/simt_igemm_tn_sm50.cu index 392db59e20..3692ec2c3b 100644 --- a/test/unit/gemm/device/simt_igemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_tn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_igemm_tt_sm50.cu b/test/unit/gemm/device/simt_igemm_tt_sm50.cu index 3fdc8e2719..2254669b36 100644 --- a/test/unit/gemm/device/simt_igemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_tt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L0(SM50_device_igemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_igemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_int8_igemm_sm61.cu b/test/unit/gemm/device/simt_int8_igemm_sm61.cu index d1a8821ace..1364a38cff 100644 --- a/test/unit/gemm/device/simt_int8_igemm_sm61.cu +++ b/test/unit/gemm/device/simt_int8_igemm_sm61.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -69,7 +69,7 @@ ElementAccumulator, \ ElementCompute \ >, \ - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, \ + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ 2 \ >; \ EXPECT_TRUE(test::gemm::device::TestAllGemm()); \ diff --git a/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu b/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu index 0c1449e0d4..4e4308ff37 100644 --- a/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu +++ b/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,7 +73,7 @@ TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_perf, 128x256x32_64x64x8) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -109,7 +109,7 @@ TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_perf, 128x256x32_64x64x8) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -145,7 +145,7 @@ TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_perf, 128x256x32_64x64x8) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -181,7 +181,7 @@ TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_perf, 128x256x32_64x64x8) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu b/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu index 9e1c21e9ec..88c72aee4c 100644 --- a/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu +++ b/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -69,7 +69,7 @@ TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -101,7 +101,7 @@ TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -133,7 +133,7 @@ TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -165,7 +165,7 @@ TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -197,7 +197,7 @@ TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -229,7 +229,7 @@ TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -261,7 +261,7 @@ TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; @@ -293,7 +293,7 @@ TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { ElementAccumulator, ElementCompute >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 >; diff --git a/test/unit/gemm/device/simt_sgemm_nn_sm50.cu b/test/unit/gemm/device/simt_sgemm_nn_sm50.cu index a81dd4dbd5..0412d751c3 100644 --- a/test/unit/gemm/device/simt_sgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_nn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_sgemm_nt_sm50.cu b/test/unit/gemm/device/simt_sgemm_nt_sm50.cu index 81c21edaf0..1adb9b5ae4 100644 --- a/test/unit/gemm/device/simt_sgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_nt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_sgemm_nt_sm80.cu b/test/unit/gemm/device/simt_sgemm_nt_sm80.cu new file mode 100644 index 0000000000..7d2ab45b6f --- /dev/null +++ b/test/unit/gemm/device/simt_sgemm_nt_sm80.cu @@ -0,0 +1,249 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 32x64x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x64x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x128x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x64x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_64x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x256x8_64x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_sgemm_tn_sm50.cu b/test/unit/gemm/device/simt_sgemm_tn_sm50.cu index 20a2eddbeb..0c00e56084 100644 --- a/test/unit/gemm/device/simt_sgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_tn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_sgemm_tn_sm80.cu b/test/unit/gemm/device/simt_sgemm_tn_sm80.cu new file mode 100644 index 0000000000..00461d2e0f --- /dev/null +++ b/test/unit/gemm/device/simt_sgemm_tn_sm80.cu @@ -0,0 +1,249 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 32x64x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x64x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x128x8_32x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x64x8_64x32x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_64x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x256x8_64x64x1) { + + using Element = float; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_sgemm_tt_sm50.cu b/test/unit/gemm/device/simt_sgemm_tt_sm50.cu index 22e846b97b..ce7ab9a7e0 100644 --- a/test/unit/gemm/device/simt_sgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_tt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -817,7 +817,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -847,7 +847,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -877,7 +877,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -907,7 +907,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -937,7 +937,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -967,7 +967,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -997,7 +997,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1027,7 +1027,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1057,7 +1057,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1087,7 +1087,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1117,7 +1117,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1147,7 +1147,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1177,7 +1177,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1207,7 +1207,7 @@ CUTLASS_TEST_L0(SM50_device_sgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1237,7 +1237,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1267,7 +1267,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1297,7 +1297,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1327,7 +1327,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1357,7 +1357,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1387,7 +1387,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1417,7 +1417,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1447,7 +1447,7 @@ CUTLASS_TEST_L1(SM50_device_sgemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1477,7 +1477,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1507,7 +1507,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1537,7 +1537,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1567,7 +1567,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1597,7 +1597,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1627,7 +1627,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1657,7 +1657,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1687,7 +1687,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1717,7 +1717,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -1747,7 +1747,7 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_sm50.py b/test/unit/gemm/device/simt_sm50.py index ba6ec3c291..f53dae2715 100644 --- a/test/unit/gemm/device/simt_sm50.py +++ b/test/unit/gemm/device/simt_sm50.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -123,7 +123,7 @@ # write file header out.write("/***************************************************************************************************\n" -" * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n" +" * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.\n" " *\n" " * Redistribution and use in source and binary forms, with or without modification, are permitted\n" " * provided that the following conditions are met:\n" diff --git a/test/unit/gemm/device/simt_zgemm_nn_sm50.cu b/test/unit/gemm/device/simt_zgemm_nn_sm50.cu index 7145b39535..7731559a81 100644 --- a/test/unit/gemm/device/simt_zgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_zgemm_nn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_zgemm_nt_sm50.cu b/test/unit/gemm/device/simt_zgemm_nt_sm50.cu index ffe8c0ddaf..17ea98203a 100644 --- a/test/unit/gemm/device/simt_zgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_zgemm_nt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_zgemm_tn_sm50.cu b/test/unit/gemm/device/simt_zgemm_tn_sm50.cu index 2d4799eb95..175c312868 100644 --- a/test/unit/gemm/device/simt_zgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_zgemm_tn_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/simt_zgemm_tt_sm50.cu b/test/unit/gemm/device/simt_zgemm_tt_sm50.cu index ba2447bcef..544e626c5a 100644 --- a/test/unit/gemm/device/simt_zgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_zgemm_tt_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,7 +67,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -97,7 +97,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -127,7 +127,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -157,7 +157,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -187,7 +187,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -217,7 +217,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -247,7 +247,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -277,7 +277,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -307,7 +307,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -337,7 +337,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -367,7 +367,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -397,7 +397,7 @@ CUTLASS_TEST_L0(SM50_device_zgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -427,7 +427,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -457,7 +457,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -487,7 +487,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -517,7 +517,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -547,7 +547,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -577,7 +577,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -607,7 +607,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -637,7 +637,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -667,7 +667,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -697,7 +697,7 @@ CUTLASS_TEST_L1(SM50_device_zgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -727,7 +727,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -757,7 +757,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -787,7 +787,7 @@ CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index 63d88e9bd4..b8c739a7e9 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -26,6 +26,8 @@ \brief Tests for device-wide GEMM interface */ +#pragma once + #include #include #include @@ -41,20 +43,7 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/gemm.h" -inline char const *to_string(cutlass::Status status) { - - switch (status) { - case cutlass::Status::kSuccess: return "kSuccess"; - case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; - case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; - case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; - case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; - case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; - case cutlass::Status::kErrorInternal: return "kErrorInternal"; - case cutlass::Status::kInvalid: return "kInvalid"; - } - return "invalid"; -} +#include "testbed_utils.h" namespace test { namespace gemm { @@ -185,9 +174,12 @@ struct Testbed { EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); @@ -341,18 +333,12 @@ bool TestAllGemm() { (cutlass::platform::is_same::value || cutlass::platform::is_same::value) ? 4 : kAlignment; - - int problem_size_m[] = { - kAlignmentM, 512 - 3*kAlignmentM - }; + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; int problem_size_k[] = { - kAlignmentK, Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK - }; + kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; int split_k_slices[] = { 1, 2, 3 @@ -379,6 +365,10 @@ bool TestAllGemm() { continue; } + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + for (auto alpha : problem_alpha) { for (auto beta : problem_beta) { diff --git a/test/unit/gemm/device/testbed_complex.h b/test/unit/gemm/device/testbed_complex.h index e3372cddaf..65c0fdfb4c 100644 --- a/test/unit/gemm/device/testbed_complex.h +++ b/test/unit/gemm/device/testbed_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -26,6 +26,8 @@ \brief Tests for device-wide GEMM interface */ +#pragma once + #include #include #include @@ -90,6 +92,7 @@ struct TestbedComplex : public Testbed { this->tensor_B.host_ref(), Gemm::kTransformB, beta, + this->tensor_C.host_ref(), this->reference_D.host_ref(), ElementAccumulator(0) ); diff --git a/test/unit/gemm/device/testbed_interleaved.h b/test/unit/gemm/device/testbed_interleaved.h index 34d61383be..3cbd720bd4 100644 --- a/test/unit/gemm/device/testbed_interleaved.h +++ b/test/unit/gemm/device/testbed_interleaved.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/device/testbed_planar_complex.h b/test/unit/gemm/device/testbed_planar_complex.h new file mode 100644 index 0000000000..0e4e561e42 --- /dev/null +++ b/test/unit/gemm/device/testbed_planar_complex.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TestbedPlanarComplex { +public: + + using ElementA = typename Gemm::ElementA; + using LayoutA = typename Gemm::LayoutA; + using ElementB = typename Gemm::ElementB; + using LayoutB = typename Gemm::LayoutB; + using ElementC = typename Gemm::ElementC; + using LayoutC = typename Gemm::LayoutC; + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::HostTensorPlanarComplex tensor_A; + cutlass::HostTensorPlanarComplex tensor_B; + cutlass::HostTensorPlanarComplex tensor_C; + cutlass::HostTensorPlanarComplex tensor_D; + cutlass::HostTensorPlanarComplex tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { + + tensor_A.reset({problem_size.m(), problem_size.k()}); + tensor_B.reset({problem_size.k(), problem_size.n()}); + tensor_C.reset({problem_size.m(), problem_size.n()}); + tensor_D.reset({problem_size.m(), problem_size.n()}); + tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); + } + + void initialize() { + + uint64_t seed = 1073; + + int scope_max = 8; + int scope_min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + bool run( + cutlass::complex alpha = {1, 0}, + cutlass::complex beta = {0, 0}) { + + initialize(); + + int batch_count = 1; + + ElementA *ptr_A = tensor_A.device_data(); + ElementB *ptr_B = tensor_B.device_data(); + ElementC *ptr_C = tensor_C.device_data(); + ElementC *ptr_D = tensor_D.device_data(); + + int lda = tensor_A.layout().stride(0); + int ldb = tensor_B.layout().stride(0); + int ldc = tensor_C.layout().stride(0); + int ldd = tensor_D.layout().stride(0); + + int64_t imag_stride_A = tensor_A.imaginary_stride(); + int64_t imag_stride_B = tensor_B.imaginary_stride(); + int64_t imag_stride_C = tensor_C.imaginary_stride(); + int64_t imag_stride_D = tensor_D.imaginary_stride(); + + // + // Launch device kernel + // + + Gemm gemm_op; + + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + {alpha, beta}, + ptr_A, + ptr_A + imag_stride_A, + ptr_B, + ptr_B + imag_stride_B, + ptr_C, + ptr_C + imag_stride_C, + ptr_D, + ptr_D + imag_stride_D, + lda, + lda, + ldb, + ldb, + ldc, + ldc, + ldd, + ldd + }; + + cutlass::Status status = gemm_op(args); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t error = cudaDeviceSynchronize(); + + tensor_D.sync_host(); + + // + // Compute reference + // + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + tensor_D_ref.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_D.host_view(), + tensor_D_ref.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("gemm_planar_complex.txt"); + + output + << "A:\n" << tensor_A.host_view() << "\n" + << "B:\n" << tensor_B.host_view() << "\n" + << "C:\n" << tensor_C.host_view() << "\n" + << "Reference:\n" + << tensor_D_ref.host_view() << "\n" + << "Computed:\n" + << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +template +bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { + + TestbedPlanarComplex testbed(problem_size); + + return testbed.run(); +} + +template +bool TestAllGemmPlanarComplex() { + + int M[] = { + 16, 264, + }; + + int N[] = { + 16, 248, + }; + + int K[] = { + 8, 96, + }; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + cutlass::complex alpha_values[] = { + {ElementCompute(1.25), ElementCompute(-0.5)} + }; + + cutlass::complex beta_values[] = { + {ElementCompute(-2.25), ElementCompute(1.5)} + }; + + for (int m : M) { + for (int n : N) { + for (int k : K) { + + test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = testbed.run(alpha, beta); + if (!passed) { + return false; + } + } + } + } + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/test/unit/gemm/device/testbed_sanity.h b/test/unit/gemm/device/testbed_sanity.h new file mode 100644 index 0000000000..025fb3874d --- /dev/null +++ b/test/unit/gemm/device/testbed_sanity.h @@ -0,0 +1,233 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/core_io.h" + +#include "testbed.h" + + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// List of Gemm internal paramters this testbed supports user verification +// +enum class ParameterID { + + // Threadblock-level parameters + kSmemASize, + kSmemBSize, + + // Warp-level parameters + kWarpFragmentASize, + kWarpFragmentBSize, + kWarpFragmentCSize, + kInvalid +}; + +struct Reference { + ParameterID parameter_id; + + union { + int value; + + struct { + int m, n, k; + } gemm_shape; + + struct { + int row, column; + } matrix_shape; + }; + + std::string error_msg; + + Reference( + ParameterID parameter_id_, + int value_=-1, + std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} +}; + + +template +struct TestbedSanity { + + // + // Type definitions (All Gemm types top down) + // + + // Unpacking Gemm types in the following order + // Kernel-level > Threadblock-level > Warp-level > Instruction-level + + // kernel-level cutlass Gemm + using GemmKernel = typename Gemm::GemmKernel; + + // + // Threadblock-level gemm types + // + using MmaThreadBlock = typename GemmKernel::Mma; + + // Threadblock-level gemm shape covering one stage + using ThreadblockShape = typename MmaThreadBlock::Shape; + + // Shared memory size covering all stages + using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; + using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; + using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; + using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; + + + /// Number of stages + static int const kStages = MmaThreadBlock::Base::kStages; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; + + + // + // Warp-level gemm types + // + + // Warp-level gemm operator + using MmaWarp = typename MmaThreadBlock::Operator; + + // Warp-level gemm shape covering all kgroups + using WarpShape = typename MmaWarp::Shape; + + // Warp-level framents holding operands A & B operand and destination C + using WarpFragmentA = typename MmaWarp::FragmentA; + using WarpFragmentB = typename MmaWarp::FragmentB; + using WarpFragmentC = typename MmaWarp::FragmentC; + + // + // Instruction-level gemm types + // + + // Instruction-level gemm operator + using MmaInstruction = typename MmaWarp::Policy::Operator; + + // Instruction shape + using InstructionShape = typename MmaInstruction::Shape; + + // Instruction-level framents holding operands A & B operand and destination C + using InstructionFragmentA = typename MmaInstruction::FragmentA; + using InstructionFragmentB = typename MmaInstruction::FragmentB; + using InstructionFragmentC = typename MmaInstruction::FragmentC; + + // + // Testbed types + // + + // Vector of values holding user provided reference + using ReferenceVector = std::vector; + + // + // Data members + // + ReferenceVector references; + + // + // Methods + // + + TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } + + // verify all parameter in ReferenceVector + bool verify() { + for(auto ref : references) + verify_parameter(ref); + return true; + } + + // verify parameter of type Reference + void verify_parameter(Reference const& ref) { + switch(ref.parameter_id) { + case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Overload output operators for TesbedSanity +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { + + + out << "Gemm internal parameters" << std::endl + << " Threadblock-level parameters:" << std::endl + << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl + << " kStages = " << TestbedSanity::kStages << std::endl + << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl + <<" Shared memory sizes:" << std::endl + <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl + <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl + <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl + <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl + <<" Warp-level parameters" << std::endl + <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl + <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl + <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl + <<" Instruction-level parameters" << std::endl + <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl + <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl + <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; + + return out; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/testbed_splitk.h b/test/unit/gemm/device/testbed_splitk.h index 19f2d1fea5..792d73923a 100644 --- a/test/unit/gemm/device/testbed_splitk.h +++ b/test/unit/gemm/device/testbed_splitk.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -26,6 +26,8 @@ \brief Tests for device-wide GEMM interface */ +#pragma once + #include #include #include diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h new file mode 100644 index 0000000000..a83c27cda6 --- /dev/null +++ b/test/unit/gemm/device/testbed_universal.h @@ -0,0 +1,480 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedUniversal { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedUniversal testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllGemmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentM, 512 - 3*kAlignmentM + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1, 2, 3, 5, 7 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { + continue; + } + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + } + + /* + // large problem with high coverage + for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { + TestbedUniversal testbed; + + cutlass::gemm::GemmCoord problem_size(72, 56, 8192); + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + cutlass::from_real(1.0), + cutlass::from_real(2.0) + ); + + if (!passed) { + break; + } + } + */ + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/testbed_utils.h b/test/unit/gemm/device/testbed_utils.h new file mode 100644 index 0000000000..9325b40fe3 --- /dev/null +++ b/test/unit/gemm/device/testbed_utils.h @@ -0,0 +1,47 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "kErrorInternal"; + case cutlass::Status::kInvalid: return "kInvalid"; + default: break; + } + return "invalid"; +} diff --git a/test/unit/gemm/thread/CMakeLists.txt b/test/unit/gemm/thread/CMakeLists.txt index 11d450c7d1..48ca115728 100644 --- a/test/unit/gemm/thread/CMakeLists.txt +++ b/test/unit/gemm/thread/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/gemm/thread/gemm_sm50.cu b/test/unit/gemm/thread/gemm_sm50.cu index 969580f587..4265922841 100644 --- a/test/unit/gemm/thread/gemm_sm50.cu +++ b/test/unit/gemm/thread/gemm_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/thread/gemm_sm60.cu b/test/unit/gemm/thread/gemm_sm60.cu index 19b8461923..b0b9fdb5b7 100644 --- a/test/unit/gemm/thread/gemm_sm60.cu +++ b/test/unit/gemm/thread/gemm_sm60.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/thread/gemm_sm61.cu b/test/unit/gemm/thread/gemm_sm61.cu index f8cbf2b81c..f6e7724dd8 100644 --- a/test/unit/gemm/thread/gemm_sm61.cu +++ b/test/unit/gemm/thread/gemm_sm61.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/thread/host/CMakeLists.txt b/test/unit/gemm/thread/host/CMakeLists.txt index 75f76c9285..c58540264d 100644 --- a/test/unit/gemm/thread/host/CMakeLists.txt +++ b/test/unit/gemm/thread/host/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/gemm/thread/host/gemm_sm60_host.cu b/test/unit/gemm/thread/host/gemm_sm60_host.cu index df2d233a6b..346b80cbe2 100644 --- a/test/unit/gemm/thread/host/gemm_sm60_host.cu +++ b/test/unit/gemm/thread/host/gemm_sm60_host.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/thread/host/testbed_host.h b/test/unit/gemm/thread/host/testbed_host.h index d2835efec5..4d5e441dd5 100644 --- a/test/unit/gemm/thread/host/testbed_host.h +++ b/test/unit/gemm/thread/host/testbed_host.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/thread/testbed.h b/test/unit/gemm/thread/testbed.h index 1b1082a5dc..bdfb8278f4 100644 --- a/test/unit/gemm/thread/testbed.h +++ b/test/unit/gemm/thread/testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/CMakeLists.txt b/test/unit/gemm/threadblock/CMakeLists.txt index 7ec75510a6..f208b9ef17 100644 --- a/test/unit/gemm/threadblock/CMakeLists.txt +++ b/test/unit/gemm/threadblock/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/batched_gemv.cu b/test/unit/gemm/threadblock/batched_gemv.cu index 79b5ac4e53..94ae947bd2 100644 --- a/test/unit/gemm/threadblock/batched_gemv.cu +++ b/test/unit/gemm/threadblock/batched_gemv.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/epilogue_workspace.cu b/test/unit/gemm/threadblock/epilogue_workspace.cu index c1967e43f6..1301aeb4dd 100644 --- a/test/unit/gemm/threadblock/epilogue_workspace.cu +++ b/test/unit/gemm/threadblock/epilogue_workspace.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_pipelined_simt.cu b/test/unit/gemm/threadblock/mma_pipelined_simt.cu index b5c1a58b77..522b029adb 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_simt.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_simt.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_pipelined_sm70.cu b/test/unit/gemm/threadblock/mma_pipelined_sm70.cu index b9302ef332..c9c714bcf6 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_sm70.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_pipelined_sm75.cu b/test/unit/gemm/threadblock/mma_pipelined_sm75.cu index 5585f23f66..e4125eb4f0 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_sm75.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -231,6 +231,7 @@ TEST(SM75_gemm_threadblock_congruous, } //////////////////////////////////////////////////////////////////////////////// + TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x32_64x64x32_16x8x8) { using ElementA = cutlass::half_t; using LayoutA = cutlass::layout::RowMajor; @@ -562,6 +563,7 @@ TEST(SM75_gemm_threadblock_crosswise, } //////////////////////////////////////////////////////////////////////////////// + TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x32x64_16x16x64_8x8x16) { using ElementA = uint8_t; using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; @@ -1785,4 +1787,337 @@ TEST(SM75_gemm_threadblock_interleaved, } //////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_64x64x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x512_16x16x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(32, 32, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 512>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x512_32x16x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 32, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x512_16x32x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(32, 64, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_32x32x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x512_64x32x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore component + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x512_32x64x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x512_64x64x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 2048); + + using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, + multicta_256x256x1536_128x128x512_64x64x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 1536); + + using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_gemm_threadblock_crosswise, + multicta_512x256x6144_256x128x512_64x64x512_8x8x128) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 6144); + + using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; + + float alpha = 1.f; + float beta = 0.f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, + cutlass::arch::OpXorPopc>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif diff --git a/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/test/unit/gemm/threadblock/mma_pipelined_testbed.h index 498ca4967d..8190c50a41 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_testbed.h +++ b/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without *modification, are permitted provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu b/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu index 3c1720a1a4..4fb964c1ae 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu b/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu index e3d900d53e..fd2ae356fa 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h new file mode 100644 index 0000000000..148e34d959 --- /dev/null +++ b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -0,0 +1,345 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma_planar_complex( + cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::Element *ptr_A, + int64_t imaginary_stride_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::Element *ptr_B, + int64_t imaginary_stride_B, + typename Mma::ElementC *ptr_C, int ldc, int64_t imaginary_stride_C) { + + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A_real(params_A, ptr_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B_real(params_B, ptr_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum.real); + + iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename Mma_> +struct TestbedPlanarComplex { + + using Mma = Mma_; + using ThreadblockShape = typename Mma::Shape; + using IteratorA = typename Mma::IteratorA; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using IteratorB = typename Mma::IteratorB; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Mma::ElementC; + using ElementAccumulator = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + using ThreadMapA = typename Mma::IteratorA::ThreadMap; + using ThreadMapB = typename Mma::IteratorB::ThreadMap; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = Mma::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + Mma::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + Mma::kCacheOpB; + + // + // Data members + // + + cutlass::HostTensorPlanarComplex matrix_A; + cutlass::HostTensorPlanarComplex matrix_B; + cutlass::HostTensorPlanarComplex matrix_C_computed; + cutlass::HostTensorPlanarComplex matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedPlanarComplex(int m, int n, int k) + : problem_size(m, n, k) { + + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + + for (int i = 0; i < matrix_A.capacity() * 2; ++i) { + matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); + } + /* + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity() * 2); + */ + } else if (init_A == cutlass::Distribution::Identity) { + //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + + + } else if (init_B == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity() * 2); + + for (int i = 0; i < matrix_B.capacity() * 2; ++i) { + matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); + } + + + } else if (init_B == cutlass::Distribution::Identity) { + + //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + + } else { + // TODO: Implement the rest + return false; + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma_planar_complex<<>>( + problem_size, + params_A, + matrix_A.device_data(), + matrix_A.imaginary_stride(), + params_B, + matrix_B.device_data(), + matrix_B.imaginary_stride(), + matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), + matrix_C_computed.imaginary_stride() + ); + + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + cutlass::complex(ElementAccumulator(1)), + matrix_A.host_ref(), + Mma::kTransformA, + matrix_B.host_ref(), + Mma::kTransformB, + cutlass::complex(ElementAccumulator(0)), + matrix_C_reference.host_ref(), + matrix_C_reference.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), + matrix_C_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu b/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu index ba54249d92..8c687f8810 100644 --- a/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu +++ b/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu b/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu index d1c0608342..262269b75d 100644 --- a/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu +++ b/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/CMakeLists.txt b/test/unit/gemm/warp/CMakeLists.txt index 96cfc29b9d..695508fa5a 100644 --- a/test/unit/gemm/warp/CMakeLists.txt +++ b/test/unit/gemm/warp/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -27,8 +27,10 @@ cutlass_test_unit_add_executable( gemm_sm61.cu gemm_sm70.cu gemm_sm75.cu + gemm_sm80.cu + gemm_complex_sm80.cu + gemm_gaussian_complex_sm80.cu wmma_sm70.cu wmma_sm72.cu wmma_sm75.cu - testbed.h ) diff --git a/test/unit/gemm/warp/gemm_complex_sm80.cu b/test/unit/gemm/warp/gemm_complex_sm80.cu new file mode 100644 index 0000000000..3fcd70c8d0 --- /dev/null +++ b/test/unit/gemm/warp/gemm_complex_sm80.cu @@ -0,0 +1,635 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM +*/ + +#include "cutlass/cutlass.h" +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// complex * complex => complex +// Input data type: complex +// Math instruction: MMA.884.F64.F64 +// Output data type: complex +/////////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<8, 8, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x32x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x16x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_nh) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_ct) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<8, 8, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// complex * complex => complex +// Input data type: complex +// Math instruction: MMA.1688.F32.TF32 +// Output data type: complex +// Shared memory layout: Congrous +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 8> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 16> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x32x8_16x8x8_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<16, 32, 8> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x16x8_16x16x8_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 16, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 16, 8> >() + .run(); +} + + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_nh) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_ct) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() + .run(); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// complex * complex => complex +// Input data type: complex +// Math instruction: MMA.1688.F32.TF32 +// Output data type: complex +// Shared memory layout: Crosswise +//////////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_tn) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 8> >() + .run(); +} + +// TEST FAILS crosswise complex TN MMA.1688.F32.TF32 test fails for k = 2*8 = 16 +TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_tn) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 16> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_tn) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x64x8_16x8x8_tn) { + + using Shape = cutlass::gemm::GemmShape<32, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 64, 8> >() + .run(); +} + +TEST(SM80_warp_gemm_complex_tensor_op_f32, 64x32x8_16x8x8_tn) { + + using Shape = cutlass::gemm::GemmShape<64, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<64, 32, 8> >() + .run(); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu b/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu new file mode 100644 index 0000000000..43ad2dfd85 --- /dev/null +++ b/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM +*/ + +#include "cutlass/cutlass.h" +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 8x8x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<8, 8, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x16x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x32x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x16x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_nh) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_ct) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x16x4_8x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/warp/gemm_sm50.cu b/test/unit/gemm/warp/gemm_sm50.cu index f6410d1d43..bb4ba5be58 100644 --- a/test/unit/gemm/warp/gemm_sm50.cu +++ b/test/unit/gemm/warp/gemm_sm50.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/gemm_sm60.cu b/test/unit/gemm/warp/gemm_sm60.cu index cf59d442e6..4f2f3f1582 100644 --- a/test/unit/gemm/warp/gemm_sm60.cu +++ b/test/unit/gemm/warp/gemm_sm60.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/gemm_sm61.cu b/test/unit/gemm/warp/gemm_sm61.cu index 98a16046e5..63e07165b6 100644 --- a/test/unit/gemm/warp/gemm_sm61.cu +++ b/test/unit/gemm/warp/gemm_sm61.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/gemm_sm70.cu b/test/unit/gemm/warp/gemm_sm70.cu index d97effeabb..16f1427e55 100644 --- a/test/unit/gemm/warp/gemm_sm70.cu +++ b/test/unit/gemm/warp/gemm_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/gemm_sm75.cu b/test/unit/gemm/warp/gemm_sm75.cu index 7c32de4ac2..144475cae4 100644 --- a/test/unit/gemm/warp/gemm_sm75.cu +++ b/test/unit/gemm/warp/gemm_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -109,6 +109,8 @@ TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x32_32x32x32_16x8x8) { .run(); } +//////////////////////////////////////////////////////////////////////////////// + TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x64x32_16x8x8) { using Shape = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; @@ -317,6 +319,8 @@ TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_16x16x64_16x8x8) { .run(); } +//////////////////////////////////////////////////////////////////////////////// + TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x64x64_8x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; diff --git a/test/unit/gemm/warp/gemm_sm80.cu b/test/unit/gemm/warp/gemm_sm80.cu new file mode 100644 index 0000000000..377e760c6b --- /dev/null +++ b/test/unit/gemm/warp/gemm_sm80.cu @@ -0,0 +1,1782 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x64x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x16x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 16, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_16x16x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<16, 16, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_16x16x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<16, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_64x64x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_64x32x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_32x32x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_32x16x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<32, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_16x16x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_64x64x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_64x32x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_32x32x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_32x16x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<32, 16, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_16x16x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<16, 16, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x32_64x64x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x32_32x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x16_64x64x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x16_32x32x16_16x8x8) { + using Shape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x32_64x64x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x32_32x32x32_16x8x8) { + using Shape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_tn, tf32_round_128x128x32_16x16x32_16x8x8) { + + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = float; + using ElementC = float; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +TEST(SM80_warp_gemm_tensor_op_nt, tf32_round_128x128x32_16x16x32_16x8x8) { + + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = float; + using ElementC = float; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_16x16x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<16, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x16x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_16x16x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<16, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x16x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x16x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_16x16x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<16, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_64x64x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_64x32x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_32x32x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_32x16x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_16x16x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<16, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = int8_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x32x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x32x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x16x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_16x16x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<16, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_64x64x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_64x32x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_32x32x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_32x16x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 16, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_16x16x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<16, 16, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_16x8x256) { + using Shape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x32x512_16x8x256) { + using Shape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x32x512_16x8x256) { + using Shape = cutlass::gemm::GemmShape<32, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x16x512_16x8x256) { + using Shape = cutlass::gemm::GemmShape<32, 16, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_16x16x512_16x8x256) { + using Shape = cutlass::gemm::GemmShape<16, 16, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 512>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_64x64x1024_16x8x256) { + using Shape = cutlass::gemm::GemmShape<64, 64, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_64x32x1024_16x8x256) { + using Shape = cutlass::gemm::GemmShape<64, 32, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_32x32x1024_16x8x256) { + using Shape = cutlass::gemm::GemmShape<32, 32, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_32x16x1024_16x8x256) { + using Shape = cutlass::gemm::GemmShape<32, 16, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_16x16x1024_16x8x256) { + using Shape = cutlass::gemm::GemmShape<16, 16, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + using Element = cutlass::uint1b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 1024>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_8x8x4) { + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 64, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_8x8x4) { + using Shape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_8x8x4) { + using Shape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_16x16x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<16, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_32x16x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_32x32x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x32x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x64x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = cutlass::int4b_t; + using ElementC = int; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/warp/testbed.h b/test/unit/gemm/warp/testbed.h index 47ab7bf0cb..8a565fd9fd 100644 --- a/test/unit/gemm/warp/testbed.h +++ b/test/unit/gemm/warp/testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -102,6 +102,7 @@ __global__ void kernel( FragmentA frag_A; FragmentB frag_B; + FragmentC accum; Mma mma; @@ -306,13 +307,22 @@ struct Testbed { if (!passed) { - cutlass::TensorView tensor_A_physical(tensor_A.host_data(), tensor_A.stride(), tensor_A.extent()); - cutlass::TensorView tensor_B_physical(tensor_B.host_data(), tensor_B.stride(), tensor_B.extent()); + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride(), + tensor_A.extent()); + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride(), + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; @@ -459,6 +469,484 @@ struct TestbedComplex { tensor_B.host_ref(), Mma::kTransformB, ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride(), + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride(), + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel_transform( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + using TransformedFragmentA = typename Mma::TransformedFragmentA; + using TransformedFragmentB = typename Mma::TransformedFragmentB; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::LaneId()); + + FragmentA loaded_frag_A; + FragmentB loaded_frag_B; + TransformedFragmentA transformed_frag_A; + TransformedFragmentB transformed_frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(loaded_frag_A); + iter_B.load(loaded_frag_B); + + ++iter_A; + ++iter_B; + + mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, + loaded_frag_B); + + mma(accum, transformed_frag_A, transformed_frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct TransformTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<>>( + tensor_D_computed.device_data(), tensor_A.device_data(), + tensor_B.device_data(), tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride(), + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride(), + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TransformedTestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformedTestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), tensor_D_reference.host_ref() ); @@ -486,13 +974,15 @@ struct TestbedComplex { tensor_B.stride(), tensor_B.extent()); + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical):\n" << tensor_B_physical << "\n\n"; + << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; std::cout << "C:\n" << tensor_C.host_view() << "\n\n" diff --git a/test/unit/gemm/warp/wmma_sm70.cu b/test/unit/gemm/warp/wmma_sm70.cu index d5e1107c1b..5b9ce63db1 100644 --- a/test/unit/gemm/warp/wmma_sm70.cu +++ b/test/unit/gemm/warp/wmma_sm70.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/wmma_sm72.cu b/test/unit/gemm/warp/wmma_sm72.cu index 4f81bbe265..89bfbb5945 100644 --- a/test/unit/gemm/warp/wmma_sm72.cu +++ b/test/unit/gemm/warp/wmma_sm72.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/gemm/warp/wmma_sm75.cu b/test/unit/gemm/warp/wmma_sm75.cu index a041610db2..3818793e84 100644 --- a/test/unit/gemm/warp/wmma_sm75.cu +++ b/test/unit/gemm/warp/wmma_sm75.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/layout/CMakeLists.txt b/test/unit/layout/CMakeLists.txt index ab34df0caa..29ebdbdd30 100644 --- a/test/unit/layout/CMakeLists.txt +++ b/test/unit/layout/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/layout/matrix.cu b/test/unit/layout/matrix.cu index 0adddb891f..2f8d0ea2be 100644 --- a/test/unit/layout/matrix.cu +++ b/test/unit/layout/matrix.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/layout/tensor.cu b/test/unit/layout/tensor.cu index a6b3f7cfff..b4a43fb3a9 100644 --- a/test/unit/layout/tensor.cu +++ b/test/unit/layout/tensor.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/layout/tensor_nhwc.cu b/test/unit/layout/tensor_nhwc.cu index 697f753daa..46482b2b2f 100644 --- a/test/unit/layout/tensor_nhwc.cu +++ b/test/unit/layout/tensor_nhwc.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/nvrtc/CMakeLists.txt b/test/unit/nvrtc/CMakeLists.txt index 7261da9688..668ea35ebe 100644 --- a/test/unit/nvrtc/CMakeLists.txt +++ b/test/unit/nvrtc/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/nvrtc/cutlass/nvrtc/environment.h b/test/unit/nvrtc/cutlass/nvrtc/environment.h index e3d493ab9b..27e999348c 100644 --- a/test/unit/nvrtc/cutlass/nvrtc/environment.h +++ b/test/unit/nvrtc/cutlass/nvrtc/environment.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/test/unit/nvrtc/kernel/thread/testbed_kernel.h index c758235167..500870581d 100644 --- a/test/unit/nvrtc/kernel/thread/testbed_kernel.h +++ b/test/unit/nvrtc/kernel/thread/testbed_kernel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/nvrtc/stdlib/stdint.h b/test/unit/nvrtc/stdlib/stdint.h index d066380e7e..380216811b 100644 --- a/test/unit/nvrtc/stdlib/stdint.h +++ b/test/unit/nvrtc/stdlib/stdint.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,3 +33,91 @@ typedef int int32_t; typedef unsigned int uint32_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; + +#if defined __x86_64__ && !defined __ILP32__ +# define __WORDSIZE 64 +#else +# define __WORDSIZE 32 +#endif + + +/* Small types. */ + +/* Signed. */ +typedef signed char int_least8_t; +typedef short int int_least16_t; +typedef int int_least32_t; +#if __WORDSIZE == 64 +typedef long int int_least64_t; +#else +__extension__ +typedef long long int int_least64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_least8_t; +typedef unsigned short int uint_least16_t; +typedef unsigned int uint_least32_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_least64_t; +#else +__extension__ +typedef unsigned long long int uint_least64_t; +#endif + + +/* Fast types. */ + +/* Signed. */ +typedef signed char int_fast8_t; +#if __WORDSIZE == 64 +typedef long int int_fast16_t; +typedef long int int_fast32_t; +typedef long int int_fast64_t; +#else +typedef int int_fast16_t; +typedef int int_fast32_t; +__extension__ +typedef long long int int_fast64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_fast8_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_fast16_t; +typedef unsigned long int uint_fast32_t; +typedef unsigned long int uint_fast64_t; +#else +typedef unsigned int uint_fast16_t; +typedef unsigned int uint_fast32_t; +__extension__ +typedef unsigned long long int uint_fast64_t; +#endif + +/* Types for `void *' pointers. */ +#if __WORDSIZE == 64 +# ifndef __intptr_t_defined +typedef long int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned long int uintptr_t; +#else +# ifndef __intptr_t_defined +typedef int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned int uintptr_t; +#endif + + +/* Largest integral types. */ +#if __WORDSIZE == 64 +typedef long int intmax_t; +typedef unsigned long int uintmax_t; +#else +__extension__ +typedef long long int intmax_t; +__extension__ +typedef unsigned long long int uintmax_t; +#endif + diff --git a/test/unit/nvrtc/thread/CMakeLists.txt b/test/unit/nvrtc/thread/CMakeLists.txt index f1d2b7a125..2e12ccfa8c 100644 --- a/test/unit/nvrtc/thread/CMakeLists.txt +++ b/test/unit/nvrtc/thread/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/nvrtc/thread/gemm_nvrtc.cu b/test/unit/nvrtc/thread/gemm_nvrtc.cu index bf57f1d3da..785ebcb2ce 100644 --- a/test/unit/nvrtc/thread/gemm_nvrtc.cu +++ b/test/unit/nvrtc/thread/gemm_nvrtc.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/nvrtc/thread/testbed.h b/test/unit/nvrtc/thread/testbed.h index 69bf81f473..41ba503ad5 100644 --- a/test/unit/nvrtc/thread/testbed.h +++ b/test/unit/nvrtc/thread/testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/reduction/CMakeLists.txt b/test/unit/reduction/CMakeLists.txt index ba1b2a99e3..7b4f267069 100644 --- a/test/unit/reduction/CMakeLists.txt +++ b/test/unit/reduction/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/reduction/kernel/CMakeLists.txt b/test/unit/reduction/kernel/CMakeLists.txt index 9ef27c84ee..e1983153d1 100644 --- a/test/unit/reduction/kernel/CMakeLists.txt +++ b/test/unit/reduction/kernel/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/reduction/kernel/reduce_splitk.cu b/test/unit/reduction/kernel/reduce_splitk.cu index f4a7f07dba..b169cb60f1 100644 --- a/test/unit/reduction/kernel/reduce_splitk.cu +++ b/test/unit/reduction/kernel/reduce_splitk.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/reduction/kernel/reduce_splitk_testbed.h b/test/unit/reduction/kernel/reduce_splitk_testbed.h index c5cbbd58d0..8e70407063 100644 --- a/test/unit/reduction/kernel/reduce_splitk_testbed.h +++ b/test/unit/reduction/kernel/reduce_splitk_testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/reduction/thread/CMakeLists.txt b/test/unit/reduction/thread/CMakeLists.txt index f42276f76e..0641590e8c 100644 --- a/test/unit/reduction/thread/CMakeLists.txt +++ b/test/unit/reduction/thread/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/reduction/thread/reduction_thread.cu b/test/unit/reduction/thread/reduction_thread.cu index ece4934598..f71e30f53c 100644 --- a/test/unit/reduction/thread/reduction_thread.cu +++ b/test/unit/reduction/thread/reduction_thread.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/reduction/thread/testbed.h b/test/unit/reduction/thread/testbed.h index 3646e5bf07..919839b3d6 100644 --- a/test/unit/reduction/thread/testbed.h +++ b/test/unit/reduction/thread/testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/test_unit.cpp b/test/unit/test_unit.cpp index fc386250c8..3bb8ac1387 100644 --- a/test/unit/test_unit.cpp +++ b/test/unit/test_unit.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/transform/CMakeLists.txt b/test/unit/transform/CMakeLists.txt index ee865cd4aa..a7b881ae20 100644 --- a/test/unit/transform/CMakeLists.txt +++ b/test/unit/transform/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/transform/threadblock/CMakeLists.txt b/test/unit/transform/threadblock/CMakeLists.txt index e849dc8a4d..0d5e5c44a0 100644 --- a/test/unit/transform/threadblock/CMakeLists.txt +++ b/test/unit/transform/threadblock/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/test/unit/transform/threadblock/predicated_tile_iterator.cu b/test/unit/transform/threadblock/predicated_tile_iterator.cu index 70502f73a7..562c7888a2 100644 --- a/test/unit/transform/threadblock/predicated_tile_iterator.cu +++ b/test/unit/transform/threadblock/predicated_tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu b/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu index e032383ee6..e52af8edf9 100644 --- a/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu +++ b/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/test/unit/util/complex.cu b/test/unit/util/complex.cu index e4867e19e3..319bbb2aa4 100644 --- a/test/unit/util/complex.cu +++ b/test/unit/util/complex.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 0aa594b0cc..5c140a9a76 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index d32b6fd350..37bb89901e 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -52,13 +52,15 @@ install( # cutlass_add_library( - cutlass_lib - SHARED - src/library.cu + cutlass_library_objs + OBJECT + src/handle.cu src/manifest.cpp + src/operation_table.cu + src/singleton.cu + src/util.cu + ) -add_library(nvidia::cutlass::library ALIAS cutlass_lib) -set_target_properties(cutlass_lib PROPERTIES EXPORT_NAME library) file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/scripts/*.py) @@ -66,16 +68,19 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU # auto-instantiation of CUTLASS kernels # +# set cutlass generator compiler version to filter kernels in the generator not supported by a specific toolkit. +set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) + execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/scripts/generator.py - --operations all + --operations "${CUTLASS_LIBRARY_OPERATIONS}" --build-dir ${PROJECT_BINARY_DIR} --curr-build-dir ${CMAKE_CURRENT_BINARY_DIR} --generator-target library --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" --kernels "${CUTLASS_LIBRARY_KERNELS}" - --cuda-version "${CMAKE_CUDA_COMPILER_VERSION}" + --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log @@ -95,35 +100,70 @@ else() endif() target_include_directories( - cutlass_lib + cutlass_library_objs PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_BINARY_DIR}/include ) -set_target_properties( - cutlass_lib - PROPERTIES - OUTPUT_NAME cutlass - WINDOWS_EXPORT_ALL_SYMBOLS 1 - ) - target_link_libraries( - cutlass_lib + cutlass_library_objs PUBLIC - cutlass_library_includes + cutlass_library_includes ) +function(cutlass_add_cutlass_library) + + set(options) + set(oneValueArgs NAME TYPE EXPORT_NAME) + set(multiValueArgs) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + cutlass_add_library( + ${__NAME} + ${__TYPE} + EXPORT_NAME ${__EXPORT_NAME} + $ + ) + + target_link_libraries( + ${__NAME} + PUBLIC + cutlass_library_includes + ) + + set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX ${CUTLASS_LIBRARY_DEBUG_POSTFIX}) + + set(OUTPUT_NAME cutlass) + + if (WIN32 AND ${__TYPE} STREQUAL "STATIC") + set(OUTPUT_NAME "${OUTPUT_NAME}.static") + endif() + + set_target_properties( + ${__NAME} + PROPERTIES + OUTPUT_NAME ${OUTPUT_NAME} + WINDOWS_EXPORT_ALL_SYMBOLS 1 + ) + +endfunction() + +cutlass_add_cutlass_library(NAME cutlass_lib TYPE SHARED EXPORT_NAME library) +cutlass_add_cutlass_library(NAME cutlass_library_static TYPE STATIC EXPORT_NAME library_static) + install( - DIRECTORY - ${CMAKE_CURRENT_SOURCE_DIR}/include/ - DESTINATION - ${CMAKE_INSTALL_INCLUDEDIR} + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) install( - TARGETS cutlass_lib cutlass_library_includes + TARGETS + cutlass_lib + cutlass_library_static + cutlass_library_includes EXPORT NvidiaCutlass RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ) diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h new file mode 100644 index 0000000000..58c6b30c7c --- /dev/null +++ b/tools/library/include/cutlass/library/handle.h @@ -0,0 +1,342 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief BLAS-like handle used to launch operations on the CUDA device. +*/ + +#pragma once + +#include +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Handle object +class Handle { +private: + + /// Host workspace + static int const kHostWorkspaceSize = (4 << 10); + + /// Provider of operations + Provider provider_; + + /// CUDA device properties + cudaDeviceProp device_; + + /// CUDA stream + cudaStream_t stream_; + + /// Device workspace + void *workspace_; + + /// Size of device workspace in bytes + size_t workspace_size_; + + /// Indicates whether scalars are host or device pointers + ScalarPointerMode scalar_pointer_mode_; + + /// Pointer to the most recently executed operation + Operation const *last_operation_; + +public: + + /// Constructor + Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); + + /// Destructor + ~Handle(); + + /// Move constructor + Handle(Handle && handle); + + /// Move assignment operator + Handle &operator=(Handle && handle); + + // + // Persistent state accessors + // + + /// Returns compute capability of the selected device + int compute_capability() const; + + /// Sets the current CUDA stream + void set_stream(cudaStream_t stream); + + /// Gets the current CUDA stream + cudaStream_t get_stream() const; + + /// Gets the current provider + Provider get_provider() const; + + /// Sets the provider of operations + void set_provider(Provider provider); + + /// Gets the device workspace size + size_t get_workspace_size() const; + + /// Gets a pointer to the device workspace allocation in Global Memory + void *get_workspace() const; + + /// Sets the size of device workspace, invalidating calls to get_device_workspace() + void set_workspace_size(size_t bytes); + + /// Gets the scalar pointer mode + ScalarPointerMode get_scalar_pointer_mode() const; + + /// Sets the scalar pointer mode + void set_scalar_pointer_mode(ScalarPointerMode mode); + + /// Gets the most recently executed operation + Operation const *get_last_operation() const; + + // + // Computations + // + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C + Status gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int ldd /// Leading dimension of D matrix + ); + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C. + // + // Supports batched-strided, batched array or split-K serial or split-K parallel. + // + Status gemm_universal( + + GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int ldd, /// Leading dimension of D matrix + + int batch_count = 1, /// Batch count or number of split-K slices + + int64_t batch_stride_A = 0, /// Batch stride of A operand + int64_t batch_stride_B = 0, /// Batch stride of B operand + int64_t batch_stride_C = 0, /// Batch stride of C operand + int64_t batch_stride_D = 0 /// Batch stride of D operand + ); + + /// Planar complex GEMM + /// + /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. + /// + Status gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count = 1, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real = 0, + int64_t batch_stride_A_imag = 0, + + int64_t batch_stride_B_real = 0, + int64_t batch_stride_B_imag = 0, + + int64_t batch_stride_C_real = 0, + int64_t batch_stride_C_imag = 0, + + int64_t batch_stride_D_real = 0, + int64_t batch_stride_D_imag = 0 + ); + + /// Planar complex GEMM loading pointers from arrays in global memory + Status gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices + + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices + + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag /// Leading dimension of imaginary part of D matrix + ); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Unique pointer storing the handle +using HandlePtr = std::unique_ptr; + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 787da7cb25..d093b6118c 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -44,6 +44,7 @@ #include #include #include +#include #include #include "cutlass/cutlass.h" @@ -68,6 +69,10 @@ enum class LayoutTypeID { kRowMajorInterleavedK4, kColumnMajorInterleavedK16, kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, kTensorNCHW, kTensorNHWC, kInvalid @@ -89,10 +94,14 @@ enum class NumericTypeID { kS32, kS64, kF16, + kBF16, + kTF32, kF32, kF64, kCF16, + kCBF16, kCF32, + kCTF32, kCF64, kCS4, kCS8, @@ -110,12 +119,27 @@ enum class NumericTypeID { /// Enumeraed type describing a transformation on a complex value. enum class ComplexTransform { kNone, - kConjugate + kConjugate, + kInvalid +}; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kInvalid }; +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Enumeration indicating the kind of operation enum class OperationKind { kGemm, + kEqGemm, + kReduction, kInvalid }; @@ -143,6 +167,16 @@ enum class OpcodeClassID { kInvalid }; +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddComplex, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Enumeration indicating what kind of GEMM operation to perform @@ -150,88 +184,25 @@ enum class GemmKind { kGemm, kBatched, kArray, + kUniversal, kPlanarComplex, - kPlanarComplexBatched, + kPlanarComplexArray, kInvalid }; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Lexical cast from string -template T from_string(std::string const &); - -/// Converts a NumericType enumerant to a string -char const *to_string(OperationKind type, bool pretty = false); - -/// Parses a NumericType enumerant from a string -template <> OperationKind from_string(std::string const &str); - -/// Converts a NumericType enumerant to a string -char const *to_string(NumericTypeID type, bool pretty = false); - -/// Parses a NumericType enumerant from a string -template <> NumericTypeID from_string(std::string const &str); - -/// Returns the size of a data type in bits -int sizeof_bits(NumericTypeID type); - -/// Returns true if the numeric type is a complex data type or false if real-valued. -bool is_complex_type(NumericTypeID type); - -/// Returns the real-valued type underlying a type (only different from 'type' if complex) -NumericTypeID get_real_type(NumericTypeID type); - -/// Returns true if numeric type is integer -bool is_integer_type(NumericTypeID type); - -/// Returns true if numeric type is signed -bool is_signed_type(NumericTypeID type); - -/// Returns true if numeric type is a signed integer -bool is_signed_integer(NumericTypeID type); - -/// returns true if numeric type is an unsigned integer -bool is_unsigned_integer(NumericTypeID type); - -/// Returns true if numeric type is floating-point type -bool is_float_type(NumericTypeID type); - -/// To string method for cutlass::Status -char const *to_string(Status status, bool pretty = false); - -/// Converts a LayoutTypeID enumerant to a string -char const *to_string(LayoutTypeID layout, bool pretty = false); - -/// Parses a LayoutType enumerant from a string -template <> LayoutTypeID from_string(std::string const &str); - -/// Returns the rank of a layout's stride base on the LayoutTypeID -int get_layout_stride_rank(LayoutTypeID layout_id); - -/// Converts a OpcodeClassID enumerant to a string -char const *to_string(OpcodeClassID type, bool pretty = false); - -/// Converts a OpcodeClassID enumerant from a string -template <> -OpcodeClassID from_string(std::string const &str); - -/// Lexical cast from int64_t to string -std::string lexical_cast(int64_t int_value); - -/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); +/// Mode of Universal GEMM +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; -/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -std::string lexical_cast(std::vector &bytes, NumericTypeID type); - -/// Casts from a signed int64 to the destination type. Returns true if successful. -bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); - -/// Casts from an unsigned int64 to the destination type. Returns true if successful. -bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); - -/// Casts from a real value represented as a double to the destination type. Returns true if successful. -bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); +enum class EpilogueKind { + kUnknown, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -246,6 +217,9 @@ struct MathInstructionDescription { /// Classification of math instruction OpcodeClassID opcode_class; + /// Type of math operation performed + MathOperationID math_operation; + // // Methods // @@ -253,9 +227,29 @@ struct MathInstructionDescription { MathInstructionDescription( cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator = NumericTypeID::kInvalid, - OpcodeClassID opcode_class = OpcodeClassID::kInvalid + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd ): - instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {} + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline + bool operator==(MathInstructionDescription const& rhs) const{ + return ( + (instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline + bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } }; @@ -298,6 +292,24 @@ struct TileDescription { math_instruction(math_instruction), minimum_compute_capability(minimum_compute_capability), maximum_compute_capability(maximum_compute_capability) { } + + // Equality operator + inline + bool operator==(TileDescription const& rhs) const{ + return ( + (threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline + bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } }; /// High-level description of an operation @@ -306,6 +318,9 @@ struct OperationDescription { /// Unique identifier describing the operation char const * name; + /// Operation provider + Provider provider; + /// Kind of operation OperationKind kind; @@ -317,6 +332,7 @@ struct OperationDescription { // OperationDescription( char const * name = "unknown", + Provider Provider = Provider::kInvalid, OperationKind kind = OperationKind::kInvalid, TileDescription const & tile_description = TileDescription() ): @@ -340,10 +356,11 @@ struct TensorDescription { /// log2() of the maximum value each relevant stride may have int log_stride_range; - + // // Methods // + TensorDescription( NumericTypeID element = NumericTypeID::kInvalid, LayoutTypeID layout = LayoutTypeID::kInvalid, @@ -355,7 +372,7 @@ struct TensorDescription { layout(layout), alignment(alignment), log_extent_range(log_extent_range), - log_stride_range(log_stride_range) { } + log_stride_range(log_stride_range) { } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -411,10 +428,24 @@ struct GemmDescription : public OperationDescription { transform_B(transform_B) {} }; + +/// Description of all Reduction operations +struct ReductionDescription : public OperationDescription { + + /// Describes the data type of workspace + NumericTypeID element_workspace; + + /// Describes the data type of final output + NumericTypeID element_output; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Base class for all device-wide operations +/// Base class for all operations class Operation { public: @@ -435,7 +466,7 @@ class Operation { virtual Status initialize( void const *configuration, void *host_workspace, - void *device_workspace, + void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; virtual Status run( @@ -443,6 +474,7 @@ class Operation { void *host_workspace, void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; + }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -551,11 +583,18 @@ using GemmBatchedArguments = GemmArguments; struct GemmArrayConfiguration { gemm::GemmCoord problem_size; + + /// Leading dimension of A matrix + int64_t lda; + + /// Leading dimension of B matrix + int64_t ldb; + + /// Leading dimension of C matrix + int64_t ldc; - int64_t const *lda; - int64_t const *ldb; - int64_t const *ldc; - int64_t const *ldd; + /// Leading dimension of D matrix + int64_t ldd; int batch_count; }; @@ -573,56 +612,140 @@ struct GemmArrayArguments { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Complex valued GEMM in which real and imaginary parts are separated by a stride +/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex // // OperationKind: Gemm -// GemmKind: Planar complex +// GemmKind: Universal -struct GemmPlanarComplexConfiguration { +struct GemmUniversalConfiguration { + GemmUniversalMode mode; gemm::GemmCoord problem_size; + int batch_count; int64_t lda; int64_t ldb; int64_t ldc; int64_t ldd; - - int64_t imag_stride_A; - int64_t imag_stride_B; - int64_t imag_stride_C; - int64_t imag_stride_D; }; -using GemmPlanarComplexArgments = GemmArguments; +struct GemmUniversalArguments { + + void const *A; + void const *B; + void const *C; + void *D; + + void const *alpha; + void const *beta; + ScalarPointerMode pointer_mode; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; +}; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Batched complex valued GEMM in which real and imaginary parts are separated by a stride +/// Complex valued GEMM in which real and imaginary parts are separated by a stride // // OperationKind: Gemm -// GemmKind: Planar complex batched -// -struct GemmPlanarComplexBatchedConfiguration { +// GemmKind: Planar complex + +struct GemmPlanarComplexConfiguration { + GemmUniversalMode mode; gemm::GemmCoord problem_size; + int batch_count; - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t ldd; + int64_t lda_real; + int64_t lda_imag; - int64_t imag_stride_A; - int64_t imag_stride_B; - int64_t imag_stride_C; - int64_t imag_stride_D; + int64_t ldb_real; + int64_t ldb_imag; - int64_t batched_stride_A; - int64_t batched_stride_B; - int64_t batched_stride_C; - int64_t batched_stride_D; + int64_t ldc_real; + int64_t ldc_imag; + + int64_t ldd_real; + int64_t ldd_imag; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArguments { + + void const *A_real; + void const *A_imag; + + void const *B_real; + void const *B_imag; + + void const *C_real; + void const *C_imag; + + void *D_real; + void *D_imag; + + void const *alpha; + void const *beta; + ScalarPointerMode pointer_mode; + + int64_t batch_stride_A_real; + int64_t batch_stride_A_imag; + + int64_t batch_stride_B_real; + int64_t batch_stride_B_imag; + + int64_t batch_stride_C_real; + int64_t batch_stride_C_imag; + + int64_t batch_stride_D_real; + int64_t batch_stride_D_imag; }; -using GemmPlanarComplexBatchedArguments = GemmArguments; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This is a special form of planar complex which loads pointers and problem size +/// from memory. +struct GemmPlanarComplexArrayConfiguration { + + gemm::GemmCoord problem_size; + int batch_count; + + int64_t lda_real; + int64_t lda_imag; + + int64_t ldb_real; + int64_t ldb_imag; + + int64_t ldc_real; + int64_t ldc_imag; + + int64_t ldd_real; + int64_t ldd_imag; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArrayArguments { + + int const *M; + int const *N; + int const *K; + + void const * const * A_real; + void const * const * A_imag; + void const * const * B_real; + void const * const * B_imag; + void const * const * C_real; + void const * const * C_imag; + void * const * D_real; + void * const * D_imag; + + void const * alpha; + void const * beta; + ScalarPointerMode pointer_mode; +}; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/manifest.h b/tools/library/include/cutlass/library/manifest.h index 7746fc3dac..54e51c1fd0 100644 --- a/tools/library/include/cutlass/library/manifest.h +++ b/tools/library/include/cutlass/library/manifest.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -45,6 +45,13 @@ namespace cutlass { namespace library { /////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declaration +class Manifest; + +// init and insert all cutlass gemm and conv2d op in manifest object (procedurally generated using generator.py) +void initialize_all(Manifest &manifest); + +///////////////////////////////////////////////////////////////////////////////////////////////////////// /// List of operations using OperationVector = std::vector>; @@ -55,10 +62,14 @@ using OperationVector = std::vector>; class Manifest { private: + /// Operation provider + Provider provider_; + /// Global list of operations OperationVector operations_; public: + Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } /// Top-level initialization Status initialize(); diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h new file mode 100644 index 0000000000..3821f65acb --- /dev/null +++ b/tools/library/include/cutlass/library/operation_table.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#pragma once +#include +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct GemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + ComplexTransform transform_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + ComplexTransform transform_B; + NumericTypeID element_C; + + // + // Methods + // + + inline + GemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + ComplexTransform transform_A = ComplexTransform::kNone, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + ComplexTransform transform_B = ComplexTransform::kNone, + NumericTypeID element_C = NumericTypeID::kF16 + ): + provider(provider), + gemm_kind(gemm_kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + transform_A(transform_A), + element_B(element_B), + layout_B(layout_B), + transform_B(transform_B), + element_C(element_C) + { } + + inline + bool operator==(GemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (transform_A == rhs.transform_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (transform_B == rhs.transform_B) && + (element_C == rhs.element_C); + } + + inline + bool operator!=(GemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " transform_A: " << to_string(k.transform_A) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " transform_B: " << to_string(k.transform_B) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for GemmFunctionalKey +struct GemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8 - shl)); + } + + inline + size_t operator()(GemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.element_compute)), 3) ^ + rotl(hash(int(key.element_scalar)), 4) ^ + rotl(hash(int(key.element_A)), 5) ^ + rotl(hash(int(key.layout_A)), 6) ^ + rotl(hash(int(key.transform_A)), 7) ^ + rotl(hash(int(key.element_B)), 8) ^ + rotl(hash(int(key.layout_B)), 9) ^ + rotl(hash(int(key.transform_B)), 10) ^ + rotl(hash(int(key.element_C)), 11); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for GEMM operators +struct GemmPreferenceKey { + + int compute_capability; + int alignment; + + // + // Methods + // + + GemmPreferenceKey(): compute_capability(), alignment() { } + + GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } + + bool operator<(GemmPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); + } + + bool operator==(GemmPreferenceKey const &rhs) const { + return compute_capability == rhs.compute_capability; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using GemmOperationVectorMap = std::map< + GemmPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using GemmOperationFunctionalMap = std::unordered_map< + GemmFunctionalKey, + GemmOperationVectorMap, + GemmFunctionalKeyHasher +>; +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + + /// Map of all operations of type kGemm + // provider (kCUTLASS) + GemmOperationFunctionalMap gemm_operations; + +public: + + void append(Manifest const &manifest); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); diff --git a/tools/library/include/cutlass/library/singleton.h b/tools/library/include/cutlass/library/singleton.h new file mode 100644 index 0000000000..591ad78f48 --- /dev/null +++ b/tools/library/include/cutlass/library/singleton.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + + Singleton(); + + static Singleton const &get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/include/cutlass/library/util.h b/tools/library/include/cutlass/library/util.h new file mode 100644 index 0000000000..526f836b2b --- /dev/null +++ b/tools/library/include/cutlass/library/util.h @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief Utilities accompanying the CUTLASS library for interacting with Library types. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template T from_string(std::string const &); + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> Provider from_string(std::string const &str); + +/// Converts a GemmKind enumerant to a string +char const *to_string(GemmKind type, bool pretty = false); + +/// Converts a NumericType enumerant to a string +char const *to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> OperationKind from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> NumericTypeID from_string(std::string const &str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const *to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const *to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> LayoutTypeID from_string(std::string const &str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const *to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const &str); + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str); + + +/// Converts a SplitKMode enumerant to a string +char const *to_string(SplitKMode split_k_mode, bool pretty = false); + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const &str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. +bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. +std::string lexical_cast(std::vector &bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if successful. +bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if successful. +bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); + +/// Casts from a real value represented as a double to the destination type. Returns true if successful. +bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index 616587026f..66ecc05e69 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -22,7 +22,9 @@ # class GemmOperation: # - def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue): + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): + self.operation_kind = OperationKind.Gemm self.arch = arch self.tile_description = tile_description @@ -31,29 +33,78 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue) self.B = B self.C = C self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + # def core_name(self): ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + } + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - else: - inst_shape = '' + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - return "%s%s%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], inst_shape, GemmKindNames[self.gemm_kind]) + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) # def extended_name(self): ''' Append data types if they differ from compute type. ''' - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: + if self.is_complex(): extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" extended_name = SubstituteTemplate(extended_name, { 'element_a': DataTypeNames[self.A.element], @@ -63,28 +114,32 @@ def extended_name(self): return extended_name + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + # def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - if self.tile_description.stages > 2: - threadblock = "%dx%d_%dx%d" % ( - self.tile_description.threadblock_shape[0], - self.tile_description.threadblock_shape[1], - self.tile_description.threadblock_shape[2], - self.tile_description.stages - ) - else: - threadblock = "%dx%d" % (self.tile_description.threadblock_shape[0], self.tile_description.threadblock_shape[1]) + threadblock = self.tile_description.procedural_name() opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) + return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}", + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", { 'opcode_class': opcode_class_name, 'extended_name': self.extended_name(), 'threadblock': threadblock, - 'layout': "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]), + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, } ) @@ -104,7 +159,7 @@ class EmitGemmInstance: ''' Responsible for emitting a CUTLASS template definition''' def __init__(self): - self.template = """ + self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::Gemm< ${element_a}, ${layout_a}, @@ -116,14 +171,45 @@ def __init__(self): cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombination< + ${epilogue_functor}< ${element_c}, ${epilogue_vector_length}, ${element_accumulator}, ${element_epilogue} >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, - ${stages} + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + self.gemm_complex_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${transform_a}, + ${transform_b}, + ${math_operation} + ${residual} >; """ @@ -135,6 +221,8 @@ def emit(self, operation): epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + residual = '' + values = { 'operation_name': operation.procedural_name(), 'element_a': DataTypeTag[operation.A.element], @@ -143,7 +231,7 @@ def emit(self, operation): 'layout_b': LayoutTag[operation.B.layout], 'element_c': DataTypeTag[operation.C.element], 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], 'arch': "cutlass::arch::Sm%d" % operation.arch, 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), @@ -157,23 +245,34 @@ def emit(self, operation): 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages) + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual } - return SubstituteTemplate(self.template, values) + template = self.gemm_complex_template if operation.is_complex() else self.gemm_template + + return SubstituteTemplate(template, values) ################################################################################################### # -class EmitGemmBatchedInstance: +class EmitGemmUniversalInstance: ''' Responsible for emitting a CUTLASS template definition''' def __init__(self): - self.template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::GemmBatched< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand ${element_c}, ${layout_c}, ${element_accumulator}, ${opcode_class}, @@ -181,36 +280,90 @@ def __init__(self): cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombination< + ${epilogue_functor}< ${element_c}, ${epilogue_vector_length}, ${element_accumulator}, ${element_epilogue} >, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + ${swizzling_functor}, ${stages}, - ${align_a}, - ${align_b} - >; + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name} : + public ${operation_name}_base { }; +""" + self.gemm_template_interleaved = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name} : + public ${operation_name}_base { }; """ def emit(self, operation): - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - #warp_shape[2] = operation.tile_description.math_instruction.instruction_shape[2] - warp_shape[2] = operation.tile_description.threadblock_shape[2] + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + warp_shape[2] = operation.tile_description.threadblock_shape[2] epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + if operation.A.layout in transpose_layouts.keys() and \ + operation.B.layout in transpose_layouts.keys() and \ + operation.C.layout in transpose_layouts.keys(): + + instance_layout_A = transpose_layouts[operation.A.layout] + instance_layout_B = transpose_layouts[operation.B.layout] + instance_layout_C = transpose_layouts[operation.C.layout] + + gemm_template = self.gemm_template + else: + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + gemm_template = self.gemm_template_interleaved + # + values = { 'operation_name': operation.procedural_name(), 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], + 'layout_a': LayoutTag[instance_layout_A], 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], + 'layout_b': LayoutTag[instance_layout_B], 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], 'arch': "cutlass::arch::Sm%d" % operation.arch, 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), @@ -224,137 +377,167 @@ def emit(self, operation): 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], 'stages': str(operation.tile_description.stages), 'align_a': str(operation.A.alignment), 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] } - return SubstituteTemplate(self.template, values) + return SubstituteTemplate(gemm_template, values) ################################################################################################### -# -# Generator functions for all layouts -# -################################################################################################### - -# -def GenerateGemmSimt(gemm_kind, manifest, tile_descriptions, min_cc): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - # for each tile configuration, emit a GEMM - for tile in tile_descriptions: - for layout in layouts: - - A = TensorDescription(tile.math_instruction.element_a, layout[0], 1) - B = TensorDescription(tile.math_instruction.element_b, layout[1], 1) - C = TensorDescription(tile.math_instruction.element_accumulator, layout[2], 1) - - manifest.append(GemmOperation(gemm_kind, 50, tile, A, B, C, tile.math_instruction.element_accumulator)) # -def GenerateGemmTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]): - - # Canonical matrix layouts - canonical_layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - # Interleaved matrix layouts - interleaved_layouts = { - 8: [ - #(LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ], - 4: [ - #(LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - } - - # for each tile configuration, emit a GEMM - for align in minimum_alignment: - for tile in tile_descriptions: - - min_input_size = min(DataTypeSize[tile.math_instruction.element_a], DataTypeSize[tile.math_instruction.element_a]) - - # If the data type is large enough, use canonical layouts. - if min_input_size >= 16: - layouts = canonical_layouts - else: - layouts = interleaved_layouts[min_input_size] - - for layout in layouts: - - # - output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ - if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ - else [tile.math_instruction.element_accumulator,] +class EmitGemmPlanarComplexInstance: + ''' Responsible for emitting a CUTLASS template definition''' - align_a = align // DataTypeSize[tile.math_instruction.element_a] - align_b = align // DataTypeSize[tile.math_instruction.element_b] + def __init__(self): + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmKernel; + struct ${operation_name} : + public Operation_${operation_name} { }; +""" - for output_type in output_types: + def emit(self, operation): - rows_per_warp = 8 // tile.warp_count[1] - align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32) + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a) - B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b) - C = TensorDescription(output_type, layout[2], max(1, align_c)) + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] - element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \ - else tile.math_instruction.element_accumulator + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } - manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue)) + return SubstituteTemplate(self.template, values) +################################################################################################### # -def GenerateGemmWmmaTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]): - - # Wmma supported matrix layouts - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] +class EmitGemmPlanarComplexArrayInstance: + ''' Responsible for emitting a CUTLASS template definition''' - # for each tile configuration, emit a GEMM - for align in minimum_alignment: - for tile in tile_descriptions: - for layout in layouts: + def __init__(self): + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmArrayKernel; - # - output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ - if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ - else [tile.math_instruction.element_accumulator,] + struct ${operation_name} : public Operation_${operation_name} { }; +""" - align_a = align // DataTypeSize[tile.math_instruction.element_a] - align_b = align // DataTypeSize[tile.math_instruction.element_b] + def emit(self, operation): + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - for output_type in output_types: + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] - rows_per_warp = 8 // tile.warp_count[1] - align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32) + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } - A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a) - B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b) - C = TensorDescription(output_type, layout[2], max(1, align_c)) + return SubstituteTemplate(self.template, values) - element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \ - else tile.math_instruction.element_accumulator +################################################################################################### - manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue)) ################################################################################################### # @@ -369,21 +552,49 @@ def __init__(self, operation_path, configuration_name): self.instance_emitter = { GemmKind.Gemm: EmitGemmInstance, - GemmKind.Batched: EmitGemmBatchedInstance + GemmKind.Universal: EmitGemmUniversalInstance, + GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, + GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance } self.gemm_kind_wrappers = { GemmKind.Gemm: 'GemmOperation', - GemmKind.Batched: 'GemmBatchedOperation', + GemmKind.Universal: 'GemmUniversalOperation', + GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', + GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation' } self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" - self.instance_template = """ + self.instance_template = { + GemmKind.Gemm: """ ${compile_guard_start} manifest.append(new ${gemm_kind}("${operation_name}")); ${compile_guard_end} +""", + GemmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""", + GemmKind.PlanarComplex: """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""", + GemmKind.PlanarComplexArray: """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} """ + } + self.header_template = """ /* Generated by gemm_operation.py - Do not edit. @@ -398,6 +609,14 @@ def __init__(self, operation_path, configuration_name): #include "library_internal.h" #include "gemm_operation.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass { namespace library { @@ -421,9 +640,11 @@ def __init__(self, operation_path, configuration_name): def __enter__(self): self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(SubstituteTemplate(self.header_template, { - 'configuration_name': self.configuration_name - })) + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + self.operations = [] return self @@ -431,8 +652,10 @@ def emit(self, operation): emitter = self.instance_emitter[operation.gemm_kind]() self.operations.append(operation) - self.configuration_file.write(emitter.emit(operation)) - self.configuration_file.write(SubstituteTemplate(self.instance_template, { + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], { 'configuration_name': self.configuration_name, 'operation_name': operation.procedural_name(), 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], @@ -443,8 +666,22 @@ def emit(self, operation): })) def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + self.configuration_file.write(self.epilogue_template) self.configuration_file.close() ################################################################################################### ################################################################################################### + diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 6c09f180d7..2957864568 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -16,208 +16,1759 @@ # def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): - if semantic_ver_string == '': - cuda_version = [10, 2, 0] - else: - cuda_version = [int(x) for x in semantic_ver_string.split('.')] - return cuda_version >= [major, minor, patch] + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +################################################################################################### ################################################################################################### # -def GenerateSM50(manifest, args): +def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.args.kernels == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] + + # by default, only generate the largest tile and largest alignment + if manifest.args.kernels == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for gemm_kind in gemm_kinds: + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + manifest.append(GemmOperation(gemm_kind, \ + tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue)) + return + +################################################################################################### +################################################################################################### +################################################################################################### + +# +def GenerateSM50_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] min_cc = 50 max_cc = 1024 - stages = 2 - # single-precision - inst = MathInstruction([1, 1, 1], DataType.f32, DataType.f32, DataType.f32, OpcodeClass.Simt) - tile_descriptions = [ - TileDescription([128, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, 8], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([256, 128, 8], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 8], stages, [4, 1, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 8], stages, [1, 4, 1], inst, min_cc, max_cc), - ] + alignment_constraints = [1,] - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmSimt(GemmKind.Batched, manifest, tile_descriptions, min_cc) + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] - # double precision - inst = MathInstruction([1, 1, 1], DataType.f64, DataType.f64, DataType.f64, OpcodeClass.Simt) - tile_descriptions = [ - TileDescription([128, 128, 8], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 8], stages, [4, 1, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 8], stages, [1, 4, 1], inst, min_cc, max_cc), - ] + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmSimt(GemmKind.Batched, manifest, tile_descriptions, min_cc) +# +def GenerateSM50(manifest, args): + GenerateSM50_Simt(manifest, args) ################################################################################################### +################################################################################################### # -def GenerateSM60(manifest, args): +def GenerateSM60_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] min_cc = 60 max_cc = 1024 - stages = 2 - math_instructions = [ - MathInstruction([1, 1, 1], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.Simt), - ] + alignment_constraints = [1,] - tile_descriptions = [] + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] - for inst in math_instructions: - tile_descriptions += [ - TileDescription([256, 256, 8], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 8], stages, [1, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 8], stages, [2, 1, 1], inst, min_cc, max_cc), + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) +# +def GenerateSM60(manifest, args): + GenerateSM60_Simt(manifest, args) ################################################################################################### +################################################################################################### # -def GenerateSM61(manifest, args): +def GenerateSM61_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 4], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] min_cc = 61 max_cc = 1024 - stages = 2 - math_instructions = [ - MathInstruction([1, 1, 4], DataType.s8, DataType.s8, DataType.s32, OpcodeClass.Simt), - ] + alignment_constraints = [1,] - tile_descriptions = [] + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] - for inst in math_instructions: - tile_descriptions += [ - TileDescription([128, 256, 32], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([256, 128, 32], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 128, 32], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 32], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, 32], stages, [4, 1, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 32], stages, [1, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 32], stages, [2, 1, 1], inst, min_cc, max_cc), + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) +# +def GenerateSM61(manifest, args): + GenerateSM61_Simt(manifest, args) ################################################################################################### +################################################################################################### # -def GenerateSM70(manifest, args): +def GenerateSM70_TensorOp_884(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] min_cc = 70 max_cc = 75 - stages = 2 - k_groups = 8 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +def GenerateSM70_PlanarComplexTensorOp_884(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] math_instructions = [ - MathInstruction([8, 8, 4], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.TensorOp), - MathInstruction([8, 8, 4], DataType.f16, DataType.f16, DataType.f32, OpcodeClass.TensorOp), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), ] - tile_descriptions = [] + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 2, 1] - for inst in math_instructions: - kblock = k_groups * inst.instruction_shape[2] - tile_descriptions += [ - TileDescription([256, 128, kblock], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, kblock], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([128, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), ] - if CudaToolkitVersionSatisfies(args.cuda_version, 10, 1): - GenerateGemmTensorOp(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmTensorOp(GemmKind.Batched, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] - # wmma tensor op SM70 Gemm kernels - stages = 2 - k_groups = 2 + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + + +# +def GenerateSM70_WmmaTensorOp_161616(manifest, args): + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] math_instructions = [ - MathInstruction([16, 16, 16], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.WmmaTensorOp), - MathInstruction([16, 16, 16], DataType.f16, DataType.f16, DataType.f32, OpcodeClass.WmmaTensorOp), + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), ] - tile_descriptions = [] + min_cc = 70 + max_cc = 1024 - for inst in math_instructions: - kblock = k_groups * inst.instruction_shape[2] - tile_descriptions += [ - TileDescription([128, 128, kblock], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([64, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), + alignment_constraints = [8,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), ] - GenerateGemmWmmaTensorOp(GemmKind.Gemm, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +def GenerateSM70(manifest, args): + GenerateSM70_TensorOp_884(manifest, args) + GenerateSM70_PlanarComplexTensorOp_884(manifest, args) + + # To limit build size, WMMA GEMMs are disabled for now. + # + #GenerateSM70_WmmaTensorOp_161616(manifest, args) ################################################################################################### +################################################################################################### # -def GenerateSM75(manifest, args): +def GenerateSM75_TensorOp_1688(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM75_PlanarComplexTensorOp_1688(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM75_TensorOp_8816_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + DataType.s8, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8816_Interleaved(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + DataType.s8, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8832_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] min_cc = 75 max_cc = 1024 - stages = 2 - k_groups = 4 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + DataType.s4, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 8 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 4 + +# + +# +def GenerateSM75_TensorOp_8832_Interleaved(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] math_instructions = [ - MathInstruction([16, 8, 8], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.TensorOp), - MathInstruction([16, 8, 8], DataType.f16, DataType.f16, DataType.f32, OpcodeClass.TensorOp), - MathInstruction([8, 8, 16], DataType.s8, DataType.s8, DataType.s32, OpcodeClass.TensorOp), - MathInstruction([8, 8, 32], DataType.s4, DataType.s4, DataType.s32, OpcodeClass.TensorOp) + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), ] - tile_descriptions = [] + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + DataType.s4, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 + +# + +# +def GenerateSM75_WmmaTensorOp_161616(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] - for inst in math_instructions: - kblock = k_groups * inst.instruction_shape[2] - tile_descriptions += [ - TileDescription([256, 128, kblock], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, kblock], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([128, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), ] - if CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): - GenerateGemmTensorOp(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmTensorOp(GemmKind.Batched, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# + +# +def GenerateSM75(manifest, args): + GenerateSM75_TensorOp_1688(manifest, args) + GenerateSM75_PlanarComplexTensorOp_1688(manifest, args) + GenerateSM75_TensorOp_8816_TN(manifest, args) + GenerateSM75_TensorOp_8816_Interleaved(manifest, args) + GenerateSM75_TensorOp_8832_TN(manifest, args) + GenerateSM75_TensorOp_8832_Interleaved(manifest, args) + #GenerateSM75_WmmaTensorOp_161616(manifest, args) ################################################################################################### ################################################################################################### +# +def GenerateSM80_TensorOp_16816(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, 4, 2] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 4, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM80_PlanarComplexTensorOp_16816(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_16832_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 128], 3, [1, 4, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, DataType.s32, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s8, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM80_TensorOp_16832_Interleaved(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s8, DataType.f32] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 + +# + +# +def GenerateSM80_TensorOp_16864_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, DataType.s32, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s4, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 8 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 4 +# + +# +def GenerateSM80_TensorOp_16864_Interleaved(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s4, DataType.f32] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM80_TensorOp_168256(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + +# + +# +def GenerateSM80_TensorOp_1688(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 4, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 3, [1, 4, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM80_TensorOp_1688_fast_math(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f16), + MathInstruction( \ + [16, 8, 8], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_bf16) + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 4, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 3, [1, 4, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM80_TensorOp_1688_complex(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_complex(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_884_complex_gaussian(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +################################################################################################### + +# +def GenerateSM80_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 5, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +################################################################################################### + +# +def GenerateSM80(manifest, args): + + GenerateSM80_TensorOp_16816(manifest, args) + GenerateSM80_PlanarComplexTensorOp_16816(manifest, args) + GenerateSM80_TensorOp_1688(manifest, args) + GenerateSM80_TensorOp_1688_fast_math(manifest, args) + GenerateSM80_TensorOp_1688_complex(manifest, args) + GenerateSM80_TensorOp_884(manifest, args) + GenerateSM80_TensorOp_884_complex(manifest, args) + GenerateSM80_TensorOp_884_complex_gaussian(manifest, args) + GenerateSM80_TensorOp_16832_TN(manifest, args) + GenerateSM80_TensorOp_16832_Interleaved(manifest, args) + GenerateSM80_TensorOp_16864_TN(manifest, args) + GenerateSM80_TensorOp_16864_Interleaved(manifest, args) + GenerateSM80_TensorOp_168256(manifest, args) + GenerateSM80_Simt(manifest, args) +# + +################################################################################################### + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") - parser.add_argument("--operations", default="gemm", help="Specifies the operation to generate (gemm, all)") + parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") - parser.add_argument("--architectures", default='50 60 61 75', help="Target compute architectures") + parser.add_argument("--architectures", default='53;60;61;70;75;80', help="Target compute architectures") parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.') - parser.add_argument("--cuda-version", default="10.2.0", help="Semantic version string of CUDA Toolkit") - + parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") args = parser.parse_args() @@ -228,6 +1779,8 @@ def GenerateSM75(manifest, args): GenerateSM61(manifest, args) GenerateSM70(manifest, args) GenerateSM75(manifest, args) + GenerateSM80(manifest, args) + if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index b9ceeb4ff5..bdc4348308 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -4,14 +4,32 @@ # \brief Generates the CUTLASS Library's instances # -import enum import re ################################################################################################### +import enum + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + +################################################################################################### + # class GeneratorTarget(enum.Enum): - Library = enum.auto() + Library = enum_auto() # GeneratorTargetNames = { GeneratorTarget.Library: 'library' @@ -22,33 +40,37 @@ class GeneratorTarget(enum.Enum): # class DataType(enum.Enum): - b1 = enum.auto() - u4 = enum.auto() - u8 = enum.auto() - u16 = enum.auto() - u32 = enum.auto() - u64 = enum.auto() - s4 = enum.auto() - s8 = enum.auto() - s16 = enum.auto() - s32 = enum.auto() - s64 = enum.auto() - f16 = enum.auto() - f32 = enum.auto() - f64 = enum.auto() - cf16 = enum.auto() - cf32 = enum.auto() - cf64 = enum.auto() - cs4 = enum.auto() - cs8 = enum.auto() - cs16 = enum.auto() - cs32 = enum.auto() - cs64 = enum.auto() - cu4 = enum.auto() - cu8 = enum.auto() - cu16 = enum.auto() - cu32 = enum.auto() - cu64 = enum.auto() + b1 = enum_auto() + u4 = enum_auto() + u8 = enum_auto() + u16 = enum_auto() + u32 = enum_auto() + u64 = enum_auto() + s4 = enum_auto() + s8 = enum_auto() + s16 = enum_auto() + s32 = enum_auto() + s64 = enum_auto() + f16 = enum_auto() + bf16 = enum_auto() + f32 = enum_auto() + tf32 = enum_auto() + f64 = enum_auto() + cf16 = enum_auto() + cbf16 = enum_auto() + cf32 = enum_auto() + ctf32 = enum_auto() + cf64 = enum_auto() + cs4 = enum_auto() + cs8 = enum_auto() + cs16 = enum_auto() + cs32 = enum_auto() + cs64 = enum_auto() + cu4 = enum_auto() + cu8 = enum_auto() + cu16 = enum_auto() + cu32 = enum_auto() + cu64 = enum_auto() # ShortDataTypeNames = { @@ -74,10 +96,14 @@ class DataType(enum.Enum): DataType.s32: "s32", DataType.s64: "s64", DataType.f16: "f16", + DataType.bf16: "bf16", DataType.f32: "f32", + DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", + DataType.cbf16: "cbf16", DataType.cf32: "cf32", + DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu4: "cu4", DataType.cu8: "cu8", @@ -104,10 +130,14 @@ class DataType(enum.Enum): DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.f16: "cutlass::half_t", + DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", + DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex", + DataType.cbf16: "cutlass::complex", DataType.cf32: "cutlass::complex", + DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", @@ -134,10 +164,14 @@ class DataType(enum.Enum): DataType.s32: 32, DataType.s64: 64, DataType.f16: 16, + DataType.bf16: 16, DataType.f32: 32, + DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, + DataType.cbf16: 32, DataType.cf32: 64, + DataType.ctf32: 32, DataType.cf64: 128, DataType.cu4: 8, DataType.cu8: 16, @@ -153,19 +187,88 @@ class DataType(enum.Enum): ################################################################################################### +# +class ComplexTransform(enum.Enum): + none = enum_auto() + conj = enum_auto() + +# +ComplexTransformTag = { + ComplexTransform.none: 'cutlass::ComplexTransform::kNone', + ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', +} + +# +RealComplexBijection = [ + (DataType.f16, DataType.cf16), + (DataType.f32, DataType.cf32), + (DataType.f64, DataType.cf64), +] + +# +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + +# +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return DataType.invalid + +# +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return DataType.invalid + +# +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum_auto() + gaussian = enum_auto() + +################################################################################################### + +# +class MathOperation(enum.Enum): + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + xor_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + +# +MathOperationTag = { + MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', + MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', + MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', + MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', +} + +################################################################################################### + # class LayoutType(enum.Enum): - ColumnMajor = enum.auto() - RowMajor = enum.auto() - ColumnMajorInterleaved32 = enum.auto() - RowMajorInterleaved32 = enum.auto() - ColumnMajorInterleaved64 = enum.auto() - RowMajorInterleaved64 = enum.auto() - TensorNHWC = enum.auto() - TensorNCHW = enum.auto() - TensorNGHWC = enum.auto() - TensorNCxHW32 = enum.auto() - TensorNCxHW64 = enum.auto() + ColumnMajor = enum_auto() + RowMajor = enum_auto() + ColumnMajorInterleaved32 = enum_auto() + RowMajorInterleaved32 = enum_auto() + ColumnMajorInterleaved64 = enum_auto() + RowMajorInterleaved64 = enum_auto() + TensorNHWC = enum_auto() + TensorNCHW = enum_auto() + TensorNGHWC = enum_auto() + TensorNCxHW32 = enum_auto() + TensorNCxHW64 = enum_auto() # LayoutTag = { @@ -182,6 +285,17 @@ class LayoutType(enum.Enum): LayoutType.TensorNCxHW64: 'cutlass::layout::TensorNCxHW64' } +# +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, + LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, + LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, + LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, + LayoutType.TensorNHWC: LayoutType.TensorNHWC +} + # ShortLayoutTypeNames = { LayoutType.ColumnMajor: 'n', @@ -197,13 +311,21 @@ class LayoutType(enum.Enum): LayoutType.TensorNCxHW64: 'ncxhw64' } +# +ShortComplexLayoutNames = { + (LayoutType.ColumnMajor, ComplexTransform.none): 'n', + (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', + (LayoutType.RowMajor, ComplexTransform.none): 't', + (LayoutType.RowMajor, ComplexTransform.conj): 'h' +} + ################################################################################################### # class OpcodeClass(enum.Enum): - Simt = enum.auto() - TensorOp = enum.auto() - WmmaTensorOp = enum.auto() + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: 'simt', @@ -221,7 +343,7 @@ class OpcodeClass(enum.Enum): # class OperationKind(enum.Enum): - Gemm = enum.auto() + Gemm = enum_auto() # OperationKindNames = { OperationKind.Gemm: 'gemm' @@ -229,7 +351,7 @@ class OperationKind(enum.Enum): # class Target(enum.Enum): - library = enum.auto() + library = enum_auto() ArchitectureNames = { 50: 'maxwell', @@ -237,6 +359,7 @@ class Target(enum.Enum): 61: 'pascal', 70: 'volta', 75: 'turing', + 80: 'ampere', } ################################################################################################### @@ -244,40 +367,74 @@ class Target(enum.Enum): # def SubstituteTemplate(template, values): text = template - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - text = re.sub(regex, value, text) + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): - Gemm = enum.auto() - Batched = enum.auto() - Array = enum.auto() - PlanarComplex = enum.auto() - PlanarComplexBatched = enum.auto() + Gemm = enum_auto() + Batched = enum_auto() + Array = enum_auto() + Universal = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Batched: "gemm_batched", GemmKind.Array: "gemm_array", + GemmKind.Universal: "gemm", GemmKind.PlanarComplex: "gemm_planar_complex", - GemmKind.PlanarComplexBatched: "gemm_planar_complex_batched", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", } +# +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationClamp = enum_auto() + +# +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', + EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', +} + +# +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + +# +SwizzlingFunctorTag = { + SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', + SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', +} ################################################################################################### # class MathInstruction: - def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class): + def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class + self.math_operation = math_operation # @@ -292,16 +449,14 @@ def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_ self.maximum_compute_capability = max_compute def procedural_name(self): - if self.stages == 2: - return "%dx%dx%d" % self.threadblock_shape - elif self.stages > 2: - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class TensorDescription: - def __init__(self, element, layout, alignment = 1): + def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.alignment = alignment + self.complex_transform = complex_transform ################################################################################################### diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 9ff69eb635..756ddc7263 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -114,10 +114,20 @@ def __init__(self, args): self.args = args self.compute_capabilities = [int(x) for x in args.architectures.split(';')] + if args.operations == 'all': + self.operations_enabled = [] + else: + + operations_list = [ + OperationKind.Gemm + ] + + self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + if args.kernels == 'all': self.kernel_names = [] else: - self.kernel_names = args.kernels.split(',') + self.kernel_names = [x for x in args.kernels.split(',') if x != ''] self.operation_count = 0 self.operations_by_name = {} @@ -142,6 +152,16 @@ def __init__(self, args): } // namespace cutlass ''' + # + def _filter_string_matches(self, filter_string, haystack): + ''' Returns true if all substrings appear in the haystack in order''' + substrings = filter_string.split('*') + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub):] + return True # def filter(self, operation): @@ -159,6 +179,9 @@ def filter(self, operation): if not enabled: return False + if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: + return False + # eliminate duplicates if operation.procedural_name() in self.operations_by_name.keys(): return False @@ -168,11 +191,10 @@ def filter(self, operation): name = operation.procedural_name() enabled = False for name_substr in self.kernel_names: - if name_substr in name: + if self._filter_string_matches(name_substr, name): enabled = True break - # todo: filter based on operation kind # todo: filter based on compute data type return enabled # @@ -255,10 +277,11 @@ def emit(self, target = GeneratorTarget.Library): manifest_path = os.path.join(generated_path, "manifest.cmake") with open(manifest_path, "w") as manifest_file: - target_name = 'cutlass_lib' + target_name = 'cutlass_library_objs' target_text = SubstituteTemplate("""cutlass_target_sources( ${target_name} + BATCH_SOURCES ON PRIVATE """, { 'target_name': target_name}) diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index b00f8d2afd..23781b25ed 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,8 +29,14 @@ #pragma once #include "cutlass/cutlass.h" + #include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_complex.h" #include "cutlass/gemm/device/gemm_batched.h" +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" #include "cutlass/library/library.h" #include "library_internal.h" @@ -68,8 +74,10 @@ class GemmOperationBase : public Operation { GemmOperationBase(char const *name = "unknown_gemm") { description_.name = name; + description_.provider = Provider::kCUTLASS; description_.kind = OperationKind::kGemm; - + description_.gemm_kind = GemmKind::kGemm; + description_.tile_description.threadblock_shape = make_Coord( Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, @@ -93,22 +101,23 @@ class GemmOperationBase : public Operation { description_.tile_description.math_instruction.opcode_class = OpcodeClassMap::kId; + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + description_.tile_description.minimum_compute_capability = - ArchMap::kMin; + ArchMap::kMin; description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.gemm_kind = GemmKind::kGemm; + ArchMap::kMax; description_.A = make_TensorDescription(Operator::kAlignmentA); description_.B = make_TensorDescription(Operator::kAlignmentB); description_.C = make_TensorDescription(Operator::kAlignmentC); description_.element_epilogue = NumericTypeMap::kId; - description_.split_k_mode = Operator::kSplitKSerial ? SplitKMode::kSerial : SplitKMode::kNone; - description_.transform_A = ComplexTransform::kNone; - description_.transform_B = ComplexTransform::kNone; + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; } /// Returns the description of the GEMM operation @@ -294,8 +303,24 @@ class GemmOperation : public GemmOperationBase { return op->run(stream); } -}; + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -360,6 +385,7 @@ class GemmBatchedOperation : public GemmOperationBase { *static_cast(arguments->alpha), *static_cast(arguments->beta) ); + operator_args.epilogue = params; } else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ @@ -491,6 +517,788 @@ class GemmBatchedOperation : public GemmOperationBase { } }; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + description_.gemm_kind = GemmKind::kArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmArrayConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + + operator_args.batch_count = configuration->batch_count; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +public: + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmUniversalOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmUniversalConfiguration const *configuration) { + + operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplex; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = int(configuration->lda_real); + operator_args.lda_imag = int(configuration->lda_imag); + operator_args.ldb_real = int(configuration->ldb_real); + operator_args.ldb_imag = int(configuration->ldb_imag); + operator_args.ldc_real = int(configuration->ldc_real); + operator_args.ldc_imag = int(configuration->ldc_imag); + operator_args.ldd_real = int(configuration->ldd_real); + operator_args.ldd_imag = int(configuration->ldd_imag); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.batch_stride_A = arguments->batch_stride_A_real; + operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; + operator_args.batch_stride_B = arguments->batch_stride_B_real; + operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; + operator_args.batch_stride_C = arguments->batch_stride_C_real; + operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; + operator_args.batch_stride_D = arguments->batch_stride_D_real; + operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplexArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = int(configuration->lda_real); + operator_args.lda_imag = int(configuration->lda_imag); + operator_args.ldb_real = int(configuration->ldb_real); + operator_args.ldb_imag = int(configuration->ldb_imag); + operator_args.ldc_real = int(configuration->ldc_real); + operator_args.ldc_imag = int(configuration->ldc_imag); + operator_args.ldd_real = int(configuration->ldd_real); + operator_args.ldd_imag = int(configuration->ldd_imag); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.ptr_M = arguments->M; + operator_args.ptr_N = arguments->N; + operator_args.ptr_K = arguments->K; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu new file mode 100644 index 0000000000..bdddf2d7ca --- /dev/null +++ b/tools/library/src/handle.cu @@ -0,0 +1,1044 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Library handle. +*/ +#include +#include +#include + +#include "cutlass/library/handle.h" +#include "cutlass/library/singleton.h" +#include "cutlass/library/util.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructor +Handle::Handle( + cudaStream_t stream, + size_t workspace_size +): + provider_(Provider::kCUTLASS), + stream_(stream), + workspace_(nullptr), + workspace_size_(0), + scalar_pointer_mode_(ScalarPointerMode::kHost), + last_operation_(nullptr) { + + int device_idx = -1; + + cudaError_t error = cudaGetDevice(&device_idx); + if (error != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + error = cudaGetDeviceProperties(&device_, device_idx); + if (error != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + set_workspace_size(workspace_size); + + Singleton::get(); +} + +/// Destructor +Handle::~Handle() { + if (workspace_) { + + if (workspace_) { + cudaFree(workspace_); + } + + workspace_ = nullptr; + workspace_size_ = 0; + } +} + +/// Move constructor +Handle::Handle(Handle && handle) { + device_ = handle.device_; + workspace_size_ = handle.workspace_size_; + workspace_ = handle.workspace_; + stream_ = handle.stream_; + scalar_pointer_mode_ = handle.scalar_pointer_mode_; + + handle.workspace_ = nullptr; + handle.workspace_size_ = 0; +} + +/// Move assignment operator +Handle & Handle::operator=(Handle && handle) { + + provider_ = handle.provider_; + device_ = handle.device_; + workspace_size_ = handle.workspace_size_; + workspace_ = handle.workspace_; + stream_ = handle.stream_; + scalar_pointer_mode_ = handle.scalar_pointer_mode_; + + handle.workspace_ = nullptr; + handle.workspace_size_ = 0; + + return *this; +} + +int Handle::compute_capability() const { + return device_.major * 10 + device_.minor; +} + +/// Sets the current CUDA stream +void Handle::set_stream(cudaStream_t stream) { + stream_ = stream; +} + +/// Gets the current CUDA stream +cudaStream_t Handle::get_stream() const { + return stream_; +} + +/// Gets the current provider +Provider Handle::get_provider() const { + return provider_; +} + +/// Sets the provider of operations +void Handle::set_provider(Provider provider) { + provider_ = provider; +} + +/// Gets the device workspace size +size_t Handle::get_workspace_size() const { + return workspace_size_; +} + +/// Gets a pointer to the device workspace allocation in Global Memory +void *Handle::get_workspace() const { + return workspace_; +} + +/// Sets the size of device workspace, invalidating previous calls to get_device_workspace() +void Handle::set_workspace_size(size_t bytes) { + if (bytes != workspace_size_) { + + if (workspace_) { + cudaFree(workspace_); + } + + workspace_ = nullptr; + workspace_size_ = bytes; + + if (workspace_size_) { + + cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to allocate workspace"); + } + } + } + + if (workspace_) { + cudaError_t error = cudaMemset(workspace_, 0, workspace_size_); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to clear workspace"); + } + } +} + +/// Gets the scalar pointer mode +ScalarPointerMode Handle::get_scalar_pointer_mode() const { + return scalar_pointer_mode_; +} + +/// Sets the scalar pointer mode +void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) { + scalar_pointer_mode_ = mode; +} + +/// Gets the last operation +Operation const *Handle::get_last_operation() const { + return last_operation_; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the maximum required alignment for each operator +static int maximum_alignment_requirement(GemmDescription const &desc) { + return std::max( + std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment); +} + +/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a +/// given upper limit. +static int gemm_problem_alignment( + int M, + int N, + int K, + NumericTypeID element_A, + void const *ptr_A, + int lda, + int64_t batch_stride_A, + NumericTypeID element_B, + void const *ptr_B, + int ldb, + int64_t batch_stride_B, + NumericTypeID element_C, + void const * ptr_C, + int ldc, + int64_t batch_stride_C, + void const * ptr_D, + int ldd, + int64_t batch_stride_D, + int max_alignment_in_bytes = 16 +) { + + void const *pointers[] = { + ptr_A, ptr_B, ptr_C, ptr_D + }; + + int64_t extents[] = { + M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D + }; + + NumericTypeID elements[] = { + element_A, element_B, element_C + }; + + for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) { + + bool satisfied = true; + + // Can pointers satisfy this? + for (void const *ptr : pointers) { + std::uintptr_t int_ptr = reinterpret_cast(ptr); + + if (int_ptr % max_alignment_in_bytes) { + satisfied = false; + break; + } + } + + if (!satisfied) { + continue; + } + + // Compute the maximum alignment based on element data types + int max_element_alignment = 0; + + for (NumericTypeID type_id : elements) { + int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id); + max_element_alignment = std::max(max_element_alignment, element_alignment); + } + + // Can the problem size and leading dimensions satisfy this? + for (int64_t extent : extents) { + if (extent % max_element_alignment) { + satisfied = false; + break; + } + } + + if (!satisfied) { + continue; + } + + // Yes + return max_element_alignment; + } + + // No alignment satisfies this problem + return 0; +} + +/// Find the best kernel in descending order of preference. +static Operation const * find_gemm_operation( + GemmOperationFunctionalMap::const_iterator operators_it, + GemmPreferenceKey const preference_key) { + + auto cc_it = operators_it->second.upper_bound(preference_key); + + if (cc_it == operators_it->second.begin()) { + return nullptr; + } + + Operation const *operation = nullptr; + + // Search in descending order of compute capability + do { + --cc_it; + + // Search tile sizes in order, for now. + for (auto const * op : cc_it->second) { + + GemmDescription const &desc = static_cast(op->description()); + + int min_cc = desc.tile_description.minimum_compute_capability; + int max_cc = desc.tile_description.maximum_compute_capability; + + int op_alignment = maximum_alignment_requirement(desc); + + if ((min_cc <= preference_key.compute_capability) && + (preference_key.compute_capability <= max_cc) && + (op_alignment <= preference_key.alignment)) { + + operation = op; + break; + } + } + } while (!operation && cc_it != operators_it->second.begin()); + + return operation; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Executes a GEMM computation: D <= alpha * A*B + beta * C +Status Handle::gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int ldd /// Leading dimension of D matrix +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + provider_, + GemmKind::kGemm, + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + int alignment = gemm_problem_alignment( + M, N, K, + element_A, ptr_A, lda, 0, + element_B, ptr_B, ldb, 0, + element_C, ptr_C, ldc, 0, + ptr_D, ldd, 0, kMaximumAlignmentSize + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmConfiguration configuration{ + {M, N, K}, + lda, + ldb, + ldc, + ldd, + 1 + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmArguments arguments{ + ptr_A, + ptr_B, + ptr_C, + ptr_D, + alpha, + beta, + scalar_pointer_mode_ + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Executes a GEMM computation: D <= alpha * A*B + beta * C. +// +// Supports batched-strided, batched array or split-K serial or split-K parallel. +// +Status Handle::gemm_universal( + + GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int ldd, /// Leading dimension of D matrix + + int batch_count, /// Batch count or number of split-K slices + + int64_t batch_stride_A, /// Batch stride of A operand + int64_t batch_stride_B, /// Batch stride of B operand + int64_t batch_stride_C, /// Batch stride of C operand + int64_t batch_stride_D /// Batch stride of D operand +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + provider_, + GemmKind::kUniversal, + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + void const *ptr_A_check = ptr_A; + void const *ptr_B_check = ptr_B; + void const *ptr_C_check = ptr_C; + void * ptr_D_check = ptr_D; + + // Ignore alignment of pointers to pointers. We can't check this from the host, + // as each batch index has its own pointer in device memory. + if (mode == GemmUniversalMode::kArray) { + ptr_A_check = nullptr; + ptr_B_check = nullptr; + ptr_C_check = nullptr; + ptr_D_check = nullptr; + } + + int alignment = gemm_problem_alignment( + M, N, K, + element_A, ptr_A_check, lda, 0, + element_B, ptr_B_check, ldb, 0, + element_C, ptr_C_check, ldc, 0, + ptr_D_check, ldd, 0, kMaximumAlignmentSize + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmUniversalConfiguration configuration{ + mode, + {M, N, K}, + batch_count, + lda, + ldb, + ldc, + ldd + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmUniversalArguments arguments{ + ptr_A, + ptr_B, + ptr_C, + ptr_D, + alpha, + beta, + scalar_pointer_mode_, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Planar complex GEMM +Status Handle::gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real, + int64_t batch_stride_A_imag, + + int64_t batch_stride_B_real, + int64_t batch_stride_B_imag, + + int64_t batch_stride_C_real, + int64_t batch_stride_C_imag, + + int64_t batch_stride_D_real, + int64_t batch_stride_D_imag +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + provider_, + GemmKind::kPlanarComplex, + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + int alignment = std::max( + gemm_problem_alignment( + M, N, K, + element_A, ptr_A_real, lda_real, batch_stride_A_real, + element_B, ptr_B_real, ldb_real, batch_stride_B_real, + element_C, ptr_C_real, ldc_real, batch_stride_C_real, + ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize + ), + gemm_problem_alignment( + M, N, K, + element_A, ptr_A_imag, lda_imag, batch_stride_A_imag, + element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag, + element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag, + ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize + ) + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmPlanarComplexConfiguration configuration{ + GemmUniversalMode::kBatched, + {M, N, K}, + batch_count, + lda_real, + lda_imag, + ldb_real, + ldb_imag, + ldc_real, + ldc_imag, + ldd_real, + ldd_imag + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmPlanarComplexArguments arguments{ + ptr_A_real, + ptr_A_imag, + ptr_B_real, + ptr_B_imag, + ptr_C_real, + ptr_C_imag, + ptr_D_real, + ptr_D_imag, + alpha, + beta, + scalar_pointer_mode_, + batch_stride_A_real, + batch_stride_A_imag, + batch_stride_B_real, + batch_stride_B_imag, + batch_stride_C_real, + batch_stride_C_imag, + batch_stride_D_real, + batch_stride_D_imag + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Planar complex batched GEMM loading pointers from arrays in global memory +Status Handle::gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices + + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices + + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag /// Leading dimension of imaginary part of D matrix +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + provider_, + GemmKind::kPlanarComplexArray, + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + int alignment = std::max( + gemm_problem_alignment( + expected_M, expected_N, expected_K, + element_A, nullptr, lda_real, 0, + element_B, nullptr, ldb_real, 0, + element_C, nullptr, ldc_real, 0, + nullptr, ldd_real, 0, kMaximumAlignmentSize + ), + gemm_problem_alignment( + expected_M, expected_N, expected_K, + element_A, nullptr, lda_imag, 0, + element_B, nullptr, ldb_imag, 0, + element_C, nullptr, ldc_imag, 0, + nullptr, ldd_imag, 0, kMaximumAlignmentSize + ) + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmPlanarComplexArrayConfiguration configuration{ + {expected_M, expected_N, expected_K}, + batch_count, + lda_real, + lda_imag, + ldb_real, + ldb_imag, + ldc_real, + ldc_imag, + ldd_real, + ldd_imag + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmPlanarComplexArrayArguments arguments{ + M, N, K, + ptr_A_real, + ptr_A_imag, + ptr_B_real, + ptr_B_imag, + ptr_C_real, + ptr_C_imag, + ptr_D_real, + ptr_D_imag, + alpha, + beta, + scalar_pointer_mode_ + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 5feff5fb93..73847b117f 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -57,6 +57,10 @@ namespace library { template struct NumericTypeMap; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kS4; }; @@ -121,6 +125,40 @@ template <> struct NumericTypeMap > { static NumericTypeID const kId = NumericTypeID::kCF64; }; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kBF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kTF32; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct LayoutMap; @@ -133,6 +171,34 @@ template <> struct LayoutMap { static LayoutTypeID const kId = LayoutTypeID::kRowMajor; }; +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct OpcodeClassMap; @@ -148,35 +214,58 @@ template <> struct OpcodeClassMap { template <> struct OpcodeClassMap { static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ComplexTransformMap; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; +}; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// -template struct ArchMap; +template struct ArchMap; -template <> struct ArchMap { +template <> struct ArchMap { static int const kMin = 50; static int const kMax = 1024; }; -template <> struct ArchMap { +template <> struct ArchMap { static int const kMin = 60; static int const kMax = 1024; }; -template <> struct ArchMap { +template <> struct ArchMap { static int const kMin = 61; static int const kMax = 1024; }; -template <> struct ArchMap { +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> struct ArchMap { static int const kMin = 70; static int const kMax = 75; }; -template <> struct ArchMap { +template struct ArchMap { static int const kMin = 75; static int const kMax = 1024; }; +template struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/tools/library/src/manifest.cpp b/tools/library/src/manifest.cpp index 159bf3f09e..d4e8a884be 100644 --- a/tools/library/src/manifest.cpp +++ b/tools/library/src/manifest.cpp @@ -1,7 +1,5 @@ -/*! - -*//*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -37,11 +35,7 @@ namespace cutlass { namespace library { -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_all(Manifest &manifest); - -/////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////////////// /// Top-level initialization Status Manifest::initialize() { @@ -50,6 +44,7 @@ Status Manifest::initialize() { operations_.clear(); } + // initialize procedurally generated cutlass op in manifest object initialize_all(*this); return Status::kSuccess; diff --git a/tools/library/src/operation_table.cu b/tools/library/src/operation_table.cu new file mode 100644 index 0000000000..64e4f264cf --- /dev/null +++ b/tools/library/src/operation_table.cu @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +void OperationTable::append(Manifest const &manifest) { + + // Insert operations into appropriate data structure + for (auto const & operation : manifest) { + + OperationDescription const &desc = operation->description(); + + // insert all gemm operation into operation table + if (desc.kind == OperationKind::kGemm) { + GemmDescription const &gemm_desc = static_cast(desc); + + + GemmFunctionalKey functional_key( + gemm_desc.provider, + gemm_desc.gemm_kind, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + gemm_desc.C.element + ); + + Operation const *op = operation.get(); + + int cc = gemm_desc.tile_description.minimum_compute_capability; + + int alignment = std::max(std::max( + gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); + + GemmPreferenceKey preference_key(cc, alignment); + + gemm_operations[functional_key][preference_key].push_back(op); + } + + + } + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/singleton.cu b/tools/library/src/singleton.cu new file mode 100644 index 0000000000..642ac61a15 --- /dev/null +++ b/tools/library/src/singleton.cu @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" +#include "cutlass/library/singleton.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static std::unique_ptr instance; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Singleton::Singleton() { + + manifest.initialize(); + + operation_table.append(manifest); +} + +Singleton const & Singleton::get() { + if (!instance.get()) { + instance.reset(new Singleton); + } + return *instance.get(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/library.cu b/tools/library/src/util.cu similarity index 71% rename from tools/library/src/library.cu rename to tools/library/src/util.cu index 92f87c6153..427f0a2c52 100644 --- a/tools/library/src/library.cu +++ b/tools/library/src/util.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -25,26 +25,108 @@ #include #include - #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/complex.h" -#include "cutlass/library/library.h" #include "cutlass/layout/matrix.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + namespace cutlass { namespace library { ///////////////////////////////////////////////////////////////////////////////////////////////// +static struct { + char const *text; + char const *pretty; + Provider enumerant; +} +Provider_enumerants[] = { + {"none", "None", Provider::kNone}, + {"cutlass", "CUTLASS", Provider::kCUTLASS}, + {"host", "reference_host", Provider::kReferenceHost}, + {"device", "reference_device", Provider::kReferenceDevice}, + {"cublas", "cuBLAS", Provider::kCUBLAS}, +}; + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty) { + + for (auto const & possible : Provider_enumerants) { + if (provider == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Parses a Provider enumerant from a string +template <> +Provider from_string(std::string const &str) { + + for (auto const & possible : Provider_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return Provider::kInvalid; +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const *text; + char const *pretty; + GemmKind enumerant; +} +GemmKind_enumerants[] = { + {"gemm", "", GemmKind::kGemm}, + {"batched", "", GemmKind::kBatched}, + {"array", "", GemmKind::kArray}, + {"universal", "", GemmKind::kUniversal}, + {"planar_complex", "", GemmKind::kPlanarComplex}, + {"planar_complex_array", "", GemmKind::kPlanarComplexArray}, +}; + +/// Converts a ConvKind enumerant to a string +char const *to_string(GemmKind type, bool pretty) { + + for (auto const & possible : GemmKind_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + static struct { char const *text; char const *pretty; OperationKind enumerant; } OperationKind_enumerants[] = { - {"gemm", "Gemm", OperationKind::kGemm}, + {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, + {"gemm", "Gemm", OperationKind::kGemm}, }; /// Converts a Status enumerant to a string @@ -146,10 +228,14 @@ NumericTypeID_enumerants[] = { {"s32", "S32", NumericTypeID::kS32}, {"s64", "S64", NumericTypeID::kS64}, {"f16", "F16", NumericTypeID::kF16}, + {"bf16", "BF16", NumericTypeID::kBF16}, {"f32", "F32", NumericTypeID::kF32}, + {"tf32", "TF32", NumericTypeID::kTF32}, {"f64", "F64", NumericTypeID::kF64}, {"cf16", "CF16", NumericTypeID::kCF16}, + {"cbf16", "CBF16", NumericTypeID::kCBF16}, {"cf32", "CF32", NumericTypeID::kCF32}, + {"ctf32", "CTF32", NumericTypeID::kCTF32}, {"cf64", "CF64", NumericTypeID::kCF64}, {"cu4", "CU4", NumericTypeID::kCU4}, {"cu8", "CU8", NumericTypeID::kCU8}, @@ -201,8 +287,15 @@ NumericTypeID from_string(std::string const &str) { int sizeof_bits(NumericTypeID type) { switch (type) { case NumericTypeID::kF16: return 16; + case NumericTypeID::kBF16: return 16; + case NumericTypeID::kTF32: return 32; case NumericTypeID::kF32: return 32; case NumericTypeID::kF64: return 64; + case NumericTypeID::kCF16: return 32; + case NumericTypeID::kCBF16: return 32; + case NumericTypeID::kCF32: return 64; + case NumericTypeID::kCTF32: return 64; + case NumericTypeID::kCF64: return 128; case NumericTypeID::kS4: return 4; case NumericTypeID::kS8: return 8; case NumericTypeID::kS16: return 16; @@ -225,6 +318,8 @@ bool is_complex_type(NumericTypeID type) { case NumericTypeID::kCF16: return true; case NumericTypeID::kCF32: return true; case NumericTypeID::kCF64: return true; + case NumericTypeID::kCBF16: return true; + case NumericTypeID::kCTF32: return true; default: break; } return false; @@ -236,6 +331,8 @@ NumericTypeID get_real_type(NumericTypeID type) { case NumericTypeID::kCF16: return NumericTypeID::kF16; case NumericTypeID::kCF32: return NumericTypeID::kF32; case NumericTypeID::kCF64: return NumericTypeID::kF64; + case NumericTypeID::kCBF16: return NumericTypeID::kBF16; + case NumericTypeID::kCTF32: return NumericTypeID::kTF32; default: break; } return type; @@ -263,6 +360,8 @@ bool is_integer_type(NumericTypeID type) { bool is_signed_type(NumericTypeID type) { switch (type) { case NumericTypeID::kF16: return true; + case NumericTypeID::kBF16: return true; + case NumericTypeID::kTF32: return true; case NumericTypeID::kF32: return true; case NumericTypeID::kF64: return true; case NumericTypeID::kS4: return true; @@ -289,8 +388,15 @@ bool is_unsigned_integer(NumericTypeID type) { bool is_float_type(NumericTypeID type) { switch (type) { case NumericTypeID::kF16: return true; + case NumericTypeID::kBF16: return true; + case NumericTypeID::kTF32: return true; case NumericTypeID::kF32: return true; case NumericTypeID::kF64: return true; + case NumericTypeID::kCF16: return true; + case NumericTypeID::kCBF16: return true; + case NumericTypeID::kCTF32: return true; + case NumericTypeID::kCF32: return true; + case NumericTypeID::kCF64: return true; default: break; } return false; @@ -309,8 +415,18 @@ layout_aliases[] = { {LayoutTypeID::kColumnMajor, "column"}, {LayoutTypeID::kColumnMajor, "col"}, {LayoutTypeID::kColumnMajor, "n"}, + + {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, + {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, + + {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, + {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, + + {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, + {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, + {LayoutTypeID::kTensorNCHW, "nchw"}, - {LayoutTypeID::kTensorNHWC, "packed_nhwc"}, + {LayoutTypeID::kTensorNHWC, "nhwc"}, {LayoutTypeID::kUnknown, "*"}, {LayoutTypeID::kInvalid, nullptr} }; @@ -344,7 +460,12 @@ int get_layout_stride_rank(LayoutTypeID layout_id) { case LayoutTypeID::kColumnMajorInterleavedK4: case LayoutTypeID::kRowMajorInterleavedK4: case LayoutTypeID::kColumnMajorInterleavedK16: - case LayoutTypeID::kRowMajorInterleavedK16: return 1; + case LayoutTypeID::kRowMajorInterleavedK16: + case LayoutTypeID::kColumnMajorInterleavedK32: + case LayoutTypeID::kRowMajorInterleavedK32: + case LayoutTypeID::kColumnMajorInterleavedK64: + case LayoutTypeID::kRowMajorInterleavedK64: + return 1; case LayoutTypeID::kTensorNCHW: case LayoutTypeID::kTensorNHWC: return 3; default : throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); @@ -362,7 +483,7 @@ OpcodeClassID_enumerants[] = { {"simt", "", OpcodeClassID::kSimt}, {"tensorop", "", OpcodeClassID::kTensorOp}, {"wmmatensorop", "", OpcodeClassID::kWmmaTensorOp}, - {"wmma", "", OpcodeClassID::kWmmaTensorOp} + {"wmma", "", OpcodeClassID::kWmmaTensorOp}, }; /// Converts a OpcodeClassID enumerant to a string @@ -396,8 +517,92 @@ OpcodeClassID from_string(std::string const &str) { return OpcodeClassID::kInvalid; } -/////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const *text; + char const *pretty; + ComplexTransform enumerant; +} +ComplexTransform_enumerants[] = { + {"n", "none", ComplexTransform::kNone}, + {"c", "conj", ComplexTransform::kConjugate} +}; + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty) { + + for (auto const & possible : ComplexTransform_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str) { + + for (auto const & possible : ComplexTransform_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return ComplexTransform::kInvalid; +} + + +static struct { + char const *text; + char const *pretty; + SplitKMode enumerant; +} +SplitKMode_enumerants[] = { + {"serial", "", SplitKMode::kSerial}, + {"parallel", "", SplitKMode::kParallel}, +}; + +/// Converts a SplitKMode enumerant to a string +char const *to_string(SplitKMode type, bool pretty) { + for (auto const & possible : SplitKMode_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const &str) { + + for (auto const & possible : SplitKMode_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return SplitKMode::kInvalid; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str) { int size_bytes = sizeof_bits(type) / 8; @@ -458,6 +663,20 @@ bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string c *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; + case NumericTypeID::kBF16: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kTF32: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; case NumericTypeID::kF32: { ss >> *reinterpret_cast(bytes.data()); @@ -477,11 +696,29 @@ bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string c x->imag() = static_cast(std::imag(tmp)); } break; + case NumericTypeID::kCBF16: + { + std::complex tmp; + ss >> tmp; + cutlass::complex *x = reinterpret_cast *>(bytes.data()); + x->real() = static_cast(std::real(tmp)); + x->imag() = static_cast(std::imag(tmp)); + } + break; case NumericTypeID::kCF32: { ss >> *reinterpret_cast*>(bytes.data()); } break; + case NumericTypeID::kCTF32: + { + std::complex tmp; + ss >> tmp; + cutlass::complex *x = reinterpret_cast *>(bytes.data()); + x->real() = static_cast(std::real(tmp)); + x->imag() = static_cast(std::imag(tmp)); + } + break; case NumericTypeID::kCF64: { ss >> *reinterpret_cast*>(bytes.data()); @@ -562,6 +799,18 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { ss << tmp; } break; + case NumericTypeID::kBF16: + { + float tmp = *reinterpret_cast(bytes.data());; + ss << tmp; + } + break; + case NumericTypeID::kTF32: + { + float tmp = *reinterpret_cast(bytes.data());; + ss << tmp; + } + break; case NumericTypeID::kF32: { ss << *reinterpret_cast(bytes.data()); @@ -574,25 +823,59 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { break; case NumericTypeID::kCF16: { - std::complex tmp; - cutlass::complex const *x = reinterpret_cast const *>(bytes.data()); - tmp.real(x->real()); - tmp.imag(x->imag()); + ss << float(x->real()); - ss << tmp; + if (x->imag() != cutlass::half_t()) { + ss << "+i" << float(x->imag()); + } + } + break; + case NumericTypeID::kCBF16: + { + cutlass::complex const *x = + reinterpret_cast const *>(bytes.data()); + + ss << float(x->real()); + + if (x->imag() != cutlass::bfloat16_t()) { + ss << "+i" << float(x->imag()); + } } break; case NumericTypeID::kCF32: { - ss << *reinterpret_cast*>(bytes.data()); + cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); + + ss << x->real(); + + if (x->imag() != float()) { + ss << "+i" << x->imag(); + } + } + break; + case NumericTypeID::kCTF32: + { + cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); + + ss << float(x->real()); + + if (x->imag() != tfloat32_t()) { + ss << "+i" << float(x->imag()); + } } break; case NumericTypeID::kCF64: { - ss << *reinterpret_cast*>(bytes.data()); + cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); + + ss << x->real(); + + if (x->imag() != double()) { + ss << "+i" << x->imag(); + } } break; default: @@ -657,6 +940,16 @@ bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t sr *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; + case NumericTypeID::kBF16: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kTF32: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; case NumericTypeID::kF32: { *reinterpret_cast(bytes.data()) = static_cast(src); @@ -747,6 +1040,16 @@ bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; + case NumericTypeID::kBF16: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kTF32: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; case NumericTypeID::kF32: { *reinterpret_cast(bytes.data()) = static_cast(src); @@ -838,6 +1141,16 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; + case NumericTypeID::kBF16: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kTF32: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; case NumericTypeID::kF32: { *reinterpret_cast(bytes.data()) = static_cast(src); @@ -855,11 +1168,23 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr x->imag() = static_cast(float(0)); } break; + case NumericTypeID::kCBF16: + { + cutlass::complex *x = reinterpret_cast *>(bytes.data()); + x->real() = static_cast(bfloat16_t(src)); + x->imag() = static_cast(bfloat16_t(0)); + } + break; case NumericTypeID::kCF32: { *reinterpret_cast*>(bytes.data()) = std::complex(float(src), float(0)); } break; + case NumericTypeID::kCTF32: + { + *reinterpret_cast*>(bytes.data()) = std::complex(tfloat32_t(src), tfloat32_t(0)); + } + break; case NumericTypeID::kCF64: { *reinterpret_cast*>(bytes.data()) = std::complex(src, double(0)); diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index b3a06900ad..a47c831415 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -33,7 +33,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES src/gpu_timer.cpp src/device_allocation.cu src/device_context.cu - src/cublas_helpers.cpp + src/cublas_helpers.cpp src/problem_space.cpp src/operation_profiler.cu src/gemm_operation_profiler.cu @@ -54,11 +54,11 @@ set_target_properties(cutlass_profiler PROPERTIES EXPORT_NAME profiler) # Include paths # -target_include_directories(cutlass_profiler +target_include_directories( + cutlass_profiler PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src # Source directory - ../../tools/util/include -) + ) # # Library dependencies @@ -68,8 +68,8 @@ target_link_libraries( cutlass_profiler PRIVATE cutlass_lib - $<$:cublas> - gtest + cutlass_tools_util_includes + $<$:nvidia::cublas> cudart ) diff --git a/tools/profiler/src/cublas_helpers.cpp b/tools/profiler/src/cublas_helpers.cpp index 973dc44cd5..05262a22de 100644 --- a/tools/profiler/src/cublas_helpers.cpp +++ b/tools/profiler/src/cublas_helpers.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,28 +39,48 @@ namespace profiler { /// Converts a cuBLAS status to cutlass::Status Status get_cutlass_status(cublasStatus_t cublas) { - if (cublas == CUBLAS_STATUS_SUCCESS) { - return Status::kSuccess; - } - else if (cublas == CUBLAS_STATUS_INVALID_VALUE) { - return Status::kErrorInvalidProblem; - } - if (cublas == CUBLAS_STATUS_NOT_SUPPORTED) { - return Status::kErrorNotSupported; + switch (cublas) { + case CUBLAS_STATUS_SUCCESS: + return Status::kSuccess; + case CUBLAS_STATUS_INVALID_VALUE: + return Status::kErrorInvalidProblem; + case CUBLAS_STATUS_NOT_SUPPORTED: + return Status::kErrorNotSupported; + default: break; } return Status::kErrorInternal; } /// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -cublasOperation_t get_cublas_transpose_operation(library::LayoutTypeID layout) { +bool get_cublas_transpose_operation( + cublasOperation_t &operation, + library::LayoutTypeID layout, + library::ComplexTransform transform) { + switch (layout) { case library::LayoutTypeID::kColumnMajor: - return CUBLAS_OP_N; + if (transform == library::ComplexTransform::kNone) { + operation = CUBLAS_OP_N; + return true; + } + else { + return false; + } + break; case library::LayoutTypeID::kRowMajor: - return CUBLAS_OP_T; + if (transform == library::ComplexTransform::kNone) { + operation = CUBLAS_OP_T; + return true; + } + else if (transform == library::ComplexTransform::kConjugate) { + operation = CUBLAS_OP_C; + return true; + } + break; default: break; } - throw std::runtime_error("CUTLASS layout type does not correspond to cublas type"); + + return false; } /// Maps a CUTLASS numeric type to a cuBLAS data type enumeration @@ -114,6 +134,14 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele case library::NumericTypeID::kB1: break; + + case library::NumericTypeID::kCF32: + data_type = CUDA_C_32F; + return true; + + case library::NumericTypeID::kCF64: + data_type = CUDA_C_64F; + return true; case library::NumericTypeID::kInvalid: @@ -145,11 +173,116 @@ Status cublas_satisfies(library::GemmDescription const &desc) { return Status::kErrorNotSupported; } + // output type S4 and S8 not supported in cuBLAS + if (desc.C.element == library::NumericTypeID::kS4 || + desc.C.element == library::NumericTypeID::kS8) { + + return Status::kErrorNotSupported; + } + return Status::kSuccess; } ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +cublasGemmExDispatcher::cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_, + cublasGemmAlgo_t algorithm +): + configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) { + + bool good = true; + + good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); + good = (good && get_cublas_transpose_operation(trans_B, op_desc.B.layout, op_desc.transform_B)); + good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); + good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); + good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); + + good = (good && get_cublas_datatype( + compute_data_type, + op_desc.tile_description.math_instruction.element_accumulator)); + + // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe + // internal numerical data types used in the computation. +#if (__CUDA_VER_MAJOR__ >= 11) + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + if (good && + op_desc.A.element == library::NumericTypeID::kF32 && + op_desc.B.element == library::NumericTypeID::kF32 && + opcode_class == library::OpcodeClassID::kTensorOp) { + + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + } + else if (good) { + bool const isPedantic = false; + switch (compute_data_type) { + case CUDA_R_32F: + case CUDA_C_32F: + compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; + break; + case CUDA_R_64F: + case CUDA_C_64F: + compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; + break; + case CUDA_R_16F: + compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; + break; + case CUDA_R_32I: + compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; + break; + default: + good = false; + break; + } + } +#endif // __CUDA_VER_MAJOR__ >= 11 + + if (!good) { + status = Status::kErrorNotSupported; + } +} + +/// Executes GEMM using these arguments +cublasStatus_t cublasGemmExDispatcher::operator()(cublasHandle_t handle) { + + return cublasGemmEx( + handle, + trans_A, + trans_B, + configuration.problem_size.m(), + configuration.problem_size.n(), + configuration.problem_size.k(), + arguments.alpha, + arguments.A, + data_type_A, + int(configuration.lda), + arguments.B, + data_type_B, + int(configuration.ldb), + arguments.beta, + arguments.D, + data_type_C, + int(configuration.ldc), +#if (__CUDA_VER_MAJOR__ >= 11) + compute_type, +#else + compute_data_type, +#endif + algo + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace detail + } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/cublas_helpers.h b/tools/profiler/src/cublas_helpers.h index 6bb2f4e94b..9c8078466a 100644 --- a/tools/profiler/src/cublas_helpers.h +++ b/tools/profiler/src/cublas_helpers.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,6 +33,9 @@ #include "cutlass/cutlass.h" #include "cutlass/library/library.h" +#include "cutlass/library/util.h" + +#include "options.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -45,7 +48,10 @@ namespace profiler { Status get_cutlass_status(cublasStatus_t cublas); /// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -cublasOperation_t get_cublas_transpose_operation(library::LayoutTypeID layout); +bool get_cublas_transpose_operation( + cublasOperation_t &operation, + library::LayoutTypeID layout, + library::ComplexTransform transform = library::ComplexTransform::kNone); /// Maps a CUTLASS numeric type to a cuBLAS data type enumeration bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); @@ -86,6 +92,125 @@ class CublasCreate { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +/// Selects one or more cuBLAS algorithms. +static void select_cublas_algorithms( + std::vector &algorithms, + Options const &options, + library::GemmDescription const &op_desc) { + + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + switch (options.library.algorithm_mode) { + case AlgorithmMode::kMatching: + { + algorithms.push_back(get_cublas_gemm_algo( + op_desc.tile_description.threadblock_shape.m(), + op_desc.tile_description.threadblock_shape.n(), + op_desc.tile_description.threadblock_shape.k(), + opcode_class)); + break; + } + + case AlgorithmMode::kBest: + { + // Choose first enumerated mode. If none are enumerated, choose based on opcode class + // and evaluate all of them. + + if (options.library.algorithms.empty()) { + // Enumerate all algorithms + if (opcode_class == library::OpcodeClassID::kSimt) { + + for (int algo = CUBLAS_GEMM_DEFAULT; + algo <= CUBLAS_GEMM_ALGO23; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + else { + + for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + } + else { + // Use the listed algorithms + algorithms.reserve(options.library.algorithms.size()); + + for (int algo : options.library.algorithms) { + algorithms.push_back(reinterpret_cast(algo)); + } + } + + break; + } + + case AlgorithmMode::kDefault: + { + + // Use the library's default algorithm + algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? + CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + break; + } + default: + { + break; + } + } +} + +/// Dispatcher to cublasGemmEx() +struct cublasGemmExDispatcher { + + // + // Data members + // + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublass-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDA_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + cublasGemmAlgo_t algo; + Status status; + + // + // Methods + // + + cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_, + cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index e1664de7cc..90f4a95970 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,14 +29,9 @@ #include #include -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - // Profiler includes #include "cutlass_profiler.h" #include "gemm_operation_profiler.h" - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -49,7 +44,8 @@ CutlassProfiler::CutlassProfiler( ): options_(options) { - operation_profilers_.emplace_back(new GemmOperationProfiler); + operation_profilers_.emplace_back(new GemmOperationProfiler(options)); + } CutlassProfiler::~CutlassProfiler() { @@ -112,13 +108,6 @@ void CutlassProfiler::enumerate_() { /// Profiles all operations int CutlassProfiler::profile_() { - library::Manifest manifest; - Status status = manifest.initialize(); - - if (status != Status::kSuccess) { - return -1; - } - int result = 0; DeviceContext device_context; @@ -128,7 +117,7 @@ int CutlassProfiler::profile_() { if (options_.operation_kind == library::OperationKind::kInvalid || options_.operation_kind == profiler->kind()) { - result = profiler->profile_all(options_, manifest, device_context); + result = profiler->profile_all(options_, library::Singleton::get().manifest, device_context); if (result) { return result; @@ -165,7 +154,8 @@ void CutlassProfiler::print_usage_(std::ostream &out) { } out << "\n\nFor details about a particular function, specify the function name with --help.\n\nExample:\n\n" - << " $ cutlass_profiler --operation=Gemm --help\n\n"; + << " $ cutlass_profiler --operation=Gemm --help\n\n" + ; } /// Prints usage diff --git a/tools/profiler/src/cutlass_profiler.h b/tools/profiler/src/cutlass_profiler.h index 09d2401120..d3b592a4ea 100644 --- a/tools/profiler/src/cutlass_profiler.h +++ b/tools/profiler/src/cutlass_profiler.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -27,6 +27,10 @@ */ #pragma once +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" #include "options.h" #include "operation_profiler.h" diff --git a/tools/profiler/src/debug.h b/tools/profiler/src/debug.h index 3aaf3bb40c..aed11ca188 100644 --- a/tools/profiler/src/debug.h +++ b/tools/profiler/src/debug.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,11 +30,11 @@ #include -#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } +//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } //#define report(x) {} // Enable/Disble Profiler debug prints -#define DEBUG_PROFILER +//#define DEBUG_PROFILER //RED 31m // profiler prints debug messages in red //YELLOW 33m // ir prints debug messages in yellow @@ -43,7 +43,7 @@ #define debugprof(...) #else #define debugprof(...) do { \ - printf("\033[31m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ + printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ printf(__VA_ARGS__); \ printf("\033[0m\n"); \ } while (0) diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 2e04a1e8e6..4045abfeec 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -34,12 +34,12 @@ #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" - #include "cutlass/util/reference/host/tensor_fill.h" - #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/library/util.h" + #include "device_allocation.h" namespace cutlass { @@ -106,6 +106,18 @@ std::vector DeviceAllocation::get_packed_layout( case library::LayoutTypeID::kRowMajorInterleavedK16: stride = get_packed_layout_stride>(extent); break; + case library::LayoutTypeID::kColumnMajorInterleavedK32: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kRowMajorInterleavedK32: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK64: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kRowMajorInterleavedK64: + stride = get_packed_layout_stride>(extent); + break; case library::LayoutTypeID::kTensorNCHW: stride = get_packed_layout_stride(extent); break; @@ -200,6 +212,18 @@ size_t DeviceAllocation::construct_layout( case library::LayoutTypeID::kRowMajorInterleavedK16: return construct_layout_>(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kColumnMajorInterleavedK32: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kRowMajorInterleavedK32: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kColumnMajorInterleavedK64: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kRowMajorInterleavedK64: + return construct_layout_>(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kTensorNCHW: return construct_layout_(bytes, layout_id, extent, stride); @@ -407,6 +431,14 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF32: + cutlass::reference::device::BlockFillRandom>( + reinterpret_cast *>(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -415,6 +447,14 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF64: + cutlass::reference::device::BlockFillRandom>( + reinterpret_cast *>(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kS8: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -508,6 +548,22 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF16: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kCF32: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF64: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -516,6 +572,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF64: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kS8: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -608,12 +672,30 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_B), capacity); + case library::NumericTypeID::kCF32: + return reference::device::BlockCompareEqual >( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + + case library::NumericTypeID::kCF16: + return reference::device::BlockCompareEqual>( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + case library::NumericTypeID::kF64: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); + case library::NumericTypeID::kCF64: + return reference::device::BlockCompareEqual>( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + case library::NumericTypeID::kS8: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -765,6 +847,23 @@ bool DeviceAllocation::block_compare_relatively_equal( static_cast(epsilon), static_cast(nonzero_floor)); + // No relatively equal comparison for complex numbers. + // + // As a simplification, we can require bitwise equality. This avoids false positives. + // (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.) + // + case library::NumericTypeID::kCF32: + return reference::device::BlockCompareEqual >( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + + case library::NumericTypeID::kCF64: + return reference::device::BlockCompareEqual >( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + default: throw std::runtime_error("Unsupported numeric type"); } @@ -910,6 +1009,14 @@ void DeviceAllocation::write_tensor_csv( case library::NumericTypeID::kU64: write_tensor_csv_static_type(out, *this); break; + + case library::NumericTypeID::kCF32: + write_tensor_csv_static_type >(out, *this); + break; + + case library::NumericTypeID::kCF64: + write_tensor_csv_static_type >(out, *this); + break; default: throw std::runtime_error("Unsupported numeric type"); diff --git a/tools/profiler/src/device_allocation.h b/tools/profiler/src/device_allocation.h index be69f03733..f57cda1431 100644 --- a/tools/profiler/src/device_allocation.h +++ b/tools/profiler/src/device_allocation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index f695a9ed4a..f9cfe9ab58 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -74,16 +74,34 @@ DeviceAllocation *DeviceContext::allocate_tensor( allocate_tensor(name, type, layout_id, extent, stride); if (options.initialization.enabled) { + Distribution data_distribution = options.initialization.data_distribution; + + // check if data distribution is allowed to change + if(!options.initialization.fix_data_distribution) { + // change data distribution based on bit width + switch(type) { + case library::NumericTypeID::kB1: + data_distribution.set_uniform(0, 2, 0); + break; + case library::NumericTypeID::kS8: + data_distribution.set_uniform(-2, 2, 0); + break; + case library::NumericTypeID::kU8: + data_distribution.set_uniform(0, 4, 0); + break; + default: break; + } + } - if (options.initialization.provider == Provider::kReferenceDevice) { + if (options.initialization.provider == library::Provider::kReferenceDevice) { allocation->initialize_random_device( options.initialization.seed, - options.initialization.data_distribution); + data_distribution); } - else if (options.initialization.provider == Provider::kReferenceHost) { + else if (options.initialization.provider == library::Provider::kReferenceHost) { allocation->initialize_random_host( options.initialization.seed, - options.initialization.data_distribution); + data_distribution); } } diff --git a/tools/profiler/src/device_context.h b/tools/profiler/src/device_context.h index 7be0349adc..aea872eff8 100644 --- a/tools/profiler/src/device_context.h +++ b/tools/profiler/src/device_context.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/profiler/src/enumerated_types.cpp b/tools/profiler/src/enumerated_types.cpp index 7ca41789b0..29be6f8baf 100644 --- a/tools/profiler/src/enumerated_types.cpp +++ b/tools/profiler/src/enumerated_types.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -123,53 +123,6 @@ AlgorithmMode from_string(std::string const &str) { ///////////////////////////////////////////////////////////////////////////////////////////////// - -static struct { - char const *text; - char const *pretty; - Provider enumerant; -} -Provider_enumerants[] = { - {"cutlass", "CUTLASS", Provider::kCUTLASS}, - {"host", "reference_host", Provider::kReferenceHost}, - {"device", "reference_device", Provider::kReferenceDevice}, - {"cublas", "cuBLAS", Provider::kCUBLAS}, -}; - -/// Converts a Provider enumerant to a string -char const *to_string(Provider provider, bool pretty) { - - for (auto const & possible : Provider_enumerants) { - if (provider == possible.enumerant) { - if (pretty) { - return possible.pretty; - } - else { - return possible.text; - } - } - } - - return pretty ? "Invalid" : "invalid"; -} - -/// Parses a Provider enumerant from a string -template <> -Provider from_string(std::string const &str) { - - for (auto const & possible : Provider_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { - return possible.enumerant; - } - } - - return Provider::kInvalid; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - - static struct { char const *text; char const *pretty; @@ -180,6 +133,7 @@ Disposition_enumerants[] = { {"failed", "Failed", Disposition::kFailed}, {"not_run", "Not run", Disposition::kNotRun}, {"not_verified", "Not verified", Disposition::kNotVerified}, + {"invalid_problem", "Invalid problem", Disposition::kInvalidProblem}, {"not_supported", "Not supported", Disposition::kNotSupported}, {"incorrect", "Incorrect", Disposition::kIncorrect} }; diff --git a/tools/profiler/src/enumerated_types.h b/tools/profiler/src/enumerated_types.h index f9b19423e2..e7e713bdbf 100644 --- a/tools/profiler/src/enumerated_types.h +++ b/tools/profiler/src/enumerated_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,7 +30,9 @@ #include #include +#include #include +#include "cutlass/library/library.h" #define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } @@ -48,7 +50,7 @@ T from_string(std::string const &); enum class ExecutionMode { kProfile, ///< regular verification and profiling kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched - kEnumerate, ///< no kernels launched or workspaces allocated; lists all function types and functions + kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations kTrace, ///< executes a single device-side computation with no other kernel launches kInvalid }; @@ -79,26 +81,6 @@ AlgorithmMode from_string(std::string const &str); ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Providers -enum class Provider { - kCUTLASS, - kReferenceHost, - kReferenceDevice, - kCUBLAS, - kInvalid -}; - -using ProviderVector = std::vector; - -/// Converts a Provider enumerant to a string -char const *to_string(Provider provider, bool pretty = false); - -/// Parses a Provider enumerant from a string -template <> -Provider from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Outcome of a performance test enum class Disposition { kPassed, @@ -106,12 +88,13 @@ enum class Disposition { kNotRun, kIncorrect, kNotVerified, + kInvalidProblem, kNotSupported, kInvalid }; /// Converts a Disposition enumerant to a string -char const *to_string(Disposition provider, bool pretty = false); +char const *to_string(Disposition disposition, bool pretty = false); /// Parses a Disposition enumerant from a string template <> @@ -159,6 +142,21 @@ char const *to_string(ArgumentTypeID type, bool pretty = false); template <> ArgumentTypeID from_string(std::string const &str); +///////////////////////////////////////////////////////////////////////////////////////////////// +// Profiler typedefs +using ProviderVector = std::vector; +using DispositionMap = std::map; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Print vector for the report +template +std::ostream& operator<< (std::ostream& out, const std::vector& v) { + for(int i = 0; i < v.size(); ++i) { + out << to_string(v[i], true) << (i+1 != v.size() ? "," : ""); + } + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 4c8cb86a3b..f494eeee9f 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -31,6 +31,8 @@ #include #include +#include "cutlass/core_io.h" + #include "cublas_helpers.h" #include "gemm_operation_profiler.h" #include "gpu_timer.h" @@ -44,22 +46,27 @@ namespace profiler { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Ctor -GemmOperationProfiler::GemmOperationProfiler(): - OperationProfiler(library::OperationKind::kGemm,{ - {ArgumentTypeID::kEnumerated, {"Gemm_kind"}, "Variant of GEMM (e.g. gemm, planar complex, batched, ...)"}, - {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, - {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, - {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, - {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, - {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, - {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, - {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, - {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, - {ArgumentTypeID::kInteger, {"split_k_slices"}, "Number of partitions of K dimension"}, - {ArgumentTypeID::kInteger, {"batch_count"}, "Number of GEMMs computed in one batch"}, - }) { - - description_ = "General matrix-matrix product. D = alpha * A*B + beta * C"; +GemmOperationProfiler::GemmOperationProfiler(Options const &options): + OperationProfiler( + options, + library::OperationKind::kGemm, + { + {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"}, + {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, + {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, + {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, + {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, + {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, + {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, + {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, + {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, + {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, + {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, + }, + { library::Provider::kCUBLAS} + ) { + + description_ = " General matrix-matrix product. D = alpha * A*B + beta * C"; } /// Destructor @@ -107,6 +114,8 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const { << " --providers=cutlass --output=functional-test.csv\n\n"; } +///////////////////////////////////////////////////////////////////////////////////////////////// + #if 0 // used this for debugging static std::string byte_string(std::vector const &bytes) { @@ -122,48 +131,34 @@ static std::string byte_string(std::vector const &bytes) { } #endif -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Extracts the problem dimensions -Status GemmOperationProfiler::initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, +Status GemmOperationProfiler::GemmProblem::parse( + library::GemmDescription const &operation_desc, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - - library::GemmDescription const &operation_desc = - static_cast(operation->description()); - - if (operation_desc.gemm_kind != library::GemmKind::kGemm) { - return Status::kErrorInvalidProblem; - } - - - if (!arg_as_int(problem_.m, "m", problem_space, problem)) { + + if (!arg_as_int(this->m, "m", problem_space, problem)) { // default value - problem_.m = 1024; + this->m = 1024; } - if (!arg_as_int(problem_.n, "n", problem_space, problem)) { + if (!arg_as_int(this->n, "n", problem_space, problem)) { // default value - problem_.n = 1024; + this->n = 1024; } - if (!arg_as_int(problem_.k, "k", problem_space, problem)) { + if (!arg_as_int(this->k, "k", problem_space, problem)) { // default value - problem_.k = 1024; + this->k = 1024; } - if (!arg_as_int(problem_.split_k_slices, "split_k_slices", problem_space, problem)) { + if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { // default value - problem_.split_k_slices = 1; + this->split_k_slices = 1; } - if (!arg_as_int(problem_.batch_count, "batch_count", problem_space, problem)) { + if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { // default value - problem_.batch_count = 1; + this->batch_count = 1; } if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { @@ -179,37 +174,97 @@ Status GemmOperationProfiler::initialize_configuration( } if (!arg_as_scalar( - problem_.alpha, + this->alpha, operation_desc.element_epilogue, "alpha", problem_space, problem)) { - if (!cast_from_double(problem_.alpha, operation_desc.element_epilogue, 1)) { + if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { return Status::kErrorInternal; } } if (!arg_as_scalar( - problem_.beta, + this->beta, operation_desc.element_epilogue, "beta", problem_space, problem)) { - if (!cast_from_double(problem_.beta, operation_desc.element_epilogue, 0)) { + if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { return Status::kErrorInternal; } } + + this->lda = DeviceAllocation::get_packed_layout( + operation_desc.A.layout, {int(this->m), int(this->k)}).front(); + + this->ldb = DeviceAllocation::get_packed_layout( + operation_desc.B.layout, {int(this->k), int(this->n)}).front(); + + this->ldc = DeviceAllocation::get_packed_layout( + operation_desc.C.layout, {int(this->m), int(this->n)}).front(); + + return Status::kSuccess; +} + +/// Initializes a performance result +void GemmOperationProfiler::GemmProblem::initialize_result( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space) { + + result.arguments.resize(problem_space.rank()); + + set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); + + set_argument(result, "A", problem_space, + std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); + + set_argument(result, "B", problem_space, + std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); + + set_argument(result, "C", problem_space, + std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); - problem_.lda = DeviceAllocation::get_packed_layout( - operation_desc.A.layout, {int(problem_.m), int(problem_.k)}).front(); + set_argument(result, "m", problem_space, m); + set_argument(result, "n", problem_space, n); + set_argument(result, "k", problem_space, k); - problem_.ldb = DeviceAllocation::get_packed_layout( - operation_desc.B.layout, {int(problem_.k), int(problem_.n)}).front(); + set_argument(result, "split_k_slices", problem_space, split_k_slices); + set_argument(result, "batch_count", problem_space, batch_count); - problem_.ldc = DeviceAllocation::get_packed_layout( - operation_desc.C.layout, {int(problem_.m), int(problem_.n)}).front(); + set_argument(result, "alpha", problem_space, + library::lexical_cast(alpha, operation_desc.element_epilogue)); + + set_argument(result, "beta", problem_space, + library::lexical_cast(beta, operation_desc.element_epilogue)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Extracts the problem dimensions +Status GemmOperationProfiler::initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + library::GemmDescription const &operation_desc = + static_cast(operation->description()); + + if (operation_desc.gemm_kind != library::GemmKind::kUniversal) { + return Status::kErrorInvalidProblem; + } + + Status status = problem_.parse(operation_desc, problem_space, problem); + + if (status != Status::kSuccess) { + return status; + } gemm_workspace_.configuration.problem_size.m() = int(problem_.m); gemm_workspace_.configuration.problem_size.n() = int(problem_.n); @@ -218,7 +273,8 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_.configuration.ldb = problem_.ldb; gemm_workspace_.configuration.ldc = problem_.ldc; gemm_workspace_.configuration.ldd = problem_.ldc; - gemm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); + //gemm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); + gemm_workspace_.configuration.batch_count = int(problem_.split_k_slices); gemm_workspace_.arguments.A = nullptr; gemm_workspace_.arguments.B = nullptr; @@ -240,46 +296,41 @@ void GemmOperationProfiler::initialize_result_( library::GemmDescription const &operation_desc, ProblemSpace const &problem_space) { - result.provider = Provider::kCUTLASS; + result.provider = library::Provider::kCUTLASS; result.disposition = Disposition::kNotRun; result.status = Status::kSuccess; result.operation_name = operation_desc.name; - - result.arguments.resize(problem_space.rank()); - - set_argument_(result, "A", problem_space, - std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); - - set_argument_(result, "B", problem_space, - std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); - - set_argument_(result, "C", problem_space, - std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); - - set_argument_(result, "m", problem_space, problem_.m); - set_argument_(result, "n", problem_space, problem_.n); - set_argument_(result, "k", problem_space, problem_.k); - - set_argument_(result, "split_k_slices", problem_space, problem_.split_k_slices); - set_argument_(result, "batch_count", problem_space, problem_.batch_count); - - set_argument_(result, "alpha", problem_space, - library::lexical_cast(problem_.alpha, operation_desc.element_epilogue)); - - set_argument_(result, "beta", problem_space, - library::lexical_cast(problem_.beta, operation_desc.element_epilogue)); + + problem_.initialize_result(result, operation_desc, problem_space); OperationProfiler::initialize_result_(result, operation_desc, problem_space); + // Input bytes read and Output bytes written for the gemm problem result.bytes = int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * problem_.k + int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.n / 8) * problem_.k + - int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n * 2; + int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; - result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); + // Set is_beta_zero true if beta is zero + bool is_beta_zero = std::all_of(problem_.beta.begin(), problem_.beta.end(), [](uint8_t i) { return i==0; }); + + // Output bytes read for the gemm problem for non-zero beta values + if (!is_beta_zero) { + result.bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; + } + result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); result.runtime = 0; + // complex-valued support + switch (operation_desc.tile_description.math_instruction.math_operation) { + case library::MathOperationID::kMultiplyAddComplex: + result.flops *= 4; + break; + + default: break; + } + } /// Initializes workspace @@ -290,7 +341,7 @@ Status GemmOperationProfiler::initialize_workspace( library::Operation const *operation, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - + library::GemmDescription const &operation_desc = static_cast(operation->description()); @@ -348,7 +399,7 @@ Status GemmOperationProfiler::initialize_workspace( // Status status = Status::kSuccess; - if (options.profiling.provider_enabled(Provider::kCUTLASS)) { + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { if (options.execution_mode != ExecutionMode::kDryRun) { @@ -368,8 +419,13 @@ Status GemmOperationProfiler::initialize_workspace( // If CUTLASS is enabled, generate a result for it // results_.push_back(model_result_); - results_.back().provider = Provider::kCUTLASS; + results_.back().provider = library::Provider::kCUTLASS; + results_.back().op_kind = library::OperationKind::kGemm; results_.back().disposition = Disposition::kNotRun; + + for(auto provider : verification_providers_) { + results_.back().verification_map[provider] = Disposition::kNotRun; + } } return status; @@ -386,7 +442,7 @@ bool GemmOperationProfiler::verify_cutlass( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - if (!options.profiling.provider_enabled(Provider::kCUTLASS)) { + if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { return true; } @@ -423,197 +479,61 @@ bool GemmOperationProfiler::verify_cutlass( return false; } + // CUTLASS op ran the but not yet verified against any verification provider results_.back().disposition = Disposition::kNotVerified; + // + // Run verification providers + // + if (options.verification.enabled) { #if CUTLASS_ENABLE_CUBLAS - if (options.verification.provider_enabled(Provider::kCUBLAS)) { + if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { // Guard against unsupported cases auto const & gemm_desc = static_cast(operation->description()); - if (cublas_satisfies(gemm_desc) != Status::kSuccess) { - return true; - } + if (cublas_satisfies(gemm_desc) == Status::kSuccess) { - return verify_with_cublas_( - options, - report, - device_context, - operation, - problem_space, - problem); + // call cublas verification if supported + verify_with_cublas_( + options, + report, + device_context, + operation, + problem_space, + problem); + } + + else { + // set verification map for cublas to not supported + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; + } } #endif // #if CUTLASS_ENABLE_CUBLAS - } - - return true; -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if CUTLASS_ENABLE_CUBLAS - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Selects one or more cuBLAS algorithms. -static void select_cublas_algorithms( - std::vector &algorithms, - Options const &options, - library::GemmDescription const &op_desc) { - - library::OpcodeClassID const & opcode_class = - op_desc.tile_description.math_instruction.opcode_class; - - switch (options.library.algorithm_mode) { - case AlgorithmMode::kMatching: - { - algorithms.push_back(get_cublas_gemm_algo( - op_desc.tile_description.threadblock_shape.m(), - op_desc.tile_description.threadblock_shape.n(), - op_desc.tile_description.threadblock_shape.k(), - opcode_class)); - break; - } - - case AlgorithmMode::kBest: - { - // Choose first enumerated mode. If none are enumerated, choose based on opcode class - // and evaluate all of them. - - if (options.library.algorithms.empty()) { - // Enumerate all algorithms - if (opcode_class == library::OpcodeClassID::kSimt) { - - for (int algo = CUBLAS_GEMM_DEFAULT; - algo <= CUBLAS_GEMM_ALGO23; - ++algo) { - - algorithms.push_back(cublasGemmAlgo_t(algo)); - } - } - else { - - for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; - ++algo) { - - algorithms.push_back(cublasGemmAlgo_t(algo)); - } - } + // Update disposition to worst case verification outcome among all + // verification providers which are supported + bool is_any_verification_run_passed = false; + for(auto &m : results_.back().verification_map) { + if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { + results_.back().disposition = m.second; + return true; } - else { - // Use the listed algorithms - algorithms.reserve(options.library.algorithms.size()); - - for (int algo : options.library.algorithms) { - algorithms.push_back(reinterpret_cast(algo)); - } + if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { + is_any_verification_run_passed = true; } - - break; } - case AlgorithmMode::kDefault: - { - - // Use the library's default algorithm - algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? - CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - break; - } - default: - { - break; + if(is_any_verification_run_passed) { + results_.back().disposition = Disposition::kPassed; } } -} - -/// Dispatcher to cublasGemmEx() -struct cublasGemmExDispatcher { - - // - // Data members - // - library::GemmConfiguration configuration; - library::GemmArguments arguments; - - cublasOperation_t trans_A; - cublasOperation_t trans_B; - cudaDataType_t data_type_A; - cudaDataType_t data_type_B; - cudaDataType_t data_type_C; - cudaDataType_t compute_type; - cublasGemmAlgo_t algo; - Status status; - - // - // Methods - // - - cublasGemmExDispatcher( - library::GemmDescription const &op_desc, - library::GemmConfiguration configuration_, - library::GemmArguments arguments_, - cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT - ): - configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) { - - trans_A = get_cublas_transpose_operation(op_desc.A.layout); - trans_B = get_cublas_transpose_operation(op_desc.B.layout); - - bool good = true; - good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); - good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); - good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); - - good = (good && get_cublas_datatype( - compute_type, - op_desc.tile_description.math_instruction.element_accumulator)); - if (!good) { - status = Status::kErrorNotSupported; - } - } - - /// Executes GEMM using these arguments - cublasStatus_t operator()(cublasHandle_t handle) { - - return cublasGemmEx( - handle, - trans_A, - trans_B, - configuration.problem_size.m(), - configuration.problem_size.n(), - configuration.problem_size.k(), - arguments.alpha, - arguments.A, - data_type_A, - int(configuration.lda), - arguments.B, - data_type_B, - int(configuration.ldb), - arguments.beta, - arguments.D, - data_type_C, - int(configuration.ldc), - compute_type, - algo - ); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -#endif // CUTLASS_ENABLE_CUBLAS + // Return true means continue profiling + return true; +} /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -632,14 +552,16 @@ bool GemmOperationProfiler::verify_with_cublas_( library::GemmDescription const &gemm_desc = static_cast(operation->description()); + // + // Construct cuBLAS operators + // + CublasCreate handle; cublasStatus_t status = handle.get_cublas_create_status(); if (status != CUBLAS_STATUS_SUCCESS) { - results_.back().status = get_cutlass_status(status); - results_.back().disposition = Disposition::kFailed; - + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; return true; } @@ -682,7 +604,7 @@ bool GemmOperationProfiler::verify_with_cublas_( ); if (gemm_op.status != Status::kSuccess) { - results_.back().disposition = Disposition::kNotVerified; + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; return true; } @@ -692,8 +614,8 @@ bool GemmOperationProfiler::verify_with_cublas_( // Handle errors if (status != CUBLAS_STATUS_SUCCESS) { - results_.back().status = get_cutlass_status(status); - results_.back().disposition = Disposition::kNotVerified; + + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; return true; } @@ -701,7 +623,7 @@ bool GemmOperationProfiler::verify_with_cublas_( // Verify results // - results_.back().disposition = compare_tensors( + results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( options, *gemm_workspace_.Computed, *gemm_workspace_.Reference @@ -709,19 +631,18 @@ bool GemmOperationProfiler::verify_with_cublas_( // Save workspace if incorrect if (options.verification.save_workspace == SaveWorkspace::kIncorrect && - results_.back().disposition == Disposition::kIncorrect) { + results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { save_workspace( device_context, options, gemm_desc, - Provider::kCUTLASS, - Provider::kCUBLAS); + library::Provider::kCUTLASS, + library::Provider::kCUBLAS); } } catch (...) { - results_.back().disposition = Disposition::kFailed; - results_.back().status = Status::kErrorNotSupported; + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; } #endif @@ -741,7 +662,7 @@ bool GemmOperationProfiler::profile( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - if (options.profiling.provider_enabled(Provider::kCUTLASS)) { + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { // Initialize structure containing GEMM arguments gemm_workspace_.arguments.A = gemm_workspace_.A->data(); diff --git a/tools/profiler/src/gemm_operation_profiler.h b/tools/profiler/src/gemm_operation_profiler.h index 37401229ef..e4d23212e1 100644 --- a/tools/profiler/src/gemm_operation_profiler.h +++ b/tools/profiler/src/gemm_operation_profiler.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -31,10 +31,12 @@ #include #include #include +#include #include // CUTLASS Library includes #include "cutlass/library/library.h" +#include "cutlass/library/util.h" #include "cutlass/library/manifest.h" // Profiler includes @@ -74,6 +76,18 @@ class GemmOperationProfiler : public OperationProfiler { GemmProblem(): m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); }; /// Workspace used @@ -85,8 +99,8 @@ class GemmOperationProfiler : public OperationProfiler { DeviceAllocation *Computed; DeviceAllocation *Reference; - library::GemmConfiguration configuration; - library::GemmArguments arguments; + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; /// Buffer used for the operation's host workspace std::vector host_workspace; @@ -121,7 +135,7 @@ class GemmOperationProfiler : public OperationProfiler { // /// Ctor - GemmOperationProfiler(); + GemmOperationProfiler(Options const &options); /// Destructor virtual ~GemmOperationProfiler(); diff --git a/tools/profiler/src/gpu_timer.cpp b/tools/profiler/src/gpu_timer.cpp index 218e09d31f..eb3a841150 100644 --- a/tools/profiler/src/gpu_timer.cpp +++ b/tools/profiler/src/gpu_timer.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/profiler/src/gpu_timer.h b/tools/profiler/src/gpu_timer.h index ca00ad7aa2..5cd4b0037f 100644 --- a/tools/profiler/src/gpu_timer.h +++ b/tools/profiler/src/gpu_timer.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/profiler/src/main.cpp b/tools/profiler/src/main.cpp index a76fcf9ac8..a1e523111d 100644 --- a/tools/profiler/src/main.cpp +++ b/tools/profiler/src/main.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index df227a5894..754118a738 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -31,6 +31,7 @@ #include #include #include +#include #ifdef __unix__ #include @@ -55,30 +56,41 @@ OperationProfiler::OperationProfiler(): kind_(library::OperationKind::kInvalid) /// Ctor OperationProfiler::OperationProfiler( + Options const &options, library::OperationKind kind, ArgumentDescriptionVector const &arguments, - ProviderVector const & reference_providers + ProviderVector const & verification_providers ): - kind_(kind), arguments_(arguments), reference_providers_(reference_providers) { + kind_(kind), arguments_(arguments) { ArgumentDescriptionVector tile_description_arguments{ - {ArgumentTypeID::kEnumerated, {"op_class", "opcode-class"}, "Class of math instruction (SIMT or TensorOp)."}, - {ArgumentTypeID::kEnumerated, {"accum", "accumulator-type"}, "Math instruction accumulator data type."}, - {ArgumentTypeID::kInteger, {"cta_m", "threadblock-shape::m"}, "Threadblock shape in the M dimension."}, - {ArgumentTypeID::kInteger, {"cta_n", "threadblock-shape::n"}, "Threadblock shape in the N dimension."}, - {ArgumentTypeID::kInteger, {"cta_k", "threadblock-shape::k"}, "Threadblock shape in the K dimension."}, - {ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply."}, - {ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension."}, - {ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension."}, - {ArgumentTypeID::kInteger, {"warps_k", "warp-count::k"}, "Number of warps within threadblock along the K dimension."}, - {ArgumentTypeID::kInteger, {"inst_m", "instruction-shape::m"}, "Math instruction shape in the M dimension."}, - {ArgumentTypeID::kInteger, {"inst_n", "instruction-shape::n"}, "Math instruction shape in the N dimension."}, - {ArgumentTypeID::kInteger, {"inst_k", "instruction-shape::k"}, "Math instruction shape in the K dimension."}, - {ArgumentTypeID::kInteger, {"min_cc", "minimum-compute-capability"}, "Minimum device compute capability."}, - {ArgumentTypeID::kInteger, {"max_cc", "maximum-compute-capability"}, "Maximum device compute capability."} + {ArgumentTypeID::kEnumerated, {"op_class", "opcode-class"}, "Class of math instruction (simt, tensorop, wmmatensorop, wmma)"}, + {ArgumentTypeID::kEnumerated, {"accum", "accumulator-type"}, "Math instruction accumulator data type"}, + {ArgumentTypeID::kInteger, {"cta_m", "threadblock-shape::m"}, "Threadblock shape in the M dimension"}, + {ArgumentTypeID::kInteger, {"cta_n", "threadblock-shape::n"}, "Threadblock shape in the N dimension"}, + {ArgumentTypeID::kInteger, {"cta_k", "threadblock-shape::k"}, "Threadblock shape in the K dimension"}, + {ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"}, + {ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"}, + {ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"}, + {ArgumentTypeID::kInteger, {"warps_k", "warp-count::k"}, "Number of warps within threadblock along the K dimension"}, + {ArgumentTypeID::kInteger, {"inst_m", "instruction-shape::m"}, "Math instruction shape in the M dimension"}, + {ArgumentTypeID::kInteger, {"inst_n", "instruction-shape::n"}, "Math instruction shape in the N dimension"}, + {ArgumentTypeID::kInteger, {"inst_k", "instruction-shape::k"}, "Math instruction shape in the K dimension"}, + {ArgumentTypeID::kInteger, {"min_cc", "minimum-compute-capability"}, "Minimum device compute capability"}, + {ArgumentTypeID::kInteger, {"max_cc", "maximum-compute-capability"}, "Maximum device compute capability"} }; arguments_.insert(arguments_.end(), tile_description_arguments.begin(), tile_description_arguments.end()); + + for (auto provider : verification_providers) { + if (std::find( + options.verification.providers.begin(), + options.verification.providers.end(), + provider) != options.verification.providers.end()) { + + verification_providers_.push_back(provider); + } + } } /// Destructor @@ -225,7 +237,7 @@ int OperationProfiler::profile_all( ProblemSpace problem_space(arguments_, options.cmdline); // 1. Construct performance report - PerformanceReport report(options, problem_space.argument_names()); + PerformanceReport report(options, problem_space.argument_names(), kind_); // 2. For each problem in problem space ProblemSpace::Iterator problem_it = problem_space.begin(); @@ -248,8 +260,9 @@ int OperationProfiler::profile_all( auto min_cc = operation->description().tile_description.minimum_compute_capability; auto max_cc = operation->description().tile_description.maximum_compute_capability; - // Execute compatible operations if they satisfy the current device's compute capability + // Execute compatible cutlass operations if they satisfy the current device's compute capability if (operation->description().kind == kind_ && + operation->description().provider == library::Provider::kCUTLASS && options.device.compute_capability() >= min_cc && options.device.compute_capability() <= max_cc) { @@ -259,7 +272,7 @@ int OperationProfiler::profile_all( if (!filtered_by_name) { for (auto const & op_name : options.operation_names) { - if (operation_name.find(op_name) !=std::string::npos) { + if (find_string_matches_(op_name, operation_name)) { filtered_by_name = true; break; } @@ -269,7 +282,7 @@ int OperationProfiler::profile_all( if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { continue; } - + // A. Initialize configuration Status status = this->initialize_configuration( options, @@ -341,7 +354,7 @@ int OperationProfiler::profile_all( device_context, options, operation->description(), - Provider::kCUTLASS); + library::Provider::kCUTLASS); } // @@ -434,8 +447,8 @@ void OperationProfiler::save_workspace( DeviceContext &device_context, Options const &options, library::OperationDescription const &desc, - Provider provider, - Provider verification_provider) { + library::Provider provider, + library::Provider verification_provider) { for (auto const & named_allocation : device_context) { @@ -443,10 +456,10 @@ void OperationProfiler::save_workspace( std::stringstream filename; - filename << desc.name << "_" << to_string(provider) << "_"; + filename << desc.name << "_" << library::to_string(provider) << "_"; - if (verification_provider != Provider::kInvalid) { - filename << "verified_by_" << to_string(verification_provider) << "_"; + if (verification_provider != library::Provider::kInvalid) { + filename << "verified_by_" << library::to_string(verification_provider) << "_"; } filename << named_allocation.first + ".mat"; @@ -454,6 +467,7 @@ void OperationProfiler::save_workspace( std::ofstream out(filename.str()); allocation->write_tensor_csv(out); + out << "\n"; if (options.report.verbose) { std::cout << "wrote '" << filename.str() << "'" << std::endl; @@ -547,29 +561,28 @@ void OperationProfiler::initialize_result_( library::OperationDescription const &operation_desc, ProblemSpace const &problem_space) { - set_argument_(result, "op_class", problem_space, + set_argument(result, "op_class", problem_space, library::to_string(operation_desc.tile_description.math_instruction.opcode_class)); - set_argument_(result, "accum", problem_space, + set_argument(result, "accum", problem_space, library::to_string(operation_desc.tile_description.math_instruction.element_accumulator)); - set_argument_(result, "cta_m", problem_space, operation_desc.tile_description.threadblock_shape.m()); - set_argument_(result, "cta_n", problem_space, operation_desc.tile_description.threadblock_shape.n()); - set_argument_(result, "cta_k", problem_space, operation_desc.tile_description.threadblock_shape.k()); - set_argument_(result, "stages", problem_space, operation_desc.tile_description.threadblock_stages); - set_argument_(result, "warps_m", problem_space, operation_desc.tile_description.warp_count.m()); - set_argument_(result, "warps_n", problem_space, operation_desc.tile_description.warp_count.n()); - set_argument_(result, "warps_k", problem_space, operation_desc.tile_description.warp_count.k()); - set_argument_(result, "inst_m", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.m()); - set_argument_(result, "inst_n", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.n()); - set_argument_(result, "inst_k", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.k()); - set_argument_(result, "min_cc", problem_space, operation_desc.tile_description.minimum_compute_capability); - set_argument_(result, "max_cc", problem_space, operation_desc.tile_description.maximum_compute_capability); + set_argument(result, "cta_m", problem_space, operation_desc.tile_description.threadblock_shape.m()); + set_argument(result, "cta_n", problem_space, operation_desc.tile_description.threadblock_shape.n()); + set_argument(result, "cta_k", problem_space, operation_desc.tile_description.threadblock_shape.k()); + set_argument(result, "stages", problem_space, operation_desc.tile_description.threadblock_stages); + set_argument(result, "warps_m", problem_space, operation_desc.tile_description.warp_count.m()); + set_argument(result, "warps_n", problem_space, operation_desc.tile_description.warp_count.n()); + set_argument(result, "warps_k", problem_space, operation_desc.tile_description.warp_count.k()); + set_argument(result, "inst_m", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.m()); + set_argument(result, "inst_n", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.n()); + set_argument(result, "inst_k", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.k()); + set_argument(result, "min_cc", problem_space, operation_desc.tile_description.minimum_compute_capability); + set_argument(result, "max_cc", problem_space, operation_desc.tile_description.maximum_compute_capability); } - /// Helper -void OperationProfiler::set_argument_( +void OperationProfiler::set_argument( PerformanceResult &result, char const *name, ProblemSpace const &problem_space, @@ -578,7 +591,7 @@ void OperationProfiler::set_argument_( result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), value); } -void OperationProfiler::set_argument_( +void OperationProfiler::set_argument( PerformanceResult &result, char const *name, ProblemSpace const &problem_space, @@ -587,6 +600,39 @@ void OperationProfiler::set_argument_( result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), library::lexical_cast(value)); } + +/// finds string matches filter_string in operation_name +bool OperationProfiler::find_string_matches_( + std::string const &filter_string, + std::string const &operation_name) { + // Returns true if all substrings appear in the operation_name in order + + // Split filter_string of the format "gemm*f32*nt" to tokens ["gemm", "f32", "nt"] + std::string item; + std::istringstream iss(filter_string); + std::vector filter_tokens; + while (std::getline(iss, item, '*')) { + filter_tokens.push_back(item); + } + + // Search filter_tokens in operation_name in order + size_t start = 0, idx = 0; + for(auto & token : filter_tokens) { + // Check if characters left to be parsed in operation_name + if (start < operation_name.length()) { + // Find token in operation_name[start:] + idx = operation_name.substr(start).find(token); + if (idx == std::string::npos) { + return false; + } + } + start += (idx + token.length()); + } + + // All tokens in filter_string found in operation_name + return true; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler diff --git a/tools/profiler/src/operation_profiler.h b/tools/profiler/src/operation_profiler.h index 3019f3bdc2..c7e20f36f7 100644 --- a/tools/profiler/src/operation_profiler.h +++ b/tools/profiler/src/operation_profiler.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -35,6 +35,7 @@ // CUTLASS Library includes #include "cutlass/library/library.h" +#include "cutlass/library/util.h" #include "cutlass/library/manifest.h" // Profiler includes @@ -43,6 +44,7 @@ #include "performance_result.h" #include "performance_report.h" #include "problem_space.h" +#include "debug.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -71,7 +73,7 @@ class OperationProfiler { ArgumentDescriptionVector arguments_; /// List of providers used to verify and compare each result - ProviderVector reference_providers_; + ProviderVector verification_providers_; /// Model performance result initailized by the operation profiler with workload statistics /// and reasonable default state. @@ -90,9 +92,10 @@ class OperationProfiler { OperationProfiler(); OperationProfiler( + Options const &options, library::OperationKind kind, ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), - ProviderVector const & reference_providers = ProviderVector()); + ProviderVector const & verification_providers = ProviderVector()); /// Destructor virtual ~OperationProfiler(); @@ -192,31 +195,31 @@ class OperationProfiler { DeviceContext &device_context, Options const &options, library::OperationDescription const &desc, - Provider provider, - Provider verification_provider = Provider::kInvalid); - -protected: - - /// Sets operation description - static void initialize_result_( - PerformanceResult &result, - library::OperationDescription const &operation_desc, - ProblemSpace const &problem_space); - + library::Provider provider, + library::Provider verification_provider = library::Provider::kInvalid); + /// Helper to set a performance result member - static void set_argument_( + static void set_argument( PerformanceResult &result, char const *name, ProblemSpace const &problem_space, std::string const &value); /// Helper to set a performance result member - static void set_argument_( + static void set_argument( PerformanceResult &result, char const *name, ProblemSpace const &problem_space, int64_t value); +protected: + + /// Sets operation description + static void initialize_result_( + PerformanceResult &result, + library::OperationDescription const &operation_desc, + ProblemSpace const &problem_space); + /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( double &runtime, @@ -225,6 +228,12 @@ class OperationProfiler { void const *arguments, void *host_workspace, void *device_workspace); + +private: + /// finds string matches filter_string in operation_name + bool find_string_matches_( + std::string const &filter_string, + std::string const &operation_name); }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 367dc7a272..5f62a81e73 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -31,6 +31,8 @@ #include "cutlass/cutlass.h" #include "cutlass/version.h" +#include "cutlass/library/util.h" + #include "options.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -74,7 +76,7 @@ Options::Device::Device(cutlass::CommandLine const &cmdline) { void Options::Device::print_usage(std::ostream &out) const { out << "Device:\n" - << " --device= " + << " --device= " << " CUDA Device ID\n\n"; int device_count = 0; @@ -104,7 +106,7 @@ void Options::Device::print_usage(std::ostream &out) const { } out - << " --compute-capability= " + << " --compute-capability= " << " Override the compute capability.\n\n"; } @@ -161,24 +163,30 @@ Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) { if (cmdline.check_cmd_line_flag("initialization-provider")) { std::string str; cmdline.get_cmd_line_argument("initialization-provider", str); - provider = from_string(str); - if (provider == Provider::kInvalid) { + provider = library::from_string(str); + if (provider == library::Provider::kInvalid) { enabled = false; } - else if (provider != Provider::kReferenceHost && provider != Provider::kReferenceDevice) { + else if (provider != library::Provider::kReferenceHost && provider != library::Provider::kReferenceDevice) { throw std::runtime_error("Unsupported intialization provider specified."); } } else { - provider = Provider::kReferenceDevice; + provider = library::Provider::kReferenceDevice; } cmdline.get_cmd_line_argument("seed", seed, 2019); if (cmdline.check_cmd_line_flag("dist")) { + // user has set the data distribution (fix data distribution once set) + fix_data_distribution = true; + // set user provided data distribution get_distribution(cmdline, "dist", data_distribution); } else { + // profiler choosen data distribution (allowed to change based on numeric types) + fix_data_distribution = false; + // set uniform data distribution with range [-4, 4] data_distribution.set_uniform(-4, 4, 0); } @@ -247,12 +255,6 @@ void Options::Initialization::get_distribution( continue; // next token } - // Casts as integer without scaling - if (it->first.compare("integer") == 0) { - dist.int_scale = 0; - continue; // next token - } - // initialize other members for (int m = 0; members[m].label; ++m) { if (it->first == members[m].label && !it->second.empty()) { @@ -268,19 +270,23 @@ void Options::Initialization::print_usage(std::ostream &out) const { out << "Initialization:\n" - << " --initialization= " + << " --initialization= " << " Enables initialization (default: true). If false, device memory is" << end_of_line - << "not initialized after allocation.\n\n" + << " not initialized after allocation.\n\n" - << " --initialization-provider= " - << " Selects 'device' or 'host' initialization.\n\n" + << " --initialization-provider= " + << " Selects initialization provider {host, device*}. (default: '*')\n\n" - << " --dist= " - << " Data distribution of input tensors\n\n" + << " --dist= " + << " Data distribution of input tensors {uniform*, gaussian, identity, sequential}" << end_of_line + << " --dist=uniform,min:,max:,scale:" << end_of_line + << " --dist=gaussian,mean:,stddev:,scale:" << end_of_line + << " --dist=sequential,start:,delta:,scale:" << end_of_line + << " --dist=identity\n\n" - << " --seed= " + << " --seed= " << " Random number generator seed. Used to enforce deterministic" << end_of_line - << "initialization.\n\n"; + << " initialization.\n\n"; } @@ -331,12 +337,12 @@ void Options::Library::print_usage(std::ostream &out) const { out << "Library:\n" - << " --library-algo-mode= " + << " --library-algo-mode= " << " Indicates algorithm mode used to call libraries such as cuBLAS and cuDNN.\n" - << " " + << " " << " mode={default*,matching,best}\n\n" - << " --library-algos= " + << " --library-algos= " << " If --algorithm-mode=best, permits specifying a selection of algorithms.\n\n"; } @@ -372,12 +378,12 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) { providers.clear(); for (auto const &token : tokens) { - providers.push_back(from_string(token)); + providers.push_back(library::from_string(token)); } } else { - providers.push_back(Provider::kCUTLASS); - providers.push_back(Provider::kCUBLAS); + providers.push_back(library::Provider::kCUTLASS); + providers.push_back(library::Provider::kCUBLAS); } } @@ -385,21 +391,25 @@ void Options::Profiling::print_usage(std::ostream &out) const { out << "Profiling:\n" - << " --profiling-iterations= " + << " --profiling-iterations= " << " Number of iterations to profile each kernel. If zero, kernels" << end_of_line - << "are launched up to the profiling duration.\n\n" + << " are launched up to the profiling duration.\n\n" - << " --warmup-iterations= " + << " --warmup-iterations= " << " Number of iterations to execute each kernel prior to profiling.\n\n" - << " --sleep-duration= " - << " Number of ms to sleep between profiling periods (ms)\n\n" + << " --sleep-duration= " + << " Number of ms to sleep between profiling periods (ms).\n\n" - << " --profiling-enabled= " + << " --profiling-enabled= " << " If true, profiling is actually conducted.\n\n" - << " --providers= " - << " List of providers to be profiled for performance\n\n"; + << " --providers= " + << " List of providers to be profiled for performance. (default: '*')" << end_of_line + << " Gemm providers {cutlass*" + << "}" << end_of_line + << "\n\n"; + } void Options::Profiling::print_options(std::ostream &out, int indent) const { @@ -412,18 +422,18 @@ void Options::Profiling::print_options(std::ostream &out, int indent) const { int j = 0; for (auto const & provider : providers) { - out << (j++ ? ", " : "") << to_string(provider); + out << (j++ ? ", " : "") << library::to_string(provider); } out << "]\n"; } /// Returns true if a provider is enabled -bool Options::Profiling::provider_enabled(Provider provider) const { +bool Options::Profiling::provider_enabled(library::Provider provider) const { return std::find(providers.begin(), providers.end(), provider) != providers.end(); } /// Returns the index of a provider if its enabled -size_t Options::Profiling::index(Provider provider) const { +size_t Options::Profiling::index(library::Provider provider) const { size_t idx = 0; for (auto const & x : providers) { if (x == provider) { @@ -461,14 +471,15 @@ Options::Verification::Verification(cutlass::CommandLine const &cmdline) { providers.clear(); for (auto const &token : tokens) { - Provider provider = from_string(token); - if (provider != Provider::kInvalid) { + library::Provider provider = library::from_string(token); + if (provider != library::Provider::kInvalid) { providers.push_back(provider); } } } else { - providers.push_back(Provider::kCUBLAS); + providers.push_back(library::Provider::kCUBLAS); + providers.push_back(library::Provider::kReferenceDevice); } } @@ -476,22 +487,27 @@ void Options::Verification::print_usage(std::ostream &out) const { out << "Verification:\n" - << " --verification-enabled= " + << " --verification-enabled= " << " Whether to perform verification checks.\n\n" - << " --epsilon= " + << " --epsilon= " << " Error threshold. Setting to zero (default) requires" << end_of_line - << "bit-level equivalence.\n\n" + << " bit-level equivalence.\n\n" - << " --nonzero-floor= " + << " --nonzero-floor= " << " Results whose absolute value is less than this quantity" << end_of_line - << "are treated as zero for comparisons.\n\n" + << " are treated as zero for comparisons.\n\n" - << " --save-workspace={*never,incorrect,always}" - << " Specifies when to save the GEMM inputs and results to the filesystem.\n\n" + << " --save-workspace= " + << " Specifies when to save the GEMM inputs and results to the filesystem." << end_of_line + << " --save-workspace=never never save workspace (default)" << end_of_line + << " --save-workspace=incorrect save workspace for incorrect results" << end_of_line + << " --save-workspace=always always save workspace\n\n" - << " --verification-providers= " - << " List of providers used to verify result. (default: device)\n\n"; + << " --verification-providers= " + << " List of providers used to verify result. (default: '*')" << end_of_line + << " Gemm verification-providers {cublas*}" << end_of_line + << "\n\n"; } void Options::Verification::print_options(std::ostream &out, int indent) const { @@ -504,18 +520,18 @@ void Options::Verification::print_options(std::ostream &out, int indent) const { int j = 0; for (auto const & provider : providers) { - out << (j++ ? ", " : "") << to_string(provider); + out << (j++ ? ", " : "") << library::to_string(provider); } out << "]\n"; } /// Returns true if a provider is enabled -bool Options::Verification::provider_enabled(Provider provider) const { +bool Options::Verification::provider_enabled(library::Provider provider) const { return std::find(providers.begin(), providers.end(), provider) != providers.end(); } /// Returns the index of a provider if its enabled -size_t Options::Verification::index(Provider provider) const { +size_t Options::Verification::index(library::Provider provider) const { size_t idx = 0; for (auto const & x : providers) { if (x == provider) { @@ -546,22 +562,22 @@ void Options::Report::print_usage(std::ostream &out) const { out << "Report:\n" - << " --append= " + << " --append= " << " If true, result is appended to possibly existing file. Otherwise, " << end_of_line - << "any existing file is overwritten.\n\n" + << " any existing file is overwritten.\n\n" - << " --output= " - << " Path to output file for machine readable results.\n\n" + << " --output= " + << " Path to output file for machine readable results. Operation kind and '.csv' is appended.\n\n" - << " --report-not-run= " + << " --report-not-run= " << " If true, reports the status of all kernels including those that" << end_of_line - << "do not satisfy the given arguments.\n\n" + << " do not satisfy the given arguments.\n\n" - << " --tags= " + << " --tags= " << " Inserts leading columns in output table and uniform values for each" << end_of_line - << "column. Useful for generating pivot tables.\n\n" + << " column. Useful for generating pivot tables.\n\n" - << " --verbose= " + << " --verbose= " << " Prints human-readable text to stdout. If false, nothing is written to stdout.\n\n"; } @@ -592,7 +608,7 @@ Options::About::About(cutlass::CommandLine const &cmdline) { void Options::About::print_usage(std::ostream &out) const { out << "About:\n" - << " --version "; + << " --version "; print_version(out); @@ -658,7 +674,7 @@ Options::Options(cutlass::CommandLine const &cmdline): // Prevent launches on the device for anything other than CUTLASS operation if (execution_mode == ExecutionMode::kTrace) { - initialization.provider = Provider::kReferenceHost; + initialization.provider = library::Provider::kReferenceHost; verification.enabled = false; profiling.enabled = false; } @@ -667,22 +683,29 @@ Options::Options(cutlass::CommandLine const &cmdline): void Options::print_usage(std::ostream &out) const { out - << "CUTLASS Performance Tool\n" + << "CUTLASS Profiler\n" << "usage:\n\n" << " cutlass_profiler [options]\n\n" << " --help\n\n" - << " --mode={profile*,single,dry,trace,enumerate} " - << " Regular profiling, single kernel mode only, or no profiling.\n\n" + << " --mode= " + << " Cutlass profiler execution mode." << end_of_line + << " --mode=profile regular verification and profiling (default)" << end_of_line + << " --mode=dry_run no kernels are launched or workspaces allocated" << end_of_line + << " --mode=enumerate lists all operation kind and operations" << end_of_line + << " --mode=trace executes a single device-side computation with" << end_of_line + << " no other kernel launches\n\n" - << " --device-info " + << " --device-info " << " Prints information on all GPUs present in the system\n\n" - << " --operation= " + << " --operation= " << " CUTLASS operation to profile.\n\n" - << " --kernels= " - << " List of substrings to filter operations by name.\n\n" + << " --kernels= " + << " Filter operations by kernel names. For example, call all kernels with" << end_of_line + << " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line + << " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n" ; // @@ -747,4 +770,3 @@ std::string Options::indent_str(int indent) { } // namespace profiler } // namespace cutlass - diff --git a/tools/profiler/src/options.h b/tools/profiler/src/options.h index 26cb93c7f9..f4b5f0a130 100644 --- a/tools/profiler/src/options.h +++ b/tools/profiler/src/options.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -105,11 +105,15 @@ class Options { /// allocating tensors. bool enabled; + /// If true, data distribution is set by the user and is not allowed to change + /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) + bool fix_data_distribution; + /// Data distribution for input tensors Distribution data_distribution; /// Source of random tensor elements - Provider provider; + library::Provider provider; /// Random number generator seed. int seed; @@ -162,10 +166,10 @@ class Options { void print_options(std::ostream &out, int indent = 0) const; /// Returns true if a provider is enabled - bool provider_enabled(Provider provider) const; + bool provider_enabled(library::Provider provider) const; /// Returns the index of a provider if its enabled - size_t index(Provider provider) const; + size_t index(library::Provider provider) const; }; /// Options related to profiling @@ -196,10 +200,10 @@ class Options { void print_options(std::ostream &out, int indent = 0) const; /// Returns true if a provider is enabled - bool provider_enabled(Provider provider) const; + bool provider_enabled(library::Provider provider) const; /// Returns the index of a provider if its enabled - size_t index(Provider provider) const; + size_t index(library::Provider provider) const; }; /// Options related to reporting diff --git a/tools/profiler/src/performance_report.cpp b/tools/profiler/src/performance_report.cpp index fd05155fe9..0ab7044929 100644 --- a/tools/profiler/src/performance_report.cpp +++ b/tools/profiler/src/performance_report.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,9 +29,15 @@ #include #include #include +#include +#include -#include "performance_report.h" +#include "cutlass/library/util.h" + +#include "cutlass/library/util.h" +#include "performance_report.h" +#include "debug.h" namespace cutlass { namespace profiler { @@ -57,12 +63,19 @@ namespace profiler { PerformanceReport::PerformanceReport( Options const &options, - std::vector const &argument_names + std::vector const &argument_names, + library::OperationKind const &op_kind ): - options_(options), argument_names_(argument_names), problem_index_(0), good_(true) { + options_(options), argument_names_(argument_names), problem_index_(0), good_(true), op_kind_(op_kind) { + + // Strip '.csv' if present + std::string base_path = options_.report.output_path.substr( + 0, options_.report.output_path.rfind(".csv")); + + op_file_name_ = base_path + "." + to_string(op_kind_) + ".csv"; // - // Open output file + // Open output file for operation of PerformanceReport::op_kind // if (!options_.report.output_path.empty()) { @@ -70,17 +83,17 @@ PerformanceReport::PerformanceReport( if (options_.report.append) { - std::ifstream test_output_file(options_.report.output_path.c_str()); + std::ifstream test_output_file(op_file_name_); if (test_output_file.is_open()) { print_header = false; test_output_file.close(); } - output_file_.open(options_.report.output_path.c_str(), std::ios::app); + output_file_.open(op_file_name_, std::ios::app); } else { - output_file_.open(options_.report.output_path.c_str()); + output_file_.open(op_file_name_); } if (!output_file_.good()) { @@ -148,13 +161,14 @@ void PerformanceReport::close() { } } else if (output_file_.is_open() && options_.report.verbose) { - std::cout << "\n\nWrote results to '" << options_.report.output_path << "'" << std::endl; + std::cout << "\n\nWrote results to '" << op_file_name_ << "'" << std::endl; } } static const char *disposition_status_color(Disposition disposition) { switch (disposition) { case Disposition::kPassed: return SHELL_COLOR_GREEN(); + case Disposition::kIncorrect: return SHELL_COLOR_RED(); case Disposition::kFailed: return SHELL_COLOR_RED(); default: break; @@ -184,21 +198,33 @@ std::ostream & PerformanceReport::print_result_pretty_( out << "\n" - << " Provider: " << SHELL_COLOR_BRIGHT() << to_string(result.provider, true) << SHELL_COLOR_END() << "\n" - << " Operation: " << result.operation_name << "\n\n" - << " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n" - << " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n"; + << " Provider: " << SHELL_COLOR_BRIGHT() << library::to_string(result.provider, true) << SHELL_COLOR_END() << "\n" + << " OperationKind: " << SHELL_COLOR_BRIGHT() << library::to_string(result.op_kind) << SHELL_COLOR_END() << "\n" + << " Operation: " << result.operation_name << "\n\n" + << " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n" + << " Verification: " << SHELL_COLOR_BRIGHT() << (options_.verification.enabled ? "ON":"OFF") << SHELL_COLOR_END() << "\n" + << " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n\n"; + + // Display individual verification results for each verification-provider + if (options_.verification.enabled) { + + static int const indent_spaces = 16; + + for(auto & m : result.verification_map) { + out << std::right << std::setw(indent_spaces) << library::to_string(m.first, true) << ": " << to_string(m.second, true) << "\n"; + } + } out - << "\n Arguments: "; + << "\n Arguments:"; int column_idx = 0; for (auto const &arg : result.arguments) { if (!arg.second.empty()) { out << " --" << arg.first << "=" << arg.second; - column_idx += 4 + arg.first.size() + arg.second.size(); - if (column_idx > 90) { - out << " \\\n "; + column_idx += int(4 + arg.first.size() + arg.second.size()); + if (column_idx > 98) { + out << " \\\n "; column_idx = 0; } } @@ -206,15 +232,15 @@ std::ostream & PerformanceReport::print_result_pretty_( out << "\n\n"; out - << " Bytes: " << result.bytes << " bytes\n" - << " FLOPs: " << result.flops << " flops\n\n"; + << " Bytes: " << result.bytes << " bytes\n" + << " FLOPs: " << result.flops << " flops\n\n"; if (result.good()) { out - << " Runtime: " << result.runtime << " ms\n" - << " Memory: " << result.gbytes_per_sec() << " GiB/s\n" - << "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n"; + << " Runtime: " << result.runtime << " ms\n" + << " Memory: " << result.gbytes_per_sec() << " GiB/s\n" + << "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n"; } @@ -234,7 +260,7 @@ std::ostream & PerformanceReport::print_csv_header_( out << (column_idx ? "," : "") << "Problem,Provider" - << ",Operation,Disposition,Status"; + << ",OperationKind,Operation,Disposition,Status"; for (auto const &arg_name : argument_names_) { out << "," << arg_name; @@ -267,6 +293,7 @@ std::ostream & PerformanceReport::print_result_csv_( << (column_idx ? "," : "") << result.problem_index << "," << to_string(result.provider, true) + << "," << to_string(result.op_kind) << "," << result.operation_name << "," << to_string(result.disposition) << "," << library::to_string(result.status); diff --git a/tools/profiler/src/performance_report.h b/tools/profiler/src/performance_report.h index 1022efac5c..1c086e6185 100644 --- a/tools/profiler/src/performance_report.h +++ b/tools/profiler/src/performance_report.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -31,10 +31,14 @@ #include #include +// CUTLASS Profiler includes #include "options.h" #include "enumerated_types.h" #include "performance_result.h" +// CUTLASS Library includes +#include "cutlass/library/library.h" + namespace cutlass { namespace profiler { @@ -46,6 +50,12 @@ class PerformanceReport { /// Reference to options Options const &options_; + /// Operation kind + library::OperationKind op_kind_; + + /// Operation file name containing performance report of op_kind + std::string op_file_name_; + /// Output file containing results std::ofstream output_file_; @@ -63,7 +73,7 @@ class PerformanceReport { public: - PerformanceReport(Options const &options, std::vector const &argument_names); + PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); bool good() const { return good_; } diff --git a/tools/profiler/src/performance_result.cu b/tools/profiler/src/performance_result.cu new file mode 100644 index 0000000000..86cabfb753 --- /dev/null +++ b/tools/profiler/src/performance_result.cu @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +// CUTLASS Profiler includes +#include "enumerated_types.h" +#include "performance_result.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/profiler/src/performance_result.h b/tools/profiler/src/performance_result.h index b710099d71..9e3ebeb5ce 100644 --- a/tools/profiler/src/performance_result.h +++ b/tools/profiler/src/performance_result.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -32,8 +32,12 @@ #include "cutlass/cutlass.h" +// CUTLASS Profiler includes #include "enumerated_types.h" +// CUTLASS Library includes +#include "cutlass/library/library.h" + namespace cutlass { namespace profiler { @@ -45,16 +49,23 @@ struct PerformanceResult { /// Index of problem size_t problem_index; - /// Provider - Provider provider; + /// library::Provider + library::Provider provider; - /// Outcome of test - Disposition disposition; + /// Operation kind + library::OperationKind op_kind; - /// CUTLASS status result from kernels + /// CUTLASS status result from kernels (success or failure) + // Status does information on verification Status status; - /// Operation object + /// Outcome of verification (worst case verification result) + Disposition disposition; + + /// Outcome of verification (all verification results) + DispositionMap verification_map; + + /// Operation name std::string operation_name; /// Stringified vector of argument values @@ -76,7 +87,8 @@ struct PerformanceResult { /// Ctor PerformanceResult(): problem_index(0), - provider(Provider::kInvalid), + op_kind(library::OperationKind::kInvalid), + provider(library::Provider::kInvalid), disposition(Disposition::kNotRun), status(Status::kInvalid), bytes(0), @@ -107,3 +119,4 @@ using PerformanceResultVector = std::vector; } // namespace profiler } // namespace cutlass + diff --git a/tools/profiler/src/problem_space.cpp b/tools/profiler/src/problem_space.cpp index 33656beff8..adede0ea1f 100644 --- a/tools/profiler/src/problem_space.cpp +++ b/tools/profiler/src/problem_space.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -27,10 +27,11 @@ */ #include -#include #include #include +#include "cutlass/library/util.h" + #include "problem_space.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -848,18 +849,58 @@ bool arg_as_OpcodeClassID( return arg_as_OpcodeClassID(opcode_class, value_ptr); } -///////////////////////////////////////////////////////////////////////////////////////////////// +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + KernelArgument::Value const *value_ptr) { + + if (value_ptr->not_null) { + if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { + + split_k_mode = library::from_string( + static_cast(value_ptr)->element); + + if (split_k_mode == library::SplitKMode::kInvalid) { + throw std::runtime_error( + "arg_as_SplitKModeID() - illegal cast."); + } + } + else { + + throw std::runtime_error( + "arg_as_SplitKModeID() - illegal cast."); + } + return true; + } + return false; +} + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + size_t idx = problem_space.argument_index(name); + KernelArgument::Value const *value_ptr = problem.at(idx).get(); + + return arg_as_SplitKModeID(split_k_mode, value_ptr); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. bool arg_as_scalar( std::vector &bytes, library::NumericTypeID numeric_type, KernelArgument::Value const *value_ptr) { - + if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) { int64_t int_value = static_cast(value_ptr)->value; - + // TODO - convert int64_t => destination type } else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { @@ -939,7 +980,6 @@ bool tensor_description_satisfies( } ///////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/problem_space.h b/tools/profiler/src/problem_space.h index 8dfd216cfb..77a79ca2a6 100644 --- a/tools/profiler/src/problem_space.h +++ b/tools/profiler/src/problem_space.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -811,6 +811,17 @@ bool arg_as_OpcodeClassID( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + /// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. bool arg_as_scalar( std::vector &bytes, diff --git a/tools/util/CMakeLists.txt b/tools/util/CMakeLists.txt index 6cda38ac3a..0d2f86fb99 100644 --- a/tools/util/CMakeLists.txt +++ b/tools/util/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -31,6 +31,12 @@ target_include_directories( $ ) +target_link_libraries( + cutlass_tools_util_includes + INTERFACE + $<$:cublas> + ) + install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ @@ -40,3 +46,4 @@ install( TARGETS cutlass_tools_util_includes EXPORT NvidiaCutlass ) + diff --git a/tools/util/include/cutlass/util/command_line.h b/tools/util/include/cutlass/util/command_line.h index 008d0e7360..c158ef9768 100644 --- a/tools/util/include/cutlass/util/command_line.h +++ b/tools/util/include/cutlass/util/command_line.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are not permitted. @@ -119,6 +119,16 @@ struct CommandLine { val = !(value == "0" || value == "false"); } } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } /** * Obtains the value specified for a given commandline parameter --= @@ -126,7 +136,7 @@ struct CommandLine { template void get_cmd_line_argument(const char* arg_name, value_t& val, - value_t const& _default = value_t()) const { + value_t const& _default) const { using namespace std; val = _default; diff --git a/tools/util/include/cutlass/util/debug.h b/tools/util/include/cutlass/util/debug.h index 065a94e42d..3ebbd4d843 100644 --- a/tools/util/include/cutlass/util/debug.h +++ b/tools/util/include/cutlass/util/debug.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/device_dump.h b/tools/util/include/cutlass/util/device_dump.h index 2dd67c8905..dac6029c41 100644 --- a/tools/util/include/cutlass/util/device_dump.h +++ b/tools/util/include/cutlass/util/device_dump.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/device_memory.h b/tools/util/include/cutlass/util/device_memory.h index e8f13d3b34..79b123687a 100644 --- a/tools/util/include/cutlass/util/device_memory.h +++ b/tools/util/include/cutlass/util/device_memory.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are not permitted. @@ -40,10 +40,14 @@ namespace device_memory { /// Allocate a buffer of \p count elements of type \p T on the current CUDA device template T* allocate(size_t count = 1) { + T* ptr = 0; - size_t bytes = sizeof(T) * count; + size_t bytes = 0; + + bytes = count * sizeof(T); cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + if (cuda_error != cudaSuccess) { throw cuda_exception("Failed to allocate memory", cuda_error); } @@ -111,13 +115,16 @@ void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { copy_to_device(device_begin, &*begin, elements); } -/****************************************************************************** - * "Smart" device memory allocation - ******************************************************************************/ +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// -/// Device allocation abstraction that tracks size and capacity template -struct allocation { +class DeviceAllocation { +public: + /// Delete functor for CUDA device memory struct deleter { void operator()(T* ptr) { @@ -130,6 +137,7 @@ struct allocation { } }; +public: // // Data members // @@ -140,23 +148,55 @@ struct allocation { /// Smart pointer platform::unique_ptr smart_ptr; +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + // // Methods // /// Constructor: allocates no memory - allocation() : capacity(0) {} + DeviceAllocation() : capacity(0) {} /// Constructor: allocates \p capacity elements on the current CUDA device - allocation(size_t _capacity) : smart_ptr(allocate(_capacity)), capacity(_capacity) {} + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} /// Copy constructor - allocation(allocation const &p): smart_ptr(allocate(p.capacity)), capacity(p.capacity) { - copy_device_to_device(smart_ptr.get(), p.get(), capacity); + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); } /// Destructor - ~allocation() { reset(); } + ~DeviceAllocation() { reset(); } /// Returns a pointer to the managed object T* get() const { return smart_ptr.get(); } @@ -173,12 +213,41 @@ struct allocation { smart_ptr.reset(); } + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity void reset(T* _ptr, size_t _capacity) { smart_ptr.reset(_ptr); capacity = _capacity; } + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + /// Returns a pointer to the object owned by *this T* operator->() const { return smart_ptr.get(); } @@ -189,15 +258,69 @@ struct allocation { const deleter& get_deleter() const { return smart_ptr.get_deleter(); } /// Copies a device-side memory allocation - allocation & operator=(allocation const &p) { + DeviceAllocation & operator=(DeviceAllocation const &p) { if (capacity != p.capacity) { - smart_ptr.reset(allocate(p.capacity)); + smart_ptr.reset(device_memory::allocate(p.capacity)); capacity = p.capacity; } copy_device_to_device(smart_ptr.get(), p.get(), capacity); return *this; } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + } // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/distribution.h b/tools/util/include/cutlass/util/distribution.h index d9b61ca558..0337737747 100644 --- a/tools/util/include/cutlass/util/distribution.h +++ b/tools/util/include/cutlass/util/distribution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/exceptions.h b/tools/util/include/cutlass/util/exceptions.h index ab5623bfc6..b6cf2fcd8e 100644 --- a/tools/util/include/cutlass/util/exceptions.h +++ b/tools/util/include/cutlass/util/exceptions.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are not permitted. diff --git a/tools/util/include/cutlass/util/host_reorder.h b/tools/util/include/cutlass/util/host_reorder.h index bb9ed621bc..d46d45946f 100644 --- a/tools/util/include/cutlass/util/host_reorder.h +++ b/tools/util/include/cutlass/util/host_reorder.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index 0a08b6e0eb..c734a5f5eb 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -99,7 +99,7 @@ class HostTensor { using ConstReference = typename ConstTensorRef::Reference; /// Used to handle packing of subbyte elements - static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? sizeof(Element) * 8 / sizeof_bits::value : 1); + static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? (8 / sizeof_bits::value) : 1); private: @@ -232,7 +232,7 @@ class HostTensor { /// Returns the logical capacity based on extent and layout. May differ from size(). LongIndex capacity() const { - return layout_.capacity(extent_) * kElementsPerStoredItem; + return layout_.capacity(extent_); } /// Gets pointer to host data diff --git a/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/tools/util/include/cutlass/util/host_tensor_planar_complex.h new file mode 100644 index 0000000000..3a31e29a43 --- /dev/null +++ b/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -0,0 +1,586 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/matrix_traits.h" + +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensorPlanarComplex { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRefPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorViewPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + + private: + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensorPlanarComplex() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensorPlanarComplex( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensorPlanarComplex( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensorPlanarComplex() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated + + device_.reset(); + host_.clear(); + + host_.resize(count * 2); + + // Allocate memory + Element* device_memory = nullptr; + if (device_backed_) { + device_memory = device_memory::allocate(count * 2); + } + device_.reset(device_memory, device_backed_ ? count * 2 : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + + if (static_cast(new_size * 2) > host_.size()) { + reserve(new_size); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the number of elements stored in the host tensor + size_t size() const { + return host_.size() / 2; + } + + /// Returns the logical capacity based on extent and layout. May differ from size(). + LongIndex capacity() const { + return layout_.capacity(extent_); + } + + /// Stride between real and imaginary parts + LongIndex imaginary_stride() const { + return host_.size() / 2; + } + + /// Gets pointer to host data + Element * host_data() { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element * host_data_imag() { return host_.data() + imaginary_stride(); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } + + /// Gets pointer to host data with a pointer offset + Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to device data + Element * device_data() { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } + + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } + + /// Gets a pointer to the device data imaginary part + Element * device_data_imag() { return device_.get() + imaginary_stride(); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { + return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_real() { + return cutlass::TensorRef(host_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_imag() { + return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { + return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_real() { + return cutlass::TensorRef(device_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_imag() { + return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_real() { + return cutlass::TensorView(host_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_imag() { + return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_real() { + return cutlass::TensorView(device_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_imag() { + return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_data(), device_data(), imaginary_stride() * 2); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_data(), host_data(), imaginary_stride() * 2); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + host_data(), ptr_device_real, count); + + device_memory::copy_to_host( + host_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + device_data(), ptr_device_real, count); + + device_memory::copy_device_to_device( + device_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + device_data(), ptr_host_real, count); + + device_memory::copy_to_device( + device_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + host_data(), ptr_host_real, count); + + device_memory::copy_host_to_host( + host_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host_real, ///< source device memory + Element * ptr_host_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + ptr_host_real, device_data(), count); + + device_memory::copy_to_host( + ptr_host_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + ptr_device_real, device_data(), count); + + device_memory::copy_device_to_device( + ptr_device_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + ptr_device_real, host_data(), count); + + device_memory::copy_to_device( + ptr_device_imag, host_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host_real, ///< source host memory + Element * ptr_host_imag, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + ptr_host_real, host_data(), count); + + device_memory::copy_host_to_host( + ptr_host_imag, host_data_imag(), count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/detail/inner_product.h b/tools/util/include/cutlass/util/reference/detail/inner_product.h index 77a3076ed5..f75f8b8884 100644 --- a/tools/util/include/cutlass/util/reference/detail/inner_product.h +++ b/tools/util/include/cutlass/util/reference/detail/inner_product.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/gemm.h b/tools/util/include/cutlass/util/reference/device/gemm.h index 9dc66cca25..5aef19ff23 100644 --- a/tools/util/include/cutlass/util/reference/device/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h new file mode 100644 index 0000000000..b3003409bb --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/matrix_traits.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag());; + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/tools/util/include/cutlass/util/reference/device/kernel/gemm.h index 6e38910299..4c8e361ecb 100644 --- a/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h index cf47c9a4ea..4d9de5156e 100644 --- a/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h +++ b/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h index b7c2f073a5..64cb37bea2 100644 --- a/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index dca50c2f72..3323bed51e 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 0c8e1ac46f..962ded0940 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -43,12 +43,12 @@ #endif // CUDA includes -#include #include // Cutlass includes #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/complex.h" #include "cutlass/tensor_view.h" #include "cutlass/util/reference/device/tensor_foreach.h" @@ -169,6 +169,95 @@ struct RandomGaussianFunc { } }; + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale))); + rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale))); + + result = { + Real(rnd_r / FloatType(IntType(1) << params.int_scale)), + Real(rnd_i / FloatType(IntType(1) << params.int_scale)) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + /// Computes a random Gaussian distribution template < typename Element, ///< Element type @@ -269,12 +358,12 @@ template ///< Element type void BlockFillRandomGaussian( Element *ptr, size_t capacity, - uint64_t seed, ///< seed for RNG - Element mean = Element(0), ///< Gaussian distribution's mean - Element stddev = Element(1), ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. using RandomFunc = detail::RandomGaussianFunc; @@ -383,6 +472,111 @@ struct RandomUniformFunc { } }; +/// Computes a random Gaussian distribution +template ///< Layout function +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale))); + rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale))); + + result = { + Real(rnd_r / FloatType(IntType(1) << params.int_scale)), + Real(rnd_i / FloatType(IntType(1) << params.int_scale)) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + /// Computes a random Gaussian distribution template < typename Element, ///< Element type @@ -489,8 +683,8 @@ void BlockFillRandomUniform( Element *ptr, size_t capacity, uint64_t seed, ///< seed for RNG - Element max = Element(1), ///< upper bound of distribution - Element min = Element(0), ///< lower bound for distribution + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution int bits = -1) { ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of /// data. @@ -976,13 +1170,15 @@ void BlockFillRandom( uint64_t seed, Distribution dist) { + using Real = typename RealType::Type; + if (dist.kind == Distribution::Gaussian) { BlockFillRandomGaussian( ptr, capacity, seed, - static_cast(dist.gaussian.mean), - static_cast(dist.gaussian.stddev), + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), dist.int_scale); } else if (dist.kind == Distribution::Uniform) { @@ -990,8 +1186,8 @@ void BlockFillRandom( ptr, capacity, seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), dist.int_scale); } } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h index aa6610e1d4..d03080b2a0 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/tools/util/include/cutlass/util/reference/device/tensor_relu.h new file mode 100644 index 0000000000..d78e19533e --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/tensor_relu.h @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorReLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element threshold; + + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element threshold_ = Element(0) + ): + view(view_), threshold(threshold_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorReLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + params.view.at(coord) = (value < params.threshold) ? params.threshold : value; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply ReLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorReLu( + TensorView view, ///< destination tensor + Element threshold = Element(0)) { ///< ReLu threshold + + using Func = detail::TensorReLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, threshold) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/tools/util/include/cutlass/util/reference/device/thread/gemm.h index fefc4131df..11485a91de 100644 --- a/tools/util/include/cutlass/util/reference/device/thread/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/thread/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/host/gemm.h b/tools/util/include/cutlass/util/reference/host/gemm.h index 13dbd5cf06..3e38886dd8 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm.h +++ b/tools/util/include/cutlass/util/reference/host/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -37,11 +37,41 @@ #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" #include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" namespace cutlass { namespace reference { namespace host { +template +struct CastIfScalar { + static Out cast(In in) { + return Out(in); + } +}; + +template +struct CastIfScalar, In> { + typedef cutlass::complex Out; + static Out cast(In in) { + return Out(static_cast(in)); + } +}; + +template +struct CastIfScalar, cutlass::complex> { + typedef cutlass::complex Out; + typedef cutlass::complex In; + static Out cast(In in) { + return Out(in); + } +}; + +template +Out cast_if_scalar(In in) { + return CastIfScalar::cast(in); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef @@ -107,7 +137,10 @@ void compute_gemm( ElementA a = tensor_a.at(MatrixCoord(row, k_block)); ElementB b = tensor_b.at(MatrixCoord(k_block, col)); - accum[i][j] = inner_product_op(ComputeType(a), ComputeType(b), accum[i][j]); + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); } } } diff --git a/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_complex.h index 964a69c489..27f368200d 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm_complex.h +++ b/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,6 +72,7 @@ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, + TensorRef tensor_d, ComputeType initial_accum) { static_assert( @@ -138,7 +139,7 @@ void GemmComplex( if (row < M && col < N) { - tensor_c.at(coord) = convert_op( + tensor_d.at(coord) = convert_op( alpha * ScalarType(accum[i][j]) + beta * ScalarType(tensor_c.at(coord))); } @@ -171,9 +172,10 @@ void GemmComplex( TensorRef tensor_b, ComplexTransform transform_b, ScalarType beta, - TensorRef tensor_c) { + TensorRef tensor_c, + TensorRef tensor_d) { - GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0)); + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h new file mode 100644 index 0000000000..2a23fd2720 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/matrix_traits.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag());; + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/tools/util/include/cutlass/util/reference/host/tensor_compare.h index 3c7d95ff1f..2d7545e907 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,6 +33,9 @@ // Cutlass includes #include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + #include "cutlass/util/distribution.h" //#include "cutlass/util/type_traits.h" #include "tensor_foreach.h" @@ -112,6 +115,46 @@ bool TensorEquals( return bool(func); } +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -137,6 +180,17 @@ bool TensorNotEquals( return !bool(func); } +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/tools/util/include/cutlass/util/reference/host/tensor_copy.h index 737119e814..a81f021127 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_copy.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_copy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h index 73eb328d34..88bbb39f45 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 37096f730e..87c14d61c6 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -38,6 +38,8 @@ #include "cutlass/complex.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" #include "cutlass/util/distribution.h" #include "tensor_foreach.h" @@ -101,6 +103,18 @@ void TensorFill( ); } +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -268,6 +282,23 @@ void TensorFillRandomGaussian( ); } +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Fills a tensor with random values with a Gaussian distribution. @@ -461,6 +492,23 @@ void TensorFillRandomUniform( ); } +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Fills a tensor with random values with a uniform random distribution. @@ -774,6 +822,27 @@ void BlockFillSequential( } } +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/tools/util/include/cutlass/util/reference/host/tensor_foreach.h index 23ee9f93da..feb439d724 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_foreach.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/tools/util/include/cutlass/util/reference/host/tensor_norm.h index 6c73d91fa6..1d494b9f45 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_norm.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_norm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/include/cutlass/util/tensor_view_io.h b/tools/util/include/cutlass/util/tensor_view_io.h index c764c61c45..0043d745c2 100644 --- a/tools/util/include/cutlass/util/tensor_view_io.h +++ b/tools/util/include/cutlass/util/tensor_view_io.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -26,6 +26,8 @@ #include "cutlass/core_io.h" #include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" namespace cutlass { @@ -80,6 +82,76 @@ inline std::ostream & TensorView_WriteRank( return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); } + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by "\n" + out << (idx ? ",\n" : ""); + TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + out << (idx ? ",\n\n" : ""); + TensorView_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + // Otherwise, write a sequence of rows and newlines for (int idx = 0; idx < view.extent(rank); ++idx) { @@ -89,12 +161,12 @@ inline std::ostream & TensorView_WriteRank( if (rank + 2 == Layout::kRank) { // Write least significant ranks asa matrix with rows delimited by ";\n" out << (idx ? ";\n" : ""); - TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); } else { // Higher ranks are separated by newlines out << (idx ? "\n" : ""); - TensorView_WriteRank(out, view, coord, rank + 1, width); + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); } } @@ -143,4 +215,42 @@ inline std::ostream& operator<<( /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass diff --git a/tools/util/include/cutlass/util/type_traits.h b/tools/util/include/cutlass/util/type_traits.h index 059a23ab4b..d97af0a421 100644 --- a/tools/util/include/cutlass/util/type_traits.h +++ b/tools/util/include/cutlass/util/type_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: