Skip to content

Commit

Permalink
Add all changes to the xla_compiler and client's python interface
Browse files Browse the repository at this point in the history
Co-authored-by: Hexu Zhao <[email protected]>
Co-authored-by: Yonghao Zhuang <[email protected]>
  • Loading branch information
3 people committed Aug 30, 2022
1 parent b66746b commit 9c6f230
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 35 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/client/executable_build_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ ExecutionOptions CreateExecutionOptions(
*execution_options.mutable_shape_with_output_layout() =
result_shape.ToProto();
}
execution_options.set_seed(build_options.seed());
execution_options.set_num_replicas(build_options.num_replicas());
execution_options.set_num_partitions(build_options.num_partitions());
execution_options.set_use_spmd_partitioning(
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/xla/client/executable_build_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class ExecutableBuildOptions {
// debugging.
std::string ToString() const;

// The random seed for compilation.
int seed() const { return seed_; };
void set_seed(int seed) { seed_ = seed; }

// The number of replicas of this computation that are to be executed.
// Defaults to 1.
int num_replicas() const { return num_replicas_; }
Expand Down Expand Up @@ -189,6 +193,8 @@ class ExecutableBuildOptions {
bool run_backend_only_ = false;
bool allow_spmd_sharding_propagation_to_output_ = false;
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
// Added by Alpa
int seed_ = 42;
};

// Creates an ExecutionOptions based on a given ExecutableBuildOptions and
Expand Down
37 changes: 29 additions & 8 deletions tensorflow/compiler/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2935,7 +2935,8 @@ XlaOp XlaBuilder::CrossReplicaSum(
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Shape>& shape_with_layout) {
const std::optional<Shape>& shape_with_layout,
const std::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
Expand Down Expand Up @@ -2992,6 +2993,9 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}

AddCalledComputation(computation, &instr);

Expand Down Expand Up @@ -3071,20 +3075,24 @@ XlaOp XlaBuilder::ReduceScatter(
XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout) {
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids) {
// Array all_to_all may need to violate layout constraint to be legal so use
// the tuple version.
if (layout.has_value()) {
return AllToAllTuple(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
}
return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
replica_groups);
replica_groups, channel_id, use_global_device_ids);
}

XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups) {
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
Expand All @@ -3103,6 +3111,14 @@ XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
*instr.add_replica_groups() = group;
}
}

if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}

instr.add_dimensions(split_dimension);
TF_ASSIGN_OR_RETURN(
XlaOp all_to_all,
Expand Down Expand Up @@ -4663,9 +4679,11 @@ XlaOp CrossReplicaSum(const XlaOp operand,
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Shape>& shape_with_layout) {
const std::optional<Shape>& shape_with_layout,
const std::optional<bool> use_global_device_ids) {
return operand.builder()->AllReduce(operand, computation, replica_groups,
channel_id, shape_with_layout);
channel_id, shape_with_layout,
use_global_device_ids);
}

XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
Expand All @@ -4682,9 +4700,12 @@ XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout) {
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids) {
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
split_count, replica_groups, channel_id, layout,
use_global_device_ids);
}

XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
Expand Down
25 changes: 18 additions & 7 deletions tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,8 @@ class XlaBuilder {
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Shape>& shape_with_layout = std::nullopt);
const std::optional<Shape>& shape_with_layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp ReduceScatter(
XlaOp operand, const XlaComputation& computation,
Expand All @@ -762,7 +763,9 @@ class XlaBuilder {
XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
absl::Span<const ReplicaGroup> replica_groups,
Expand Down Expand Up @@ -1362,7 +1365,8 @@ class XlaBuilder {
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Shape>& shape_with_layout);
const std::optional<Shape>& shape_with_layout,
const std::optional<bool> use_global_device_ids);
friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
int64_t scatter_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
Expand All @@ -1373,7 +1377,9 @@ class XlaBuilder {
friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout);
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids);
friend XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout);
Expand Down Expand Up @@ -1517,7 +1523,9 @@ class XlaBuilder {

XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups);
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id=std::nullopt,
const std::optional<bool> use_global_device_ids=std::nullopt);

// Creates an op with the given opcode and the output shape.
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
Expand Down Expand Up @@ -2343,7 +2351,8 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Shape>& shape_with_layout = std::nullopt);
const std::optional<Shape>& shape_with_layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp ReduceScatter(
XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
Expand All @@ -2359,7 +2368,9 @@ XlaOp ReduceScatter(
XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<Layout>& layout = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp AllToAllTuple(absl::Span<const XlaOp> operand,
absl::Span<const ReplicaGroup> replica_groups = {},
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,11 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@pybind11",
# Added by Alpa
"//tensorflow/compiler/xla/service:pass_context",
"//tensorflow/compiler/xla/service/gpu:gpu_cost_model",
"//tensorflow/compiler/xla/service/spmd:alpa_compiler",
"//tensorflow/compiler/xla/service/spmd:grad_acc_rewrite",
],
)

Expand Down Expand Up @@ -746,6 +751,8 @@ pybind_extension(
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
"//tensorflow/core/distributed_runtime/preemption:preemption_sync_manager",
"//tensorflow/python:bfloat16_lib",
# Added by Alpa
"//tensorflow/compiler/xla/service/gpu:alpa_nccl_wrapper",
] + select({
":gpu_enabled": [
"//tensorflow/compiler/xla/pjrt:gpu_device",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/xla/python/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ StatusOr<std::vector<int64_t>> StridesToLayout(
if (strides[a] > strides[b]) {
return false;
}
// FIXME(yonghao): This is only a walk-around.
// Should support isConsistent([1,1]{1,0}, [1,1]{0,1}) in type check
if (dims[a] == dims[b]) {
return a > b;
}
return dims[a] == 1 && dims[b] != 1;
});
int64_t stride = 1;
Expand Down
18 changes: 14 additions & 4 deletions tensorflow/compiler/xla/python/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

// Added by Alpa for TorchIndexSelect
#include "tensorflow/compiler/xla/client/lib/slicing.h"

namespace xla {

namespace py = pybind11;
Expand Down Expand Up @@ -80,12 +83,13 @@ void BuildOpsSubmodule(py::module* m) {
"AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const std::optional<ChannelHandle>&, const std::optional<Shape>&)>(
&AllReduce),
const std::optional<ChannelHandle>&, const std::optional<Shape>&,
const std::optional<bool>)>(&AllReduce),
py::arg("operand"), py::arg("computation"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = std::nullopt,
py::arg("shape_with_layout") = std::nullopt);
py::arg("shape_with_layout") = std::nullopt,
py::arg("use_global_device_ids") = std::nullopt);
ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"),
py::arg("computation"), py::arg("scatter_dimension"),
py::arg("shard_count"), py::arg("replica_groups") = py::list(),
Expand All @@ -95,7 +99,9 @@ void BuildOpsSubmodule(py::module* m) {
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list(),
py::arg("layout") = std::nullopt);
py::arg("channel_id") = std::nullopt,
py::arg("layout") = std::nullopt,
py::arg("use_global_device_ids") = std::nullopt);
ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"),
py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"),
py::arg("comparator"), py::arg("recall_target") = 0.9,
Expand Down Expand Up @@ -429,6 +435,10 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("b"), py::arg("x"));
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));

// Added by Alpa
ops.def("IndexSelect", &TorchIndexSelect, py::arg("input"), py::arg("index"),
py::arg("dim"), py::arg("batch_dims") = 0);

#define BINARY_OP(op) \
ops.def( \
#op, \
Expand Down
Loading

0 comments on commit 9c6f230

Please sign in to comment.