Simplify Tensor Parallel implementation with PyTorch TP (#34184)

* Simplify Tensor Parallel implementation with PyTorch TP

* Move tp_plan to config

* Lint

* Format and warning

* Disable copy-from check

* Conditionally get attr from config

* make fix-copies

* Move base_model_tp_plan to PretrainedConfig

* Move TP into from_pretrained

* Add device context for load

* Do not serialize

* Move _tp_plan setting to post_init

* Add has_tp_plan

* Add test_tp

* Add 'Multi-gpu inference' doc

* Add backward support for device type identification

* Auto-detect accelerator

* supports_tp_plan

* copyright year

* Fix copy
This commit is contained in:
Ke Wen
2024-11-18 10:51:49 -08:00
committed by GitHub
parent 7df93d6ffb
commit 20142ab542
18 changed files with 357 additions and 92 deletions

View File

@@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se
* [Inference on a single CPU](perf_infer_cpu)
* [Inference on a single GPU](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_multi)
* [XLA Integration for TensorFlow Models](tf_xla)