0%

Distiller 模型剪枝教程

Distiller 是由 Intel AI Lab 维护的基于 PyTorch 的开源神经网络压缩框架。主要包括:

  • 用于集成剪枝(pruning),正则化(regularization)和量化(quantization )算法的框架。
  • 一套用于分析和评估压缩性能的工具。
  • 现有技术压缩算法的示例实现。

1. 安装

安装教程建议按照官方链接 installation来,对于虚拟环境,笔者使用的 conda 方便管理,也可以安全按照教程中步骤。

由于最近GitHub网络原因,很多 *.github.io 网站不能访问,distiller 自带的官方文档可用本地浏览器打开。将 distiller 代码 clone 后,使用浏览器打开该文件即可:’ …/distiller/docs/index.html’(请使用绝对路径)。这种方法其实是通过本地文件协议,而不是 http 访问,而且本地访问很快,特殊时期挺有用。

2. 工程简介

2.1 Distiller 的工作流程

Distiller 的剪枝过程与以往的剪枝论文相似,而且设计出普适性的函数接口,方便用户自定义更多的剪枝方法,大致过程如下所示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
For each epoch:
compression_scheduler.on_epoch_begin(epoch) # <<< [1]: 训练一个 epoch 之前
train()
validate()
save_checkpoint()
compression_scheduler.on_epoch_end(epoch) # <<< [5]: 训练一个 epoch 之后

train():
For each training step:
compression_scheduler.on_minibatch_begin(epoch) # <<< [2]: 传入一个 batch 数据之前
output = model(input_var)
loss = criterion(output, target_var)
compression_scheduler.before_backward_pass(epoch) # <<< [3]: 进行 BP 之前
loss.backward()
optimizer.step()
compression_scheduler.on_minibatch_end(epoch) # <<< [4]: 训练一个 batch 之后

与正常的训练过程相比,多出了代码中标记出的步骤,分别按照其序号如上所示。之前有个剪枝经验或阅读过剪枝论文的同学对于这个过程应该比较熟悉。

2.2 distiller包含模块

distiller主要包含5大模块,每个模块可看做一个policy:

  • Regularization:正则化训练
  • Pruning:剪枝
  • Knowledge Distillation:知识蒸馏
  • Quantization:量化
  • Conditional Computation

5大模块都集成policy基类,也就是上文中描述的,插入在训练的过程中。比如以 Network Slimming 为例,需要在BP之前对BN的 gamma weight 进行梯度调整就发生在上文中的 [3]过程。

3.运行

3.1 任务配置文件

distiller 在运行前会读取任务配置文件,examples文件夹下有许多示例供用户参考。本文以Mask RCNN稀疏化训练为例,配置文件路径在:examples/object_detection_compression/maskrcnn.scheduler_agp.yaml(以下省略部分)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
version: 1

pruners:

fc_pruner:
class: AutomatedGradualPruner
initial_sparsity : 0.01
final_sparsity: 0.85
weights: [
roi_heads.box_head.fc6.weight,
roi_heads.box_head.fc7.weight
]

agp_pruner_75:
class: AutomatedGradualPruner
initial_sparsity : 0.01
final_sparsity: 0.75
weights: [
backbone.body.layer1.0.conv1.weight,
...]

...
...

policies:
- pruner:
instance_name : agp_pruner_75
starting_epoch: 0
ending_epoch: 45
frequency: 1

- pruner:
instance_name : fc_pruner
starting_epoch: 0
ending_epoch: 45
frequency: 3
...
...

pruner列表: 定义剪枝过程会用到的pruner,每个pruner包含如下几个key需要定义:

  • name: pruner名称,方便查找使用
  • class:pruner 所使用剪枝类型,即剪枝方法。如本例中的AutomatedGradualPruner
  • weights:剪枝表,即需要剪枝的layer
  • 其他:由于不同剪枝方法所需变量不同,额外的关键字需要根据需要给出。如本例中AGP剪枝所需的:initial_sparsityfinal_sparsity

需要值得注意的是,由于剪枝任务的高度定制化特性,distiller剪枝表需要提前通过 model 的state_dict(以pytoch为例)获取,并对需要剪枝的layer填入上述剪枝表中。例如本文中的 Mask RCNN 模型,如果只对backbone进行剪枝,那么只需在对应pruner的剪枝表中填入model.state_dict()backbone所包含的layer name即可,如backbone.body.layer1.0.conv1.weight等等。

distiller包含的剪枝方法有:https://intellabs.github.io/distiller/algo_pruning.html

algorithms

policies 列表:定义pruner运行期间的规则,每个policy必须包含如下变量:

  • instance_name: 即前文中定义的prunername
  • starting_epochpruner 开始运行epoch
  • ending_epochpruner结束运行epoch
  • frequencypruner运行频率(每多少epoch运行一次)

3.2 开始剪枝训练

运行命令行如下:

1
2
3
4
5
6
7
8
9
export CUDA_VISIBLE_DEVICES='3'

python compress_detector.py --data-path /path/to/coco \ # you should specify your coco path here
--model maskrcnn_resnet50_fpn \
--pretrained \
--lr 0.01 \
--epochs 45 \
--output-dir ./log/mask_struc_prun/ \
--compress ./configs/maskrcnn.scheduler_agp_structure.yaml \

如果正常运行,过程中会看到如下log(仅贴部分):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
...
...
...
Parameters:
+----+--------------------------------------------+--------------------+---------------|
| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |
|----+--------------------------------------------+--------------------+---------------|
| 0 | backbone.body.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12322 | -0.00050 | 0.07605 |
| 1 | backbone.body.layer1.0.conv1.weight | (64, 64, 1, 1) | 4096 | 4096 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07083 | -0.00439 | 0.04032 |
| 2 | backbone.body.layer1.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02915 | 0.00079 | 0.01717 |
| 3 | backbone.body.layer1.0.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03527 | 0.00042 | 0.02105 |
...
...
+----------------+------------+------------+----------+----------+----------+------------|
Total sparsity: 0.52
...
...

同时会保存一个 csv文件,包含训练过程的 sparsity 信息,位置在用户自定义log文件夹下,如下图所示:

1604049867920

总结

distiller 开源库包含了模型压缩主流的一些算法,包括正则化,剪枝,量化,蒸馏等,用户也可以根据需要,定制化加入压缩算法。本文主要以Mask RCNN为例,介绍了使用disitiller进行剪枝过程,读者参照该过程就可进行试验。

Reference

  1. Distiller:神经网络压缩研究框架
  2. Learning Efficient CNN through Network Slimming