MAGIS

Experimental Setup

主要软件包版本。基于以下环境运行了 Resnet50 + CIFAR Demo 进行验证。

  • CUDA 11.6
  • CuDNN 8.4.0
  • PyTorch 1.13.1
  • Python 3.10.5

一些说明

任务

根据 Renze Chen 描述,任务是跑一遍 forward + backward,测时间和内存开销。Samples 输出的 results.csv 里有实验一节里的各种评估指标。数据集不重要,可能是随机生成。

nn/

实现了 ResNet,ViT 等网络。Transformer.py 中有 Bert 等网络定义。两个大模型也基于 config.pyTransformer.py 作用即可。

torch_cuda.py

相当于把 nn/ 中定义的网络转为 pytorch,跑 baseline,然后再进入 MAGIS 工作流。所以 Pytorch Baseline 相当于已经实现了。待研读代码。

Run

第三次实验求助了作者,将 nn.Bert() 参数做了修改,可惜仍然 exceed memory limit.

第四次实验进一步降 batch_size 为 16,首次运行成功。

Debug

这个工作尚有一些部分没有实现,运行时会报一些参数上的错误

1
2
3
TypeError: OpGraph.conv2d() takes from 3 to 7 positional arguments but 8 were given
...
TypeError: TorchCudaBackend._gen_ewise_uni() missing 1 required positional argument: 'x'

问了作者,答复是有输出就不用管。

另有一些数据类型上的问题

1
RuntimeError: expected scalar type Float but found BFloat16(Half)

按照 paper 改成 torch.bfloat16float16 都会有此问题. 推测是 torch 1.13 的锅。暂时按 float32 运行,不出意外地 OOM 了。用 float32, 降至 batch_size = 8 仍无法运行,暂时搁置。

Result

测试了 ResNet50、UNet、ViT、BERT 四个网络,结果如下。

name device memory limit latency limit memory limit latency limit ratio memory limit ratio weight memory opt-is-prof-result opt-latency opt-memory opt-simul-latency opt-simul-memory ori-is-prof-result ori-latency ori-memory ori-simul-latency ori-simul-memory
ResNet50 3157573632 None 307137126.4 None 0.8 25502912 TRUE 61.81041336 306195712 98.23434734 378773952 TRUE 61.37547684 383921408 64.40793696 455844288
UNet 3157573632 None 1234749030 None 0.8 31372994 TRUE 239.9398092 1237252096 219.9666572 1163835074 TRUE 236.343043 1543436288 202.1358391 1470019266
ViT 3157573632 None 872511897.6 None 0.8 85702656 TRUE 169.5556183 751625984 140.1534884 856602880 TRUE 149.9205729 1090639872 136.0365616 1243764736
BERT 3157573632 None 1442303181 None 0.8 84934656 TRUE 243.1806081 1220583424 226.0288952 1337982976 TRUE 243.7401377 1802878976 216.6313542 1953890304

Latency

与 PyTorch 动态计算图 (Eager Mode) 对比。

平均 Latency 降低 3.39%,在三个 Benchmarks 上取得了加速,在 BERT 上有 0.23% 的时延增加。

Peak Memory Usage

与 PyTorch 动态计算图 (Eager Mode) 对比。

平均降低 25.85% 的峰值内存,在 ViT 和 BERT 明显,分别优化 31.05% 和 32.28%。看来对 Transformer 十分有效。

附 | 环境配置

点我直达 ☞