Skip to content

Commit

Permalink
minor clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Sep 9, 2022
1 parent 0208afa commit 77f831b
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 284 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,8 +1405,8 @@ void IterDomain::parallelize(ParallelType t) {
// TORCH_CHECK(
// t == ParallelType::Vectorize || t == ParallelType::TIDx ||
// t == ParallelType::Serial || t == ParallelType::Mma,
// "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids",
// t);
// "Parallel type other than serial, tidx, vectorize not allowed for mma
// swizzled ids", t);
}

parallel_type_ = t;
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ void validateDoubleBufferedTensor(const TensorView* tv) {
const auto c_mem_type = tv->getMemoryType();
// TORCH_INTERNAL_ASSERT(
// (p_mem_type == MemoryType::Global &&
// (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) ||
// (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local))
// ||
// (c_mem_type == MemoryType::Local),
// "Invalid tensor to double-buffer: ",
// tv->toString(),
Expand Down Expand Up @@ -146,9 +147,8 @@ class DoubleBufferFusionInspector : private IterVisitor {
bool requireEpilogue(const std::vector<Expr*>& exprs) {
return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) {
return expr->input(0)->as<TensorView>()->getMemoryType() ==
MemoryType::Shared ||
expr->input(0)->as<TensorView>()->getMemoryType() ==
MemoryType::Local;
MemoryType::Shared ||
expr->input(0)->as<TensorView>()->getMemoryType() == MemoryType::Local;
});
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ BasicAllocInfo getAllocInformation(
outer_alloc_found = true;
}

if(tv->getMemoryType()==MemoryType::Shared && !fl_id->isThread()){
if (tv->getMemoryType() == MemoryType::Shared && !fl_id->isThread()) {
outer_alloc_found = true;
}

Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ void scheduleMatmul(
.propagateToBoundary());

c_smem->computeAt(c, 3);
c->reorder({{-1,-2}, {-2,-1}});
c->reorder({{-1, -2}, {-2, -1}});
// 16 x 128, with half of the warps:

// Output vectorize by 4:
Expand All @@ -566,7 +566,6 @@ void scheduleMatmul(
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->doubleBuffer();


if (params.index_lift_options.lift_gmem_read_address) {
a->liftReadAddress();
b->liftReadAddress();
Expand Down
Loading

0 comments on commit 77f831b

Please sign in to comment.