diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 0cce00c18a5..66bff554925 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -136,6 +136,7 @@ ExecutionOptions CreateExecutionOptions( *execution_options.mutable_shape_with_output_layout() = result_shape.ToProto(); } + execution_options.set_seed(build_options.seed()); execution_options.set_num_replicas(build_options.num_replicas()); execution_options.set_num_partitions(build_options.num_partitions()); execution_options.set_use_spmd_partitioning( diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 34f95b35e0f..d952c7c2b30 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -77,6 +77,10 @@ class ExecutableBuildOptions { // debugging. std::string ToString() const; + // The random seed for compilation. + int seed() const { return seed_; }; + void set_seed(int seed) { seed_ = seed; } + // The number of replicas of this computation that are to be executed. // Defaults to 1. int num_replicas() const { return num_replicas_; } @@ -189,6 +193,8 @@ class ExecutableBuildOptions { bool run_backend_only_ = false; bool allow_spmd_sharding_propagation_to_output_ = false; tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr; + // Added by Alpa + int seed_ = 42; }; // Creates an ExecutionOptions based on a given ExecutableBuildOptions and diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 1af7f34135a..49cb9dd05f9 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2935,7 +2935,8 @@ XlaOp XlaBuilder::CrossReplicaSum( XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, - const std::optional& shape_with_layout) { + const std::optional& shape_with_layout, + const std::optional use_global_device_ids) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -2992,6 +2993,9 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation, if (channel_id.has_value()) { instr.set_channel_id(channel_id->handle()); } + if (use_global_device_ids.has_value()) { + instr.set_use_global_device_ids(use_global_device_ids.value()); + } AddCalledComputation(computation, &instr); @@ -3071,7 +3075,9 @@ XlaOp XlaBuilder::ReduceScatter( XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout) { + const std::optional& channel_id, + const std::optional& layout, + const std::optional use_global_device_ids) { // Array all_to_all may need to violate layout constraint to be legal so use // the tuple version. if (layout.has_value()) { @@ -3079,12 +3085,14 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension, split_count, replica_groups, layout); } return AllToAllArray(operand, split_dimension, concat_dimension, split_count, - replica_groups); + replica_groups, channel_id, use_global_device_ids); } XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups) { + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional use_global_device_ids) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( @@ -3103,6 +3111,14 @@ XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension, *instr.add_replica_groups() = group; } } + + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } + if (use_global_device_ids.has_value()) { + instr.set_use_global_device_ids(use_global_device_ids.value()); + } + instr.add_dimensions(split_dimension); TF_ASSIGN_OR_RETURN( XlaOp all_to_all, @@ -4663,9 +4679,11 @@ XlaOp CrossReplicaSum(const XlaOp operand, XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, - const std::optional& shape_with_layout) { + const std::optional& shape_with_layout, + const std::optional use_global_device_ids) { return operand.builder()->AllReduce(operand, computation, replica_groups, - channel_id, shape_with_layout); + channel_id, shape_with_layout, + use_global_device_ids); } XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation, @@ -4682,9 +4700,12 @@ XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation, XlaOp AllToAll(const XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout) { + const std::optional& channel_id, + const std::optional& layout, + const std::optional use_global_device_ids) { return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, - split_count, replica_groups, layout); + split_count, replica_groups, channel_id, layout, + use_global_device_ids); } XlaOp AllToAllTuple(absl::Span operands, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 865a62cac40..448d52dafe5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -749,7 +749,8 @@ class XlaBuilder { XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt); + const std::optional& shape_with_layout = std::nullopt, + const std::optional use_global_device_ids = std::nullopt); XlaOp ReduceScatter( XlaOp operand, const XlaComputation& computation, @@ -762,7 +763,9 @@ class XlaBuilder { XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout = std::nullopt); + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + const std::optional use_global_device_ids = std::nullopt); XlaOp AllToAllTuple(absl::Span operands, absl::Span replica_groups, @@ -1362,7 +1365,8 @@ class XlaBuilder { friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, - const std::optional& shape_with_layout); + const std::optional& shape_with_layout, + const std::optional use_global_device_ids); friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, int64_t shard_count, absl::Span replica_groups, @@ -1373,7 +1377,9 @@ class XlaBuilder { friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout); + const std::optional& channel_id, + const std::optional& layout, + const std::optional use_global_device_ids); friend XlaOp AllToAllTuple(absl::Span operands, absl::Span replica_groups, const std::optional& layout); @@ -1517,7 +1523,9 @@ class XlaBuilder { XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups); + absl::Span replica_groups, + const std::optional& channel_id=std::nullopt, + const std::optional use_global_device_ids=std::nullopt); // Creates an op with the given opcode and the output shape. virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, @@ -2343,7 +2351,8 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt); + const std::optional& shape_with_layout = std::nullopt, + const std::optional use_global_device_ids = std::nullopt); XlaOp ReduceScatter( XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, @@ -2359,7 +2368,9 @@ XlaOp ReduceScatter( XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt); + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + const std::optional use_global_device_ids = std::nullopt); XlaOp AllToAllTuple(absl::Span operand, absl::Span replica_groups = {}, diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 533c9ed540b..e264024f261 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -651,6 +651,11 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@pybind11", + # Added by Alpa + "//tensorflow/compiler/xla/service:pass_context", + "//tensorflow/compiler/xla/service/gpu:gpu_cost_model", + "//tensorflow/compiler/xla/service/spmd:alpa_compiler", + "//tensorflow/compiler/xla/service/spmd:grad_acc_rewrite", ], ) @@ -746,6 +751,8 @@ pybind_extension( "//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core/distributed_runtime/preemption:preemption_sync_manager", "//tensorflow/python:bfloat16_lib", + # Added by Alpa + "//tensorflow/compiler/xla/service/gpu:alpa_nccl_wrapper", ] + select({ ":gpu_enabled": [ "//tensorflow/compiler/xla/pjrt:gpu_device", diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 8d00c5136c9..2f74072ee81 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -203,6 +203,11 @@ StatusOr> StridesToLayout( if (strides[a] > strides[b]) { return false; } + // FIXME(yonghao): This is only a walk-around. + // Should support isConsistent([1,1]{1,0}, [1,1]{0,1}) in type check + if (dims[a] == dims[b]) { + return a > b; + } return dims[a] == 1 && dims[b] != 1; }); int64_t stride = 1; diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc index 62025b4ad55..d92dde85c0f 100644 --- a/tensorflow/compiler/xla/python/ops.cc +++ b/tensorflow/compiler/xla/python/ops.cc @@ -36,6 +36,9 @@ limitations under the License. #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +// Added by Alpa for TorchIndexSelect +#include "tensorflow/compiler/xla/client/lib/slicing.h" + namespace xla { namespace py = pybind11; @@ -80,12 +83,13 @@ void BuildOpsSubmodule(py::module* m) { "AllReduce", static_cast, - const std::optional&, const std::optional&)>( - &AllReduce), + const std::optional&, const std::optional&, + const std::optional)>(&AllReduce), py::arg("operand"), py::arg("computation"), py::arg("replica_groups") = py::list(), py::arg("channel_id") = std::nullopt, - py::arg("shape_with_layout") = std::nullopt); + py::arg("shape_with_layout") = std::nullopt, + py::arg("use_global_device_ids") = std::nullopt); ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"), py::arg("computation"), py::arg("scatter_dimension"), py::arg("shard_count"), py::arg("replica_groups") = py::list(), @@ -95,7 +99,9 @@ void BuildOpsSubmodule(py::module* m) { ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), py::arg("concat_dimension"), py::arg("split_count"), py::arg("replica_groups") = py::list(), - py::arg("layout") = std::nullopt); + py::arg("channel_id") = std::nullopt, + py::arg("layout") = std::nullopt, + py::arg("use_global_device_ids") = std::nullopt); ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"), py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"), py::arg("comparator"), py::arg("recall_target") = 0.9, @@ -429,6 +435,10 @@ void BuildOpsSubmodule(py::module* m) { py::arg("b"), py::arg("x")); ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q")); + // Added by Alpa + ops.def("IndexSelect", &TorchIndexSelect, py::arg("input"), py::arg("index"), + py::arg("dim"), py::arg("batch_dims") = 0); + #define BINARY_OP(op) \ ops.def( \ #op, \ diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 6b9c8b7f930..3e18179397a 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -66,6 +66,12 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/python/lib/core/bfloat16.h" +// Added by Alpa +#include "tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" + +PYBIND11_MAKE_OPAQUE(std::vector); + // TODO(phawkins): remove host_id properties after JAX is update to avoid them. namespace xla { @@ -143,6 +149,72 @@ PYBIND11_MODULE(xla_extension, m) { .def_property_readonly( "client", [](const ClientAndPtr& device) { return device.client; }) + // Added by Alpa + .def("set_seed", [](const PjRtDevice& device, int seed) { + xla::PjRtClient* client = device.client(); + xla::PjRtStreamExecutorClient* stream_client = + dynamic_cast(client); + CHECK(stream_client != nullptr); + stream_client->device_state(device.local_hardware_id()).SetPrngSeed(seed); + return Status::OK(); + }) + .def("memory_allocated", [](const PjRtDevice& device) { + const int64_t invalid = -1; + + xla::PjRtClient* client = device.client(); + if (client->platform_name() != "gpu") { + return invalid; + } + xla::PjRtStreamExecutorClient* gpu_client = + dynamic_cast(client); + CHECK(gpu_client != nullptr); + return gpu_client->allocator()->bytes_used(device.local_hardware_id()); + }) + .def("max_memory_allocated", [](const PjRtDevice& device) { + const int64_t invalid = -1; + + xla::PjRtClient* client = device.client(); + if (client->platform_name() != "gpu") { + return invalid; + } + xla::PjRtStreamExecutorClient* gpu_client = + dynamic_cast(client); + CHECK(gpu_client != nullptr); + return gpu_client->allocator()->bytes_peak_in_use(device.local_hardware_id()); + }) + .def("available_memory", [](const PjRtDevice& device) { + const int64_t invalid = -1; + + xla::PjRtClient* client = device.client(); + if (client->platform_name() != "gpu") { + return invalid; + } + xla::PjRtStreamExecutorClient* gpu_client = + dynamic_cast(client); + CHECK(gpu_client != nullptr); + return gpu_client->allocator()->bytes_available(device.local_hardware_id()); + }) + .def("clear_memory_stats", [](const PjRtDevice& device) { + const bool invalid = false; + + xla::PjRtClient* client = device.client(); + if (client->platform_name() != "gpu") { + return invalid; + } + xla::PjRtStreamExecutorClient* gpu_client = + dynamic_cast(client); + CHECK(gpu_client != nullptr); + return gpu_client->allocator()->ClearStats(device.local_hardware_id()); + }) + .def("synchronize_all_activity", [](PjRtDevice& device) { + PjRtStreamExecutorDevice* stream_device = + dynamic_cast(&device); + CHECK_NE(stream_device, nullptr); + TF_ASSIGN_OR_RETURN(LocalDeviceState* local_device, + stream_device->GetLocalDeviceState()); + local_device->SynchronizeAllActivity(); + return Status::OK(); + }) .def("__str__", &PjRtDevice::DebugString) .def("__repr__", &PjRtDevice::ToString) .def("transfer_to_infeed", @@ -377,6 +449,15 @@ PYBIND11_MODULE(xla_extension, m) { &PyExecutable::ExecuteShardedOnLocalDevices, py::arg("arguments")) .def("hlo_modules", &PyExecutable::HloModules) .def("keep_alive", &PyExecutable::KeepAlive) + // Added by Alpa + .def("total_allocation_size", [](PyExecutable* exec){ + const PjRtLoadedExecutable* pjrt_executable = &exec->pjrt_executable(); + const PjRtStreamExecutorExecutable* stream_executable = dynamic_cast(pjrt_executable); + absl::Span> local_executables =\ + stream_executable->executables(); + Executable* executable = local_executables[0]->executable(); + return executable->TotalAllocationSize(); + }) .def_property_readonly("traceback", &PyExecutable::traceback) .def_property_readonly("fingerprint", [](PyExecutable* exec) -> py::object { @@ -554,6 +635,29 @@ PYBIND11_MODULE(xla_extension, m) { m.def("pprof_profile_to_json", &PprofProfileToJson, "Decodes an uncompressed pprof Profile protocol buffer into a JSON " "representation"); + + py::class_> + nccl_comm_storage(m, "nccl_comm_storage"); + m.def("nccl_init_communicator", &gpu::NcclInitCommunicator, + "Initialize single thread communicators"); + m.def("nccl_local_all_gather", &gpu::NcclLocalAllGather, "nccl local allgather"); + m.def("nccl_destroy_comms", &gpu::NcclDestroyComms, "destroy comms"); + m.def("nccl_get_unique_id", &gpu::NcclGetUniqueId, "get unique nccl id"); + m.def("nccl_get_version", &gpu::NcclGetVersion, "get nccl version"); + m.def("nccl_broadcast_partial_gpus", &gpu::NcclBroadcastPartialGPUs, + "nccl broadcast with only a subset of gpus in the host are involved"); + m.def("nccl_create_communicators", &gpu::NcclCreateCommunicators, + "nccl create communicators for multiple threads case"); + m.def("nccl_create_communicators_no_stream", + &gpu::NcclCreateCommunicatorsNoStream, + "nccl create pure communicators"); + m.def("get_buffer_device_id", &gpu::GetBufferDeviceId, + "get the local device id for one pybuffer"); + m.def("nccl_recv", &gpu::NcclRecv, "nccl recv data"); + m.def("nccl_send", &gpu::NcclSend, "nccl send data"); + m.def("set_cross_mesh_communicator", &gpu::SetCrossMeshCommunicators, + "set nccl communicators for cross mesh collective communication"); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_compiler.cc b/tensorflow/compiler/xla/python/xla_compiler.cc index 108db0193dd..7c8fd5af4bc 100644 --- a/tensorflow/compiler/xla/python/xla_compiler.cc +++ b/tensorflow/compiler/xla/python/xla_compiler.cc @@ -55,6 +55,12 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/proto_serialization.h" +// Added by Alpa +#include "tensorflow/compiler/xla/service/pass_context.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_cost_model.h" +#include "tensorflow/compiler/xla/service/spmd/alpa_compiler.h" +#include "tensorflow/compiler/xla/service/spmd/grad_acc_rewrite.h" + namespace xla { namespace { @@ -406,7 +412,22 @@ void BuildXlaCompilerSubmodule(py::module& m) { py::arg("print_large_constants") = false) .def("as_hlo_dot_graph", &GetComputationHloDotGraph) .def("hash", &HashComputation) - .def("as_hlo_module", &GetHloModule); + .def("as_hlo_module", &GetHloModule) + .def("setup_alias", + [](XlaComputation& computation, const std::vector& output_index, + int64_t param_number, const std::vector& param_index) { + HloInputOutputAliasProto::AliasEntryProto entry; + for (auto i : output_index) { + entry.add_output_shape_index(i); + } + entry.set_parameter_number(param_number); + for (auto i : param_index) { + entry.add_parameter_shape_index(i); + } + entry.set_kind(Kind::MAY_ALIAS); + computation.mutable_proto()->mutable_input_output_alias() + ->add_entries()->Swap(&entry); + }); py::class_ hlo_print_options_class(m, "HloPrintOptions"); hlo_print_options_class.def(py::init<>()) @@ -459,6 +480,18 @@ void BuildXlaCompilerSubmodule(py::module& m) { &HloPrintOptions::is_in_nested_computation, &HloPrintOptions::set_is_in_nested_computation); + // Added by Alpa + py::class_ hlo_sharding_class(m, "HloSharding"); + hlo_sharding_class + .def(py::init([](const py::bytes& serialized_hlo_sharding_proto) { + OpSharding proto; + proto.ParseFromString(std::string(serialized_hlo_sharding_proto)); + return ValueOrThrow(HloSharding::FromProto(proto)); + })) + .def("proto_tuple", [](const HloSharding& hlo_sharding) { + return hlo_sharding.ToProto(); + }); + py::class_> hlo_module_class( m, "HloModule"); hlo_module_class.def_property_readonly("name", &HloModule::name) @@ -469,23 +502,32 @@ void BuildXlaCompilerSubmodule(py::module& m) { py::arg("options") = HloPrintOptions()) .def("as_serialized_hlo_module_proto", &GetHloModuleSerializedProto) .def("from_serialized_hlo_module_proto", &HloModuleFromSerializedProto) - .def_property_readonly( - "spmd_output_sharding", - [](const HloModule& m) -> std::optional { - if (!m.has_spmd_output_sharding()) return std::nullopt; - return m.spmd_output_sharding().ToProto(); + // Added by Alpa + .def("has_schedule", &HloModule::has_schedule) + .def("spmd_output_sharding", &HloModule::spmd_output_sharding) + .def("spmd_parameters_shardings", &HloModule::spmd_parameters_shardings) + .def("set_spmd_output_sharding", &HloModule::set_spmd_output_sharding) + .def("set_spmd_parameters_shardings", &HloModule::set_spmd_parameters_shardings) + .def("infer_spmd_shardings", &HloModule::infer_spmd_shardings) + .def("setup_alias", [](std::shared_ptr hlo_module, + const std::vector& output_index, + int64_t param_number, + const std::vector& param_index) { + hlo_module->input_output_alias_config().SetUpAlias( + ShapeIndex(output_index.begin(), output_index.end()), + param_number, + ShapeIndex(param_index.begin(), param_index.end())); }) - .def_property_readonly( - "spmd_parameters_shardings", - [](const HloModule& m) - -> std::optional> { - if (!m.has_spmd_parameters_shardings()) return std::nullopt; - std::vector param_shardings; - for (const auto& parameter_sharding : - m.spmd_parameters_shardings()) { - param_shardings.push_back(parameter_sharding.ToProto()); + .def("program_shape", [](const HloModule& hlo_module) { + return hlo_module.entry_computation_layout().ComputeProgramShape(); + }) + .def("parameter_shapes", [](const HloModule& hlo_module) -> std::vector{ + const auto params = hlo_module.entry_computation()->parameter_instructions(); + std::vector ret(params.size()); + for (size_t i = 0; i < params.size(); ++i) { + ret[i] = params[i]->shape(); } - return param_shardings; + return ret; }); m.def("hlo_module_to_dot_graph", @@ -694,6 +736,8 @@ void BuildXlaCompilerSubmodule(py::module& m) { : std::nullopt; }, &ExecutableBuildOptions::set_result_layout) + .def_property("seed", &ExecutableBuildOptions::seed, + &ExecutableBuildOptions::set_seed) .def_property("num_replicas", &ExecutableBuildOptions::num_replicas, &ExecutableBuildOptions::set_num_replicas) .def_property("num_partitions", &ExecutableBuildOptions::num_partitions, @@ -768,6 +812,42 @@ void BuildXlaCompilerSubmodule(py::module& m) { (*attr->mutable_map())[key] = value; }); + /***** Alpa Functions Begin *****/ + m.def("set_pass_context", &pass_context::SetPassContext); + m.def("clear_pass_context", &pass_context::ClearPassContext); + m.def("estimate_hlo_module_cost", &gpu::EstimateHloModuleCost); + m.def("set_hlo_module_output_shardings", &spmd::SetHloModuleOutputShardings); + m.def("set_hlo_module_input_shardings", &spmd::SetHloModuleInputShardings); + m.def("get_grad_sync_channel_ids", &spmd::GetGradSyncChannelIds); + m.def("get_alpa_jaxlib_version", [] { return "0.1.1"; }); + + m.def("run_auto_sharding", + [](HloModule* hlo_module, const CompileOptions& options) { + py::gil_scoped_release gil_release; + TF_RETURN_IF_ERROR(spmd::RunAutoShardingPass(hlo_module, options)); + return Status::OK(); + }, + py::arg("hlo_module"), py::arg("compile_options") = CompileOptions()); + + m.def("run_spmd_partitioner", + [](HloModule* hlo_module, const CompileOptions& options) { + py::gil_scoped_release gil_release; + TF_RETURN_IF_ERROR(spmd::RunSpmdPartitionerPass(hlo_module, options)); + return Status::OK(); + }, + py::arg("hlo_module"), py::arg("compile_options") = CompileOptions()); + + m.def( + "hlo_module_count_flop_dot_conv_only", + [](const HloModule& module) -> double { + double ret = 0.0; + for (HloComputation* computation : module.computations()) { + ret += CountFlopDotConvOnly(*computation); + } + return ret; + }); + /***** Alpa Functions End *****/ + py::enum_(m, "PrecisionConfig_Precision") .value("DEFAULT", PrecisionConfig::DEFAULT) .value("HIGH", PrecisionConfig::HIGH)