Fix beam search when using model parallel (#24969)

* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Dong-Yong Lee
2023-09-15 00:00:52 +09:00
committed by GitHub
parent 0dd06c3f78
commit 8881f38a4f
53 changed files with 191 additions and 95 deletions

View File

@@ -1180,7 +1180,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],)
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:],)
return reordered_past
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
@@ -2898,7 +2898,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past
@@ -3335,6 +3335,6 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past
{% endif -%}