本記事は、当社オウンドメディア「Doors」に移転しました。
約5秒後に自動的にリダイレクトします。
この記事では、2023年3月に発表されたPyTorch 2.0のcompileを試し、速度改善がどの程度になるのか調査した結果をご紹介します。
こんにちは。アナリティクスサービス部の加茂です。
昨今のLLMの潮流に乗るように、機械学習の速度を向上させる取り組みも活発に行われており、JAXを利用したWhisperの高速化やmojo言語の発表など、日々賑わいを見せています。 今回は少しリリースから時間が経っていますが、2023年3月に発表されたPyTorch 2.0のcompileを試し、速度改善がどの程度になるのかを簡単に調査してみたいと思います。
検証内容
推論よりも学習において速度の向上の効果が期待できると公式のページに記載があったため、タスクとしては言語モデルのfine-tuningを選択しました。
方法は主にこちらのurl を参考にさせていただきました。
transformersのgithubリポジトリのexampleにある、examples/pytorch/language-modeling/run_clm.py
を利用して、
同様の学習データで少し軽量なモデル(rinna社のjapanese-gpt2-xsmall)のfine-tuningを行いました。
run_clm.py
は実行時に引数を渡すだけでcompileの有無などを指定できます。
下表は関連する3つのオプションとその簡単な説明です。これらのcompileに関連するオプションは試験的に導入しており、将来的に変更される可能性があると書かれているため、もし試される方がいらっしゃったらバージョンにご注意ください。
オプション | 説明 |
---|---|
torch_compile (bool, *optional*, defaults to False) |
PyTorch 2.0によるモデルのコンパイルを行うか否かを指定する |
torch_compile_backend (str, *optional*) |
torch.compile で使用するbackendを指定する |
torch_compile_mode (str, *optional*) |
torch.compile で使用するmodeを指定する |
実際にオプションが使われているコードの箇所を確認すると、transfomers.Trainer
の_wrap_model
メソッド内で .compile
が呼び出されていることが確認できます。
# torch.compile() needs to be called after wrapping the model with FSDP or DDP # to ensure that it accounts for the graph breaks required by those wrappers if self.args.torch_compile: model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
今回はtorch_compile_mode
の各パターンも試してみたいと思います。
実行環境
- マシン
- GCP A100 40GB
- 学習データ
- dolly-japanese-gpt-1b のdolly-oasst1-ja.txt
- フレームワークのバージョン
- Transformers 4.28.1
- PyTorch 2.0.0+cu117
- Datasets 2.12.0
- Tokenizers 0.13.3
実行方法
run_clm.py
を実行します。
- 「compileなし (PyTorch1.13)」もしくは「compileなし (PyTorch 2.0)」の検証は下の通り実行
- 「compileあり (default)」の場合は
--torch_compile=True
を付与して実行 - 「compileあり (reduce-overhead) 」の場合は
--torch_compile=True --torch_compile_mode=reduce-overhead
を付与して実行 - 「compileあり (max-autotune) 」の場合は
--torch_compile=Ture --torch_compile_mode=max-autotune
を付与して実行
python transformers/examples/pytorch/language-modeling/run_clm.py \ --model_name_or_path=rinna/japanese-gpt2-xsmall \ --train_file=train_data/dolly-oasst1-ja.txt \ --output_dir=output \ --do_train \ --optim=adafactor \ --num_train_epochs=5 \ --save_steps=721 \ --logging_steps=72 \ --learning_rate=1e-03 \ --per_device_train_batch_size=4 \ --save_safetensors=True \ --logging_dir=logs
結果
今回の環境では問題なく全てのオプションが動作し、各実行条件とタスクの学習時間は次の通りになりました。
実行条件 | 計算時間 (sec.) |
---|---|
compileなし (PyTorch 1.13) | 1016.0 |
compileなし (PyTorch 2.0) | 1059.0 |
compileあり (default) | 497.9 |
compileあり (reduce-overhead) | 497.6 |
compileあり (max-autotune) | 769.1 |
compileなしと比較するとおよそ計算時間が半分程度になっていることが分かります。 公式ページではNVIDIA A100 GPUで43%速くなったと報告されており、今回はそれを少し上回る改善が確認できました。 速度改善が10%未満に留まっている検証記事も散見される(*1, *2)ため、利用するGPUによってかなり改善幅が変わるようです。
2点ほど細かい点に触れておくと、まずcompileなしの場合、PyTorch 2.0 と PyTorch 1.13では若干2.0の方が遅くなっています。2.0を利用していたとしても、compileなしでは速度改善しないことが確認できます。
また、modeがdefault
と reduce-overhead
ではほぼ同程度となっておりますが、max-autotune
は他の2つと比べると計算時間が長くなっています。公式によると、max-autotune
は最も高速なモデルにcompileを行う代わりに、compileに非常に時間がかかると説明されています。今回の環境とタスクでは、compileに要する時間の影響の方が大きかったと考えられます。
本記事の結果を簡単にまとめますと、PyTorch 2.0を利用して高速化したい場合、一旦はcompileありでdefaultもしくはreduce-overheadのmodeを利用すると良いと結論付けられそうです。