Skip to content

Commit

Permalink
Add special build for testing serialization via a serialization round…
Browse files Browse the repository at this point in the history
…trip in JIT compilation and fix serialization leaks (halide#7763)

* add back JIT testing, enclosed in #ifdef blocks

* fix typo

* nits

* WITH_SERIALIZATION_JIT->WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING

* fix self-reference leaks: now uses weak function ptr in reverse function mappings

* Move clang-tidy checks back to Linux

Recent changes in the GHA runners for macOS don't play well with clang-tidy; rather than sink any more time into debugging it, I'm going to revert the relevant parts of halide#7746 so that it runs on the less-finicky Linux runners instead.

* bogus

* Update Generator.cpp

* Update Generator.cpp

* call copy_to_host before serializing buffers

* throw an error if we serialize on-device buffer

* Skip specialize_to_gpu

* Update Pipeline.cpp

* Skip two more tests

* use serialize to memory during jit testing

* makefile update

* makefile fix

* skip the tutorial if flatc is not there

* fix

* fix signature

* fix makefile

* trigger buildbot

---------

Co-authored-by: Steven Johnson <[email protected]>
  • Loading branch information
2 people authored and ardier committed Mar 3, 2024
1 parent 47433d8 commit d5e8592
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 4 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,12 @@ tutorial_%: $(BIN_DIR)/tutorial_% $(TMP_DIR)/images/rgb.png $(TMP_DIR)/images/gr
cd $(TMP_DIR) ; $(CURDIR)/$<
@-echo

# Skip the serialization tutorial, if we didn't build -DWITH_SERIALIZATION
ifeq (,$(shell which flatc))
tutorial_lesson_23_serialization:
@echo "Skipping tutorial lesson 23 (serialization not enabled) ..."
endif

test_mullapudi2016: $(MULLAPUDI2016_TESTS:$(ROOT_DIR)/test/autoschedulers/mullapudi2016/%.cpp=mullapudi2016_%)

mullapudi2016_%: $(BIN_DIR)/mullapudi2016_% $(BIN_MULLAPUDI2016)
Expand Down
7 changes: 7 additions & 0 deletions cmake/HalideTestHelpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ function(add_halide_test TARGET)
CXX_VISIBILITY_PRESET hidden
VISIBILITY_INLINES_HIDDEN TRUE)


