adapt style to predefined style layout
This commit is contained in:
@@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if 'past' in kwargs and kwargs['past']:
|
if "past" in kwargs and kwargs["past"]:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
inputs = {"input_ids": input_ids}
|
inputs = {"input_ids": input_ids}
|
||||||
|
|||||||
@@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if 'past' in kwargs and kwargs['past']:
|
if "past" in kwargs and kwargs["past"]:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
inputs = {"input_ids": input_ids}
|
inputs = {"input_ids": input_ids}
|
||||||
|
|||||||
@@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
inputs = {"input_ids": input_ids}
|
inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
# if past is defined in model kwargs then use it for faster decoding
|
# if past is defined in model kwargs then use it for faster decoding
|
||||||
if 'past' in model_kwargs and model_kwargs['past']:
|
if "past" in model_kwargs and model_kwargs["past"]:
|
||||||
inputs['mems'] = model_kwargs['past']
|
inputs["mems"] = model_kwargs["past"]
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
def _do_output_past(self, outputs):
|
def _do_output_past(self, outputs):
|
||||||
has_output_past = hasattr(self.config, 'output_past') and self.config.output_past
|
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||||
has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len
|
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||||
|
|
||||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||||
|
|
||||||
# if past is defined in model kwargs then use it for faster decoding
|
# if past is defined in model kwargs then use it for faster decoding
|
||||||
if 'past' in model_kwargs and model_kwargs['past']:
|
if "past" in model_kwargs and model_kwargs["past"]:
|
||||||
inputs['mems'] = model_kwargs['past']
|
inputs["mems"] = model_kwargs["past"]
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user