From d5e8592c7ce01cf3ceb86e528a3f4a7c08edf5ca Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 7 Nov 2023 07:36:56 +0800 Subject: [PATCH] Add special build for testing serialization via a serialization roundtrip in JIT compilation and fix serialization leaks (#7763) * add back JIT testing, enclosed in #ifdef blocks * fix typo * nits * WITH_SERIALIZATION_JIT->WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING * fix self-reference leaks: now uses weak function ptr in reverse function mappings * Move clang-tidy checks back to Linux Recent changes in the GHA runners for macOS don't play well with clang-tidy; rather than sink any more time into debugging it, I'm going to revert the relevant parts of #7746 so that it runs on the less-finicky Linux runners instead. * bogus * Update Generator.cpp * Update Generator.cpp * call copy_to_host before serializing buffers * throw an error if we serialize on-device buffer * Skip specialize_to_gpu * Update Pipeline.cpp * Skip two more tests * use serialize to memory during jit testing * makefile update * makefile fix * skip the tutorial if flatc is not there * fix * fix signature * fix makefile * trigger buildbot --------- Co-authored-by: Steven Johnson --- Makefile | 6 +++++ cmake/HalideTestHelpers.cmake | 7 ++++++ src/CMakeLists.txt | 9 ++++++++ src/Deserialization.cpp | 7 +++++- src/Pipeline.cpp | 23 +++++++++++++++++++ src/Serialization.cpp | 10 +++++--- ..._give_input_buffers_device_allocations.cpp | 5 ++++ test/correctness/leak_device_memory.cpp | 4 ++++ test/correctness/specialize_to_gpu.cpp | 5 ++++ 9 files changed, 72 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 200742525cc5..67a0818b176a 100644 --- a/Makefile +++ b/Makefile @@ -2083,6 +2083,12 @@ tutorial_%: $(BIN_DIR)/tutorial_% $(TMP_DIR)/images/rgb.png $(TMP_DIR)/images/gr cd $(TMP_DIR) ; $(CURDIR)/$< @-echo +# Skip the serialization tutorial, if we didn't build -DWITH_SERIALIZATION +ifeq (,$(shell which flatc)) +tutorial_lesson_23_serialization: + @echo "Skipping tutorial lesson 23 (serialization not enabled) ..." +endif + test_mullapudi2016: $(MULLAPUDI2016_TESTS:$(ROOT_DIR)/test/autoschedulers/mullapudi2016/%.cpp=mullapudi2016_%) mullapudi2016_%: $(BIN_DIR)/mullapudi2016_% $(BIN_MULLAPUDI2016) diff --git a/cmake/HalideTestHelpers.cmake b/cmake/HalideTestHelpers.cmake index c23aba75fea6..e938d11d53ec 100644 --- a/cmake/HalideTestHelpers.cmake +++ b/cmake/HalideTestHelpers.cmake @@ -77,6 +77,13 @@ function(add_halide_test TARGET) CXX_VISIBILITY_PRESET hidden VISIBILITY_INLINES_HIDDEN TRUE) + + if (WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING) + if (WITH_SERIALIZATION) + target_compile_definitions(${TARGET} PRIVATE WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING) + endif () + endif () + # Add a meta-target for each group, to allow us to build by group easily foreach (GROUP IN LISTS args_GROUPS) set(META_TARGET build_${GROUP}) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d5d6a8a3832e..74e44de3c163 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -511,6 +511,15 @@ if (WITH_SERIALIZATION) target_compile_definitions(Halide PRIVATE WITH_SERIALIZATION) endif () +# Enable serialization testing by intercepting JIT compilation with a serialization roundtrip; +# This is used only for special builds made specifically for testing, and must be disabled by default. +option(WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING "Intercepting JIT compilation with a serialization roundtrip, for test only" OFF) +if (WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING) + if (WITH_SERIALIZATION) + target_compile_definitions(Halide PRIVATE WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING) + endif () +endif () + add_library(Halide::Halide ALIAS Halide) target_link_libraries(Halide PRIVATE Halide::LLVM) diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index 5d3979fc7f52..eda8ad93338b 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -1321,7 +1321,12 @@ void Deserializer::build_reverse_function_mappings(const std::vector & } int count = 0; for (const auto &f : functions) { - this->reverse_function_mappings[count++] = f.get_contents(); + // The reverse function mappings are used in places where only weak references are needed. + FunctionPtr ptr; + ptr.strong = nullptr; + ptr.weak = f.get_contents().group(); + ptr.idx = f.get_contents().idx; + this->reverse_function_mappings[count++] = ptr; } } diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 631033404137..c605d2038248 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -581,6 +581,24 @@ void Pipeline::compile_jit(const Target &target_arg) { // Clear all cached info in case there is an error. contents->invalidate_cache(); +#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING + std::map external_params; + std::vector data; + serialize_pipeline(*this, data, external_params); + Pipeline deserialized_pipe = deserialize_pipeline(data, external_params); + std::vector outputs; + for (const Func &f : deserialized_pipe.outputs()) { + outputs.push_back(f.function()); + } + // We save the original output functions and requirements, + // and restore them once all lowering is done, + // so that reschedule/reorder storage can be properly handled. + std::vector origin_outputs = contents->outputs; + std::vector origin_requirements = contents->requirements; + contents->outputs = outputs; + contents->requirements = deserialized_pipe.requirements(); +#endif + // Infer an arguments vector infer_arguments(); @@ -596,6 +614,11 @@ void Pipeline::compile_jit(const Target &target_arg) { Module module = compile_to_module(args, generate_function_name(), target).resolve_submodules(); std::map lowered_externs = contents->jit_externs; contents->jit_cache = compile_jit_cache(module, std::move(args), contents->outputs, contents->jit_externs, target); +#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING + // Restore the original outputs and requirements. + contents->outputs = origin_outputs; + contents->requirements = origin_requirements; +#endif } Callable Pipeline::compile_to_callable(const std::vector &args_in, const Target &target_arg) { diff --git a/src/Serialization.cpp b/src/Serialization.cpp index f38f79c5464a..038d5a1323e0 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -129,7 +129,7 @@ class Serializer { Offset serialize_extern_func_argument(FlatBufferBuilder &builder, const ExternFuncArgument &extern_func_argument); - Offset serialize_buffer(FlatBufferBuilder &builder, const Buffer<> &buffer); + Offset serialize_buffer(FlatBufferBuilder &builder, Buffer<> buffer); std::vector> serialize_wrapper_refs(FlatBufferBuilder &builder, const std::map &wrappers); @@ -1380,10 +1380,14 @@ Offset Serializer::serialize_extern_func_argument } } -Offset Serializer::serialize_buffer(FlatBufferBuilder &builder, const Buffer<> &buffer) { +Offset Serializer::serialize_buffer(FlatBufferBuilder &builder, Buffer<> buffer) { if (!buffer.defined()) { return Serialize::CreateBuffer(builder, false); } + if (buffer.device_dirty()) { + user_error << "Cannot serialize on-device buffer: " << buffer.name() << "\n"; + } + buffer.copy_to_host(); const auto name_serialized = serialize_string(builder, buffer.name()); const auto type_serialized = serialize_type(builder, buffer.type()); const int32_t dimensions = buffer.dimensions(); @@ -1475,7 +1479,7 @@ void Serializer::serialize(const Pipeline &pipeline, std::vector &resul std::vector> buffers_serialized; buffers_serialized.reserve(buffers_in_pipeline.size()); - for (const auto &buffer : buffers_in_pipeline) { + for (auto &buffer : buffers_in_pipeline) { buffers_serialized.push_back(serialize_buffer(builder, buffer.second)); } diff --git a/test/correctness/gpu_give_input_buffers_device_allocations.cpp b/test/correctness/gpu_give_input_buffers_device_allocations.cpp index a2f4d9618f63..666ce86d9b3f 100644 --- a/test/correctness/gpu_give_input_buffers_device_allocations.cpp +++ b/test/correctness/gpu_give_input_buffers_device_allocations.cpp @@ -4,6 +4,11 @@ using namespace Halide; int main(int argc, char **argv) { +#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING + printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n"); + return 0; +#endif + Target t(get_jit_target_from_environment()); if (!t.has_gpu_feature()) { printf("[SKIP] No GPU target enabled.\n"); diff --git a/test/correctness/leak_device_memory.cpp b/test/correctness/leak_device_memory.cpp index 086bb1cd5810..567aeddb5fd8 100644 --- a/test/correctness/leak_device_memory.cpp +++ b/test/correctness/leak_device_memory.cpp @@ -14,6 +14,10 @@ void halide_print(JITUserContext *user_context, const char *str) { } int main(int argc, char **argv) { +#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING + printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n"); + return 0; +#endif Target target = get_jit_target_from_environment(); diff --git a/test/correctness/specialize_to_gpu.cpp b/test/correctness/specialize_to_gpu.cpp index 0890e2ad6eae..8e9644114c6f 100644 --- a/test/correctness/specialize_to_gpu.cpp +++ b/test/correctness/specialize_to_gpu.cpp @@ -4,6 +4,11 @@ using namespace Halide; int main(int argc, char **argv) { +#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING + printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n"); + return 0; +#endif + if (!get_jit_target_from_environment().has_gpu_feature()) { printf("[SKIP] No GPU target enabled.\n"); return 0;