Skip to content

Commit

Permalink
[serialization] Add support to serialize to memory, and a basic seria…
Browse files Browse the repository at this point in the history
…lization tutorial (halide#7760)

* Add in-memory buffer serialize/deserialize support.

* Add basic serialization tutorial

* Clang format pass

* Update doc strings to use Doxygen formatted args

* Clear out data buffer during serialization

* Update serialization tutorial to use simple blur example with ImageParam

* Make parameter map optional for serialize halide#7849
Add error messages to deserializer for missing params
Update tutorial

* Clang format pass

---------

Co-authored-by: Derek Gerstmann <[email protected]>
Co-authored-by: Steven Johnson <[email protected]>
  • Loading branch information
3 people authored and ardier committed Mar 3, 2024
1 parent 18820f9 commit 4ce8831
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 11 deletions.
36 changes: 34 additions & 2 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ class Deserializer {
: external_params(external_params) {
}

// Deserialize a pipeline from the given filename
Pipeline deserialize(const std::string &filename);

// Deserialize a pipeline from the given input stream
Pipeline deserialize(std::istream &in);

// Deserialize a pipeline from the given buffer of bytes
Pipeline deserialize(const std::vector<uint8_t> &data);

private:
// Helper function to deserialize a homogenous vector from a flatbuffer vector,
// does not apply to union types like Stmt and Expr or enum types like MemoryType
Expand Down Expand Up @@ -445,6 +450,8 @@ void Deserializer::deserialize_function(const Serialize::Func *function, Functio
output_buffer = it->second;
} else if (auto it = parameters_in_pipeline.find(output_buffer_name); it != parameters_in_pipeline.end()) {
output_buffer = it->second;
} else if (!output_buffer_name.empty()) {
user_error << "unknown output buffer used in pipeline '" << output_buffer_name << "'\n";
}
output_buffers.push_back(output_buffer);
}
Expand Down Expand Up @@ -514,6 +521,8 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
const auto alignment = deserialize_modulus_remainder(store_stmt->alignment());
return Store::make(name, value, index, param, predicate, alignment);
Expand Down Expand Up @@ -771,6 +780,8 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
const auto alignment = deserialize_modulus_remainder(load_expr->alignment());
const auto type = deserialize_type(load_expr->type());
Expand Down Expand Up @@ -820,6 +831,8 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
const auto type = deserialize_type(call_expr->type());
return Call::make(type, name, args, call_type, func_ptr, value_index, image, param);
Expand All @@ -834,6 +847,8 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
auto image_name = deserialize_string(variable_expr->image_name());
Buffer<> image;
Expand Down Expand Up @@ -1031,6 +1046,8 @@ PrefetchDirective Deserializer::deserialize_prefetch_directive(const Serialize::
Parameter param;
if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
auto hl_prefetch_directive = PrefetchDirective();
hl_prefetch_directive.name = name;
Expand Down Expand Up @@ -1209,6 +1226,8 @@ ExternFuncArgument Deserializer::deserialize_extern_func_argument(const Serializ
image_param = it->second;
} else if (auto it = parameters_in_pipeline.find(image_param_name); it != parameters_in_pipeline.end()) {
image_param = it->second;
} else if (!image_param_name.empty()) {
user_error << "unknown image parameter used in pipeline '" << image_param_name << "'\n";
}
return ExternFuncArgument(image_param);
}
Expand Down Expand Up @@ -1304,9 +1323,12 @@ Pipeline Deserializer::deserialize(std::istream &in) {
in.seekg(0, std::ios::end);
int size = in.tellg();
in.seekg(0, std::ios::beg);
std::vector<char> data(size);
in.read(data.data(), size);
std::vector<uint8_t> data(size);
in.read((char *)data.data(), size);
return deserialize(data);
}

Pipeline Deserializer::deserialize(const std::vector<uint8_t> &data) {
const auto *pipeline_obj = Serialize::GetPipeline(data.data());
if (pipeline_obj == nullptr) {
user_warning << "deserialized pipeline is empty\n";
Expand Down Expand Up @@ -1385,6 +1407,11 @@ Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Para
return deserializer.deserialize(in);
}

Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
return deserializer.deserialize(buffer);
}

} // namespace Halide

#else // WITH_SERIALIZATION
Expand All @@ -1401,6 +1428,11 @@ Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Para
return Pipeline();
}

Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &external_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}

} // namespace Halide

#endif // WITH_SERIALIZATION
20 changes: 14 additions & 6 deletions src/Deserialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@

namespace Halide {

/**
* Deserialize a Halide pipeline from a file.
* filename should always end in .hlpipe suffix.
* external_params is an optional map, all parameters in the map
* will be treated as external parameters so won't be deserialized.
*/
/// @brief Deserialize a Halide pipeline from a file.
/// @param filename The location of the file to deserialize. Must use .hlpipe extension.
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params);

/// @brief Deserialize a Halide pipeline from an input stream.
/// @param in The input stream to read from containing a serialized Halide pipeline
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params);

/// @brief Deserialize a Halide pipeline from a byte buffer containing a serizalized pipeline in binary format
/// @param data The data buffer containing a serialized Halide pipeline
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(const std::vector<uint8_t> &data, const std::map<std::string, Parameter> &external_params);

} // namespace Halide

#endif
49 changes: 47 additions & 2 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ class Serializer {
public:
Serializer() = default;

// Serialize the given pipeline into the given filename
void serialize(const Pipeline &pipeline, const std::string &filename);

// Serialize the given pipeline into given the data buffer
void serialize(const Pipeline &pipeline, std::vector<uint8_t> &data);

const std::map<std::string, Parameter> &get_external_parameters() const {
return external_parameters;
}
Expand Down Expand Up @@ -1394,7 +1398,7 @@ void Serializer::build_function_mappings(const std::map<std::string, Function> &
}
}

void Serializer::serialize(const Pipeline &pipeline, const std::string &filename) {
void Serializer::serialize(const Pipeline &pipeline, std::vector<uint8_t> &result) {
FlatBufferBuilder builder(1024);

// extract the DAG, unwrap function from Funcs
Expand Down Expand Up @@ -1459,17 +1463,46 @@ void Serializer::serialize(const Pipeline &pipeline, const std::string &filename

uint8_t *buf = builder.GetBufferPointer();
int size = builder.GetSize();

if (buf != nullptr && size > 0) {
result.clear();
result.reserve(size);
result.insert(result.begin(), buf, buf + size);
} else {
user_error << "failed to serialize pipeline!\n";
}
}

void Serializer::serialize(const Pipeline &pipeline, const std::string &filename) {
std::vector<uint8_t> data;
serialize(pipeline, data);
std::ofstream out(filename, std::ios::out | std::ios::binary);
if (!out) {
user_error << "failed to open file " << filename << "\n";
exit(1);
}
out.write((char *)(buf), size);
out.write((char *)(data.data()), data.size());
out.close();
}

} // namespace Internal

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data) {
Internal::Serializer serializer;
serializer.serialize(pipeline, data);
}

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data, std::map<std::string, Parameter> &params) {
Internal::Serializer serializer;
serializer.serialize(pipeline, data);
params = serializer.get_external_parameters();
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename) {
Internal::Serializer serializer;
serializer.serialize(pipeline, filename);
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, std::map<std::string, Parameter> &params) {
Internal::Serializer serializer;
serializer.serialize(pipeline, filename);
Expand All @@ -1482,6 +1515,18 @@ void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, s

namespace Halide {

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data, std::map<std::string, Parameter> &params) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, std::map<std::string, Parameter> &params) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}
Expand Down
21 changes: 21 additions & 0 deletions src/Serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@

namespace Halide {

/// @brief Serialize a Halide pipeline into the given data buffer.
/// @param pipeline The Halide pipeline to serialize.
/// @param data The data buffer to store the serialized Halide pipeline into. Any existing contents will be destroyed.
/// @param params Map of named parameters which will get populated during serialization (can be used to bind external parameters to objects in the pipeline by name).
void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data);

/// @brief Serialize a Halide pipeline into the given data buffer.
/// @param pipeline The Halide pipeline to serialize.
/// @param data The data buffer to store the serialized Halide pipeline into. Any existing contents will be destroyed.
/// @param params Map of named parameters which will get populated during serialization (can be used to bind external parameters to objects in the pipeline by name).
void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data, std::map<std::string, Parameter> &params);

/// @brief Serialize a Halide pipeline into the given filename.
/// @param pipeline The Halide pipeline to serialize.
/// @param filename The location of the file to write into to store the serialized pipeline. Any existing contents will be destroyed.
void serialize_pipeline(const Pipeline &pipeline, const std::string &filename);

/// @brief Serialize a Halide pipeline into the given filename.
/// @param pipeline The Halide pipeline to serialize.
/// @param filename The location of the file to write into to store the serialized pipeline. Any existing contents will be destroyed.
/// @param params Map of named parameters which will get populated during serialization (can be used to bind external parameters to objects in the pipeline by name).
void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, std::map<std::string, Parameter> &params);

} // namespace Halide
Expand Down
3 changes: 2 additions & 1 deletion tutorial/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,6 @@ if (TARGET Halide::Mullapudi2016)
set_tests_properties(tutorial_lesson_21_auto_scheduler_run PROPERTIES LABELS "tutorial;multithreaded")
endif ()

# Lesson 22
# Lessons 22-23
add_tutorial(lesson_22_jit_performance.cpp)
add_tutorial(lesson_23_serialization.cpp WITH_IMAGE_IO)
Loading

0 comments on commit 4ce8831

Please sign in to comment.