Skip to content

Commit

Permalink
Add all changes under service
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Aug 30, 2022
1 parent b66aedf commit b66746b
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 14 deletions.
24 changes: 21 additions & 3 deletions tensorflow/compiler/xla/service/all_reduce_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"

// Added by Alpa
#include "tensorflow/compiler/xla/service/spmd/grad_acc_rewrite.h"

namespace xla {
namespace {

Expand Down Expand Up @@ -87,6 +90,10 @@ Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
Cast<HloAllReduceInstruction>(to_combine.front())
->use_global_device_ids()));

if (to_combine.front()->metadata().op_name() == spmd::kSkippableAllReduce) {
combined->set_metadata_op_name(spmd::kSkippableAllReduce);
}

// We have to propagate the sharding manually because Domain instructions are
// not guaranteed to preserve it for side effecting instructions.
if (to_combine.front()->has_sharding()) {
Expand All @@ -111,6 +118,11 @@ AllReduceCombiner::AllReduceCombiner(int64_t combine_threshold_in_bytes,
: combine_threshold_in_bytes_(combine_threshold_in_bytes),
combine_threshold_count_(combine_threshold_count) {}

// Add a new boolean field to the original AllReduceKey.
// This field indicates whether the all-reduce is a skippable
// all-reduce for gradient accumulation.
using AllReduceKeyWithSkip = std::tuple<AllReduceKey, bool>;

StatusOr<bool> AllReduceCombiner::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -135,16 +147,22 @@ StatusOr<bool> AllReduceCombiner::Run(

auto key_fn =
[&domain_map](
const HloInstruction* instruction) -> std::optional<AllReduceKey> {
const HloInstruction* instruction) -> std::optional<AllReduceKeyWithSkip> {
if (instruction->opcode() != HloOpcode::kAllReduce) {
return std::nullopt;
}
return GetAllReduceKey(instruction, domain_map.get());
auto old_key = GetAllReduceKey(instruction, domain_map.get());
if (!old_key.has_value()) {
return absl::nullopt;
}
return AllReduceKeyWithSkip{
*old_key,
instruction->metadata().op_name() == spmd::kSkippableAllReduce};
};

TF_ASSIGN_OR_RETURN(
bool computation_changed,
CombineInstructionsByKey<AllReduceKey>(
CombineInstructionsByKey<AllReduceKeyWithSkip>(
computation, key_fn, &CombineAllReduces,
combine_threshold_in_bytes_, combine_threshold_count_));
changed |= computation_changed;
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/xla/service/all_reduce_contiguous.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace {

Status ReplaceWithContiguousAllReduce(HloAllReduceInstruction* all_reduce) {
TF_RET_CHECK(all_reduce);
TF_RET_CHECK(!all_reduce->has_sharding());
// Added by Alpa. We need to disable this check after a recent rebase.
//TF_RET_CHECK(!all_reduce->has_sharding());

HloComputation& computation = *all_reduce->parent(); // never null
PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
Expand Down Expand Up @@ -59,6 +60,7 @@ Status ReplaceWithContiguousAllReduce(HloAllReduceInstruction* all_reduce) {
all_reduce->replica_groups(),
/*constrain_layout=*/false, all_reduce->channel_id(),
all_reduce->use_global_device_ids()));
new_all_reduce->set_metadata_op_name(all_reduce->metadata().op_name());

// Slice from all-reduce result and bitcast back to the original shapes.
std::vector<HloInstruction*> outputs;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/xla/service/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ class Executable {
// Does not include the size of used libraries (e.g. cuDNN, Eigen, etc.).
virtual int64_t SizeOfGeneratedCodeInBytes() const;

// Return the total size of allocated buffers in bytes. This is GPU only.
virtual int64_t TotalAllocationSize() const { return -1; }

// Dumping helpers.
void set_hlo_proto(std::unique_ptr<xla::HloProto> hlo_proto) {
hlo_proto_ = std::move(hlo_proto);
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1245,4 +1245,18 @@ void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
}

double CountFlopDotConvOnly(const HloComputation& computation) {
auto analysis = absl::make_unique<HloCostAnalysis>([](const Shape&) { return 0; });
computation.Accept(analysis.get());

double ret = 0.0;
for (const HloInstruction* instruction : computation.instructions()) {
if (instruction->opcode() == HloOpcode::kDot ||
instruction->opcode() == HloOpcode::kConvolution) {
ret += analysis->flop_count(*instruction);
}
}
return ret;
}

} // namespace xla
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/hlo_cost_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
HloCostAnalysis& operator=(const HloCostAnalysis&) = delete;
};

// Count the number of floating point operations for
// dot and convolution in a HLO module.
double CountFlopDotConvOnly(const HloComputation& computation);

} // namespace xla

#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
13 changes: 13 additions & 0 deletions tensorflow/compiler/xla/service/hlo_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,19 @@ class HloModule {

Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;

void infer_spmd_shardings() {
std::vector<HloSharding> entry_params_shardings;
for (int64_t i = 0; i < entry_computation()->num_parameters(); ++i) {
auto param = entry_computation()->parameter_instruction(i);
CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
entry_params_shardings.push_back(param->sharding());
}
set_spmd_parameters_shardings(entry_params_shardings);
auto entry_root = entry_computation()->root_instruction();
CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
set_spmd_output_sharding(entry_root->sharding());
}

// Checks if this config has a list of entry parameters' HLO shardings for
// SPMD.
bool has_spmd_parameters_shardings() const {
Expand Down
16 changes: 9 additions & 7 deletions tensorflow/compiler/xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kDot:
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceStart:
case HloOpcode::kAllReduceDone:
Expand Down Expand Up @@ -2171,13 +2172,14 @@ Status VerifyChannels(const HloModule& module) {
Status CheckFusionInstruction(HloInstruction* fusion) {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
return InternalError(
"Instruction of fused computation does not match expected "
"instruction "
"%s.",
fusion->ToString());
}
// Temporarly disable this check due to our pass CommonComputationElimination.
//if (fusion != fused_computation->FusionInstruction()) {
// return InternalError(
// "Instruction of fused computation does not match expected "
// "instruction "
// "%s.",
// fusion->ToString());
//}

// Fused root instruction and fused parameters must all be owned by the
// fusion computation.
Expand Down
12 changes: 9 additions & 3 deletions tensorflow/compiler/xla/service/reduce_scatter_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size,
return IsPerIdOffset(offset->operand(1 - const_operand),
shard_size / *multiplier, map_id, group_size, ar);
}

// Added by Alpa
if (offset->opcode() == HloOpcode::kSubtract) {
// TODO(lmzheng): make the condition stronger.
return IsTableLookup(offset->operand(0)) && IsTableLookup(offset->operand(1));
}

if (shard_size == 1 && iota_group) {
bool id_mapping_is_identity = true;
for (int64_t id = 0; id < group_size; ++id) {
Expand Down Expand Up @@ -270,9 +277,7 @@ std::optional<ReduceScatterSpec> MatchReduceScatter(
int64_t num_replicas, bool allow_multiple_split_dims,
bool allow_intervening_reshape, int64_t min_rank,
HloPredicate match_partition_id, HloPredicate match_replica_id) {
if (!ar->shape().IsArray() || ar->constrain_layout() ||
(ar->IsCrossModuleAllReduce() &&
!ar->GetModule()->config().use_spmd_partitioning())) {
if (!ar->shape().IsArray() || ar->constrain_layout()) {
VLOG(2) << "Unsupported all-reduce: " << ar->ToString();
return std::nullopt;
}
Expand Down Expand Up @@ -407,6 +412,7 @@ std::optional<ReduceScatterSpec> MatchReduceScatter(
std::vector<int64_t> split_dims;
// First find a single dimension where the input and output of dynamic slice
// differ.
CHECK_EQ(ar->shape().rank(), user->shape().rank());
int num_dims = 0;
for (int64_t dim = 0; dim < ar->shape().rank(); ++dim) {
if (ar->shape().dimensions(dim) == user->shape().dimensions(dim)) {
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ bool IsPassthroughCustomOps(const HloInstruction* hlo) {
if (hlo->IsCustomCall("X64Combine")) {
return true;
}
// Added by Alpa
if (hlo->IsCustomCall("pipeline_marker") || hlo->IsCustomCall("identity")) {
return true;
}
if (hlo->operand_count() != 1 || !hlo->shape().IsArray() ||
!hlo->operand(0)->shape().IsArray() ||
hlo->operand(0)->shape().rank() != hlo->shape().rank()) {
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"

#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"

namespace xla {

StatusOr<bool> ZeroSizedHloElimination::Run(
Expand All @@ -34,6 +37,21 @@ StatusOr<bool> ZeroSizedHloElimination::Run(
for (HloComputation* comp :
module->MakeNonfusionComputations(execution_threads)) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
// Added by Alpa.
if (instruction->IsCustomCall("pipeline_marker")) {
// Set the input and output as alias
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
aliasing;

ShapeUtil::VisitorFunction visitor = [&](const Shape& shape,
const ShapeIndex& idx) {
aliasing.push_back(std::make_pair(idx, std::make_pair(0, idx)));
};
ShapeUtil::ForEachSubshape(instruction->shape(), visitor);
HloCustomCallInstruction* call = Cast<HloCustomCallInstruction>(instruction);
call->set_output_to_operand_aliasing(aliasing);
}

if (instruction->HasSideEffect() || !instruction->shape().IsArray() ||
instruction->opcode() == HloOpcode::kConstant) {
continue;
Expand Down

0 comments on commit b66746b

Please sign in to comment.