Skip to content

Commit

Permalink
[Feature](mlu-ops): add mluAdamW.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrelYy committed Dec 23, 2024
1 parent f70bd53 commit c952491
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 34 deletions.
33 changes: 16 additions & 17 deletions bangc_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,22 @@ typedef enum {
} bangcKernelsStatus_t;

template <typename T>
bangcKernelsStatus_t BANGC_KERNELS_WIN_API
mluApplyAdamW(const cnrtQueue_t queue,
const float lr,
const float beta1,
const float beta2,
const float bias1,
const float bias2,
const float epsilon,
const float weight_decay,
const float scale,
const bool use_nesterov,
const size_t size,
T *param_h,
T *grad,
void *param,
void *momentum,
void *velocity);
bangcKernelsStatus_t BANGC_KERNELS_WIN_API mluAdamW(const cnrtQueue_t queue,
const float lr,
const float beta1,
const float beta2,
const float bias1,
const float bias2,
const float epsilon,
const float weight_decay,
const float scale,
const bool use_nesterov,
const size_t size,
T *param_h,
T *grad,
void *param,
void *momentum,
void *velocity);

#ifndef NAMESPACE_BANGC_KERNELS_END
#define NAMESPACE_BANGC_KERNELS_END }
Expand Down
11 changes: 5 additions & 6 deletions kernels/adam_w/adam_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,11 @@ mluOpAdamW(mluOpHandle_t handle, const mluOpAdamWDescriptor_t adamw_desc,
<< ", " << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
CHECK_RETURN(
"[mluOpAdamW]",
KernelApplyAdamW(k_dim, k_type, handle->queue, (void *)param,
(void *)param_h, (void *)grad, (void *)momentum,
(void *)velocity, lr, beta1, beta2, bias1, bias2,
epsilon, adamw_desc->weight_decay,
adamw_desc->grad_scale, adamw_desc->use_nesterov,
size, k_data_type));
KernelApplyAdamW(
k_dim, k_type, handle->queue, (void *)param, (void *)param_h,
(void *)grad, (void *)momentum, (void *)velocity, lr, beta1,
beta2, bias1, bias2, epsilon, adamw_desc->weight_decay,
adamw_desc->grad_scale, adamw_desc->use_nesterov, size));
}
}
GEN_CASE_END();
Expand Down
2 changes: 1 addition & 1 deletion kernels/adam_w/adam_w.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ mluOpStatus_t MLUOP_WIN_API KernelApplyAdamW(
const cnrtQueue_t queue, void *param, void *param_h, void *grad,
void *momentum, void *velocity, float lr, float beta1, float beta2,
float bias1, float bias2, float epsilon, float weight_decay, float scale,
bool use_nesterov, size_t size, mluOpDataType_t k_data_type);
bool use_nesterov, size_t size);

#endif // KERNELS_ADAMW_ADAMW_H_
14 changes: 7 additions & 7 deletions kernels/adam_w/adam_w_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ mluOpStatus_t MLUOP_WIN_API KernelApplyAdamW(
const cnrtQueue_t queue, void *param, void *param_h, void *grad,
void *momentum, void *velocity, float lr, float beta1, float beta2,
float bias1, float bias2, float epsilon, float weight_decay, float scale,
bool use_nesterov, size_t size, mluOpDataType_t k_data_type) {
bool use_nesterov, size_t size) {
// launch kernel
unionApplyAdamW<bfloat16_t><<<k_dim, k_type, queue>>>(
(bfloat16_t *)param_h, (bfloat16_t *)grad, (float *)param,
Expand All @@ -241,11 +241,11 @@ NAMESPACE_BANGC_KERNELS_GEGIN

template <typename T>
bangcKernelsStatus_t BANGC_KERNELS_WIN_API
mluApplyAdamW(const cnrtQueue_t queue, const float lr, const float beta1,
const float beta2, const float bias1, const float bias2,
const float epsilon, const float weight_decay, const float scale,
const bool use_nesterov, size_t size, T *param_h, T *grad,
void *param, void *momentum, void *velocity) {
mluAdamW(const cnrtQueue_t queue, const float lr, const float beta1,
const float beta2, const float bias1, const float bias2,
const float epsilon, const float weight_decay, const float scale,
const bool use_nesterov, size_t size, T *param_h, T *grad, void *param,
void *momentum, void *velocity) {
// set job type
int ordinal = -1;
int cluster_num;
Expand All @@ -266,7 +266,7 @@ mluApplyAdamW(const cnrtQueue_t queue, const float lr, const float beta1,
}

#define IMPL_MLU_APPLY_ADAMW_KERNEL(DType) \
template bangcKernelsStatus_t BANGC_KERNELS_WIN_API mluApplyAdamW( \
template bangcKernelsStatus_t BANGC_KERNELS_WIN_API mluAdamW( \
const cnrtQueue_t, const float, const float, const float, const float, \
const float, const float, const float, const float, const bool, \
const size_t, DType *, DType *, void *, void *, void *)
Expand Down
2 changes: 1 addition & 1 deletion scripts/gen_symbol_visibility_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def get_mluops(input_file):
ops_str=""
pattern = re.compile(r'(?P<api>mluOp\w+) *\(')
pattern_lite = re.compile(r'(?P<api>mluApply\w+) *\(')
pattern_lite = re.compile(r'(?P<api>mlu\w+) *\(')
with open(input_file,'r', encoding='utf8') as f:
for line in f:
match = pattern.search(line)
Expand Down
4 changes: 2 additions & 2 deletions test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ void AdamWExecutor::compute() {
interface_timer_.stop();
MLUOP_CHECK(mluOpDestroyAdamWDescriptor(adamw_desc));
} else {
VLOG(4) << "call mluApplyAdamW. ";
VLOG(4) << "call mluAdamW. ";
const int size = mluOpGetTensorElementNum(desc_momentum) * sizeof(float);
interface_timer_.start();
const auto adamw_status = bangc_kernels::mluApplyAdamW(
const auto adamw_status = bangc_kernels::mluAdamW(
handle_->queue, fp32_lr, fp32_beta1, fp32_beta2, fp32_bias1, fp32_bias2,
fp32_epsilon, fp32_weight_decay, fp32_scale, use_nesterov, size,
BANG_WRAP_T((Eigen::bfloat16 *)dev_paramh),
Expand Down

0 comments on commit c952491

Please sign in to comment.