add prepare inputs for transfo_xl and xlnet
This commit is contained in:
@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
# TODO: might be better to write a self.do_output_past method for each
|
||||
# individual class as is done for prepare_inputs_for_generation
|
||||
has_output_past = hasattr(self.config, 'output_past') and self.config.output_past
|
||||
has_multiple_outputs = len(outputs) > 1
|
||||
has_mem_len = hasattr(self.config, 'mem_len')
|
||||
has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len
|
||||
|
||||
if has_output_past and has_multiple_outputs and not has_mem_len:
|
||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||
return True
|
||||
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
|
||||
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module):
|
||||
if past:
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# copy the relevant beam idx past to past
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
|
||||
Reference in New Issue
Block a user