Patch-past-refactor (#21050)
* small patches, forgot a line * refactor PT * the actual fix
This commit is contained in:
@@ -729,13 +729,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||||
):
|
):
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
past_key_values = decoder_inputs.get("past_key_values")
|
past_key_values = decoder_inputs.get("past_key_values")
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
|
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
|
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
|||||||
@@ -649,9 +649,9 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||||
):
|
):
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
|||||||
@@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
if past:
|
if past_key_values:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
# first step, decoder_cached_states are empty
|
# first step, decoder_cached_states are empty
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
|
|||||||
)
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@unittest.skip("Arthur will fix me!")
|
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2", framework="tf")
|
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2", framework="tf")
|
||||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
|||||||
Reference in New Issue
Block a user