if (WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
if (WITH_SERIALIZATION)
target_compile_definitions(${TARGET} PRIVATE WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
endif ()
endif ()

# Add a meta-target for each group, to allow us to build by group easily
foreach (GROUP IN LISTS args_GROUPS)
set(META_TARGET build_${GROUP})
Expand Down
9 changes: 9 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,15 @@ if (WITH_SERIALIZATION)
target_compile_definitions(Halide PRIVATE WITH_SERIALIZATION)
endif ()

# Enable serialization testing by intercepting JIT compilation with a serialization roundtrip;
# This is used only for special builds made specifically for testing, and must be disabled by default.
option(WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING "Intercepting JIT compilation with a serialization roundtrip, for test only" OFF)
if (WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
if (WITH_SERIALIZATION)
target_compile_definitions(Halide PRIVATE WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING)
endif ()
endif ()

add_library(Halide::Halide ALIAS Halide)

target_link_libraries(Halide PRIVATE Halide::LLVM)
Expand Down
7 changes: 6 additions & 1 deletion src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,12 @@ void Deserializer::build_reverse_function_mappings(const std::vector<Function> &
}
int count = 0;
for (const auto &f : functions) {
this->reverse_function_mappings[count++] = f.get_contents();
// The reverse function mappings are used in places where only weak references are needed.
FunctionPtr ptr;
ptr.strong = nullptr;
ptr.weak = f.get_contents().group();
ptr.idx = f.get_contents().idx;
this->reverse_function_mappings[count++] = ptr;
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,24 @@ void Pipeline::compile_jit(const Target &target_arg) {
// Clear all cached info in case there is an error.
contents->invalidate_cache();

#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
std::map<std::string, Parameter> external_params;
std::vector<uint8_t> data;
serialize_pipeline(*this, data, external_params);
Pipeline deserialized_pipe = deserialize_pipeline(data, external_params);
std::vector<Function> outputs;
for (const Func &f : deserialized_pipe.outputs()) {
outputs.push_back(f.function());
}
// We save the original output functions and requirements,
// and restore them once all lowering is done,
// so that reschedule/reorder storage can be properly handled.
std::vector<Function> origin_outputs = contents->outputs;
std::vector<Internal::Stmt> origin_requirements = contents->requirements;
contents->outputs = outputs;
contents->requirements = deserialized_pipe.requirements();
#endif

// Infer an arguments vector
infer_arguments();

Expand All @@ -596,6 +614,11 @@ void Pipeline::compile_jit(const Target &target_arg) {
Module module = compile_to_module(args, generate_function_name(), target).resolve_submodules();
std::map<std::string, JITExtern> lowered_externs = contents->jit_externs;
contents->jit_cache = compile_jit_cache(module, std::move(args), contents->outputs, contents->jit_externs, target);
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
// Restore the original outputs and requirements.
contents->outputs = origin_outputs;
contents->requirements = origin_requirements;
#endif
}

Callable Pipeline::compile_to_callable(const std::vector<Argument> &args_in, const Target &target_arg) {
Expand Down
10 changes: 7 additions & 3 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class Serializer {

Offset<Serialize::ExternFuncArgument> serialize_extern_func_argument(FlatBufferBuilder &builder, const ExternFuncArgument &extern_func_argument);

Offset<Serialize::Buffer> serialize_buffer(FlatBufferBuilder &builder, const Buffer<> &buffer);
Offset<Serialize::Buffer> serialize_buffer(FlatBufferBuilder &builder, Buffer<> buffer);

std::vector<Offset<Serialize::WrapperRef>> serialize_wrapper_refs(FlatBufferBuilder &builder, const std::map<std::string, FunctionPtr> &wrappers);

Expand Down Expand Up @@ -1380,10 +1380,14 @@ Offset<Serialize::ExternFuncArgument> Serializer::serialize_extern_func_argument
}
}

Offset<Serialize::Buffer> Serializer::serialize_buffer(FlatBufferBuilder &builder, const Buffer<> &buffer) {
Offset<Serialize::Buffer> Serializer::serialize_buffer(FlatBufferBuilder &builder, Buffer<> buffer) {
if (!buffer.defined()) {
return Serialize::CreateBuffer(builder, false);
}
if (buffer.device_dirty()) {
user_error << "Cannot serialize on-device buffer: " << buffer.name() << "\n";
}
buffer.copy_to_host();
const auto name_serialized = serialize_string(builder, buffer.name());
const auto type_serialized = serialize_type(builder, buffer.type());
const int32_t dimensions = buffer.dimensions();
Expand Down Expand Up @@ -1475,7 +1479,7 @@ void Serializer::serialize(const Pipeline &pipeline, std::vector<uint8_t> &resul

std::vector<Offset<Serialize::Buffer>> buffers_serialized;
buffers_serialized.reserve(buffers_in_pipeline.size());
for (const auto &buffer : buffers_in_pipeline) {
for (auto &buffer : buffers_in_pipeline) {
buffers_serialized.push_back(serialize_buffer(builder, buffer.second));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
using namespace Halide;

int main(int argc, char **argv) {
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n");
return 0;
#endif

Target t(get_jit_target_from_environment());
if (!t.has_gpu_feature()) {
printf("[SKIP] No GPU target enabled.\n");
Expand Down
4 changes: 4 additions & 0 deletions test/correctness/leak_device_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ void halide_print(JITUserContext *user_context, const char *str) {
}

int main(int argc, char **argv) {
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n");
return 0;
#endif

Target target = get_jit_target_from_environment();

Expand Down
5 changes: 5 additions & 0 deletions test/correctness/specialize_to_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
using namespace Halide;

int main(int argc, char **argv) {
#ifdef WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING
printf("[SKIP] Serialization won't preserve GPU buffers, skipping.\n");
return 0;
#endif

if (!get_jit_target_from_environment().has_gpu_feature()) {
printf("[SKIP] No GPU target enabled.\n");
return 0;
Expand Down

0 comments on commit d5e8592

Please sign in to comment.