Improve TF weight loading, especially PT crossloading (#21792)
* First commit for the improved PT-TF weight loading * Remove workarounds from TFEncoderDecoder tests * Allow a custom weight renaming function in from_pretrained and use that to clean up EncoderDecoder * make fixup * First attempt at visionencoderdecoder * Disable tensorfloat32 in tests to get consistent outputs * Quick fix to tf_vision_encoder_decoder tests * make fixup * Update Blenderbot tests * Remove unused arg in modeling_tf_opt * load_tf_sharded_weights had strict=True! This meant transfer learning was impossible, so I'm setting it to False. * Support prefixes when loading sharded TF checkpoints * make fixup * Add test to load sharded models with a weight prefix * Fix sharded weight loading test * Add a test for transfer from a sharded checkpoint * make fixup * Add test to check that crossloading from PT with a prefix works * Refactor from_pretrained in the encoderdecoder classes * Refactor from_pretrained in the encoderdecoder classes * missmatched -> mismatched * Explicitly check for None * No comments showing my very impressive and attractive knowledge of Py3.9+ * Disable TF32 across all TF tests
This commit is contained in:
@@ -925,16 +925,14 @@ class TFViT2GPT2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
def generate_step(pixel_values):
|
||||
outputs = model.generate(
|
||||
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
outputs = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
|
||||
output_ids = outputs.sequences
|
||||
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
preds = [pred.strip() for pred in preds]
|
||||
|
||||
return preds, outputs.scores.numpy()
|
||||
return preds
|
||||
|
||||
preds, scores = generate_step(pixel_values)
|
||||
preds = generate_step(pixel_values)
|
||||
|
||||
# should produce
|
||||
# ["a cat laying on top of a couch next to another cat"]
|
||||
|
||||
Reference in New Issue
Block a user