权值热更新,即在runtime期间,支持在不重新加载模型的前提下,更新模型中的权值。
- Linux 常见操作系统版本(如 Ubuntu16.04,Ubuntu18.04,CentOS7.x 等),安装 docker(>=v18.00.0)应用程序;
- 服务器装配好寒武纪 300 系列及以上的智能加速卡,并安装好驱动(>=v4.20.6);
- 若不具备以上软硬件条件,可前往寒武纪开发者社区申请试用;
若基于寒武纪云平台环境可跳过该环节。否则需运行以下步骤:
请前往寒武纪开发者社区下载 MagicMind(version >= 0.13.0)镜像,名字如下:
magicmind_version_os.tar.gz, 例如 magicmind_0.13.1-1_ubuntu18.04.tar.gz
docker load -i magicmind_version_os.tar.gz
docker run -it --name=dockername \
--network=host --cap-add=sys_ptrace \
-v /your/host/path/weight_refit:/weight_refit \
-v /usr/bin/cnmon:/usr/bin/cnmon \
--device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl \
-w /weight_refit/ magicmind_version_image_name:tag_name /bin/bash
启用该功能需要在生成模型阶段,在build_config中添加"enable_refit": true
auto build_config = CreateIBuilderConfig();
build_config->ParseFromString(R"({"enable_refit": true})");
auto fit = magicmind::CreateIRefitter(engine);
int new_weight[9] = {9,8,7,6,5,4,3,2,1,0};
auto weight_tensor = magicmind::CreateIRTTensor(magicmind::DataType::INT32,
"const_data",
Layout::NONE,
TensorLocation::kHost);
weight_tensor->SetDimensions(Dims({10}));
weight_tensor->SetData(&new_weight);
fit->SetNamedWeights("tensor_name", weight_tensor);
fit->RefitEngine();
编译项目中的所有c++代码
./build.sh
gen_pytorch_model.py定义了,只有一个卷积层的模型
input shape为(1,1,3,3),conv kernel size(3,3),stride(1,1),没有bias,权重初始化为全0,此时卷积的计算过程为,input和kernel对应位置相乘相加。
python gen_pytorch_model.py
./bin/gen_model
./bin/single_thread
该demo加载了模型的初始权重,即conv的权重全为0,此时计算结果等于0
通过权值热更新接口,将卷积的权重设为全1,模型的输出变为9
input: 1 1 1 conv kernel: 0 0 0 result: 0
1 1 1 0 0 0
1 1 1 0 0 0
--- weight refit ---
input: 1 1 1 conv kernel: 1 1 1 result: 9
1 1 1 1 1 1
1 1 1 1 1 1
./bin/multi_thread
该demo启动了两个线程,线程1先启动,循环推理并打印结果,线程2通接口替换权值,从全0一直到全9(每秒+1), 控制台可以观察到输出结果的变化
-------------------------
input: 1 1 1 result: 0
1 1 1
1 1 1
... ...
(a lot of log)
... ...
-------------------------
input: 1 1 1 result: 9
1 1 1
1 1 1
... ...
(a lot of log)
... ...
-------------------------
input: 1 1 1 result: 18
1 1 1
1 1 1
... ...
(a lot of log)
(a lot of log)
... ...
-------------------------
input: 1 1 1 result: 81
1 1 1
1 1 1