论文复现 - MAGIS(一)
MAGIS
Experimental Setup
主要软件包版本。基于以下环境运行了 Resnet50 + CIFAR Demo 进行验证。
- CUDA 11.6
- CuDNN 8.4.0
- PyTorch 1.13.1
- Python 3.10.5
一些说明
任务
根据 Renze Chen 描述,任务是跑一遍 forward + backward,测时间和内存开销。Samples 输出的 results.csv 里有实验一节里的各种评估指标。数据集不重要,可能是随机生成。
nn/
实现了 ResNet,ViT 等网络。Transformer.py 中有 Bert 等网络定义。两个大模型也基于 config.py 对 Transformer.py 作用即可。
torch_cuda.py
相当于把 nn/ 中定义的网络转为 pytorch,跑 baseline,然后再进入 MAGIS 工作流。所以 Pytorch Baseline 相当于已经实现了。待研读代码。
Run
第三次实验求助了作者,将 nn.Bert()
参数做了修改,可惜仍然 exceed memory limit.
第四次实验进一步降 batch_size 为 16,首次运行成功。
Debug
这个工作尚有一些部分没有实现,运行时会报一些参数上的错误
1 | TypeError: OpGraph.conv2d() takes from 3 to 7 positional arguments but 8 were given |
问了作者,答复是有输出就不用管。
另有一些数据类型上的问题
1 | RuntimeError: expected scalar type Float but found BFloat16(Half) |
按照 paper 改成 torch.bfloat16
或 float16
都会有此问题. 推测是 torch 1.13 的锅。暂时按 float32
运行,不出意外地 OOM 了。用 float32
, 降至 batch_size = 8
仍无法运行,暂时搁置。
Result
测试了 ResNet50、UNet、ViT、BERT 四个网络,结果如下。
name | device memory limit | latency limit | memory limit | latency limit ratio | memory limit ratio | weight memory | opt-is-prof-result | opt-latency | opt-memory | opt-simul-latency | opt-simul-memory | ori-is-prof-result | ori-latency | ori-memory | ori-simul-latency | ori-simul-memory |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ResNet50 | 3157573632 | None | 307137126.4 | None | 0.8 | 25502912 | TRUE | 61.81041336 | 306195712 | 98.23434734 | 378773952 | TRUE | 61.37547684 | 383921408 | 64.40793696 | 455844288 |
UNet | 3157573632 | None | 1234749030 | None | 0.8 | 31372994 | TRUE | 239.9398092 | 1237252096 | 219.9666572 | 1163835074 | TRUE | 236.343043 | 1543436288 | 202.1358391 | 1470019266 |
ViT | 3157573632 | None | 872511897.6 | None | 0.8 | 85702656 | TRUE | 169.5556183 | 751625984 | 140.1534884 | 856602880 | TRUE | 149.9205729 | 1090639872 | 136.0365616 | 1243764736 |
BERT | 3157573632 | None | 1442303181 | None | 0.8 | 84934656 | TRUE | 243.1806081 | 1220583424 | 226.0288952 | 1337982976 | TRUE | 243.7401377 | 1802878976 | 216.6313542 | 1953890304 |
Latency
与 PyTorch 动态计算图 (Eager Mode) 对比。
平均 Latency 降低 3.39%,在三个 Benchmarks 上取得了加速,在 BERT 上有 0.23% 的时延增加。
Peak Memory Usage
与 PyTorch 动态计算图 (Eager Mode) 对比。
平均降低 25.85% 的峰值内存,在 ViT 和 BERT 明显,分别优化 31.05% 和 32.28%。看来对 Transformer 十分有效。