Skip to content

Commit

Permalink
Merge pull request #187 from KernelTuner/refactor_interface
Browse files Browse the repository at this point in the history
Refactor interfaces
  • Loading branch information
isazi authored Mar 30, 2023
2 parents a60e060 + 881042a commit 7c6f709
Show file tree
Hide file tree
Showing 61 changed files with 1,608 additions and 1,029 deletions.
20 changes: 10 additions & 10 deletions doc/source/design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,33 +98,33 @@ kernel_tuner.core.DeviceInterface
:special-members: __init__
:members:

kernel_tuner.pycuda.PyCudaFunctions
kernel_tuner.backends.pycuda.PyCudaFunctions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: kernel_tuner.pycuda.PyCudaFunctions
.. autoclass:: kernel_tuner.backends.pycuda.PyCudaFunctions
:special-members: __init__
:members:

kernel_tuner.cupy.CupyFunctions
kernel_tuner.backends.cupy.CupyFunctions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: kernel_tuner.cupy.CupyFunctions
.. autoclass:: kernel_tuner.backends.cupy.CupyFunctions
:special-members: __init__
:members:

kernel_tuner.nvcuda.CudaFunctions
kernel_tuner.backends.nvcuda.CudaFunctions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: kernel_tuner.nvcuda.CudaFunctions
.. autoclass:: kernel_tuner.backends.nvcuda.CudaFunctions
:special-members: __init__
:members:

kernel_tuner.opencl.OpenCLFunctions
kernel_tuner.backends.opencl.OpenCLFunctions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: kernel_tuner.opencl.OpenCLFunctions
.. autoclass:: kernel_tuner.backends.opencl.OpenCLFunctions
:special-members: __init__
:members:

kernel_tuner.c.CFunctions
kernel_tuner.backends.c.CFunctions
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: kernel_tuner.c.CFunctions
.. autoclass:: kernel_tuner.backends.c.CFunctions
:special-members: __init__
:members:

Expand Down
76 changes: 49 additions & 27 deletions examples/cuda/convolution_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,66 +26,88 @@
import kernel_tuner
from collections import OrderedDict


def tune():
with open('convolution.cu', 'r') as f:
with open("convolution.cu", "r") as f:
kernel_string = f.read()

filter_size = (17, 17)
problem_size = (4096, 4096)
size = numpy.prod(problem_size)
border_size = (filter_size[0]//2*2, filter_size[1]//2*2)
input_size = ((problem_size[0]+border_size[0]) * (problem_size[1]+border_size[1]))
border_size = (filter_size[0] // 2 * 2, filter_size[1] // 2 * 2)
input_size = (problem_size[0] + border_size[0]) * (problem_size[1] + border_size[1])

output = numpy.zeros(size).astype(numpy.float32)
input = numpy.random.randn(input_size).astype(numpy.float32)

filter = numpy.random.randn(filter_size[0]*filter_size[1]).astype(numpy.float32)
cmem_args= {'d_filter': filter }
filter = numpy.random.randn(filter_size[0] * filter_size[1]).astype(numpy.float32)
cmem_args = {"d_filter": filter}

args = [output, input, filter]
tune_params = OrderedDict()
tune_params["filter_width"] = [filter_size[0]]
tune_params["filter_height"] = [filter_size[1]]

#tune_params["block_size_x"] = [16*i for i in range(1,3)]
tune_params["block_size_x"] = [16*i for i in range(1,9)]
#tune_params["block_size_y"] = [2**i for i in range(1,5)]
tune_params["block_size_y"] = [2**i for i in range(1,6)]
# tune_params["block_size_x"] = [16*i for i in range(1,3)]
tune_params["block_size_x"] = [16 * i for i in range(1, 9)]
# tune_params["block_size_y"] = [2**i for i in range(1,5)]
tune_params["block_size_y"] = [2**i for i in range(1, 6)]

tune_params["tile_size_x"] = [2**i for i in range(3)]
tune_params["tile_size_y"] = [2**i for i in range(3)]

tune_params["use_padding"] = [0,1] #toggle the insertion of padding in shared memory
tune_params["read_only"] = [0,1] #toggle using the read-only cache
tune_params["use_padding"] = [
0,
1,
] # toggle the insertion of padding in shared memory
tune_params["read_only"] = [0, 1] # toggle using the read-only cache

grid_div_x = ["block_size_x", "tile_size_x"]
grid_div_y = ["block_size_y", "tile_size_y"]

#compute the answer using a naive kernel
params = { "block_size_x": 16, "block_size_y": 16}
# compute the answer using a naive kernel
params = {"block_size_x": 16, "block_size_y": 16}
tune_params["filter_width"] = [filter_size[0]]
tune_params["filter_height"] = [filter_size[1]]
results = kernel_tuner.run_kernel("convolution_naive", kernel_string,
problem_size, args, params,
grid_div_y=["block_size_y"], grid_div_x=["block_size_x"], lang='cupy')

#set non-output fields to None
results = kernel_tuner.run_kernel(
"convolution_naive",
kernel_string,
problem_size,
args,
params,
grid_div_y=["block_size_y"],
grid_div_x=["block_size_x"],
lang="cupy",
)

# set non-output fields to None
answer = [results[0], None, None]

#start kernel tuning with correctness verification
return kernel_tuner.tune_kernel("convolution_kernel", kernel_string,
problem_size, args, tune_params,
grid_div_y=grid_div_y, grid_div_x=grid_div_x, verbose=True, cmem_args=cmem_args, answer=answer, lang='cupy')
# start kernel tuning with correctness verification
return kernel_tuner.tune_kernel(
"convolution_kernel",
kernel_string,
problem_size,
args,
tune_params,
grid_div_y=grid_div_y,
grid_div_x=grid_div_x,
verbose=True,
cmem_args=cmem_args,
answer=answer,
lang="cupy",
)


if __name__ == "__main__":
import time
s1 = time.time()*1000

s1 = time.time() * 1000
results = tune()

e1 = time.time()*1000
print("\n Actualy time used:", e1-s1)
e1 = time.time() * 1000
print("\n Actual time used:", e1 - s1)
import json
with open("convolution_RTX_2070.json", 'w') as fp:
json.dump(results, fp)

with open("convolution_RTX_2070.json", "w") as fp:
json.dump(results, fp)
2 changes: 1 addition & 1 deletion examples/cuda/vector_add_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy
from kernel_tuner import tune_kernel
from kernel_tuner.nvml import NVMLObserver
from kernel_tuner.observers.nvml import NVMLObserver

def tune():

Expand Down
2 changes: 1 addition & 1 deletion examples/opencl/vector_add_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy
from kernel_tuner import tune_kernel
from kernel_tuner.nvml import NVMLObserver
from kernel_tuner.observers.nvml import NVMLObserver

def tune():

Expand Down
Empty file.
89 changes: 89 additions & 0 deletions kernel_tuner/backends/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""This module contains the interface of all kernel_tuner backends"""
from __future__ import print_function

from abc import ABC, abstractmethod


class Backend(ABC):
"""Base class for kernel_tuner backends"""

@abstractmethod
def ready_argument_list(self, arguments):
"""This method must implement the allocation of the arguments on device memory."""
pass

@abstractmethod
def compile(self, kernel_instance):
"""This method must implement the compilation of a kernel into a callable function."""
pass

@abstractmethod
def start_event(self):
"""This method must implement the recording of the start of a measurement."""
pass

@abstractmethod
def stop_event(self):
"""This method must implement the recording of the end of a measurement."""
pass

@abstractmethod
def kernel_finished(self):
"""This method must implement a check that returns True if the kernel has finished, False otherwise."""
pass

@abstractmethod
def synchronize(self):
"""This method must implement a barrier that halts execution until device has finished its tasks."""
pass

@abstractmethod
def run_kernel(self, func, gpu_args, threads, grid, stream):
"""This method must implement the execution of the kernel on the device."""
pass

@abstractmethod
def memset(self, allocation, value, size):
"""This method must implement setting the memory to a value on the device."""
pass

@abstractmethod
def memcpy_dtoh(self, dest, src):
"""This method must implement a device to host copy."""
pass

@abstractmethod
def memcpy_htod(self, dest, src):
"""This method must implement a host to device copy."""
pass


class GPUBackend(Backend):
"""Base class for GPU backends"""

@abstractmethod
def __init__(self, device, iterations, compiler_options, observers):
pass

@abstractmethod
def copy_constant_memory_args(self, cmem_args):
"""This method must implement the allocation and copy of constant memory to the GPU."""
pass

@abstractmethod
def copy_shared_memory_args(self, smem_args):
"""This method must implement the dynamic allocation of shared memory on the GPU."""
pass

@abstractmethod
def copy_texture_memory_args(self, texmem_args):
"""This method must implement the allocation and copy of texture memory to the GPU."""
pass


class CompilerBackend(Backend):
"""Base class for compiler backends"""

@abstractmethod
def __init__(self, iterations, compiler_options, compiler):
pass
Loading

0 comments on commit 7c6f709

Please sign in to comment.