Improve TF weight loading, especially PT crossloading (#21792)

* First commit for the improved PT-TF weight loading

* Remove workarounds from TFEncoderDecoder tests

* Allow a custom weight renaming function in from_pretrained and use that to clean up EncoderDecoder

* make fixup

* First attempt at visionencoderdecoder

* Disable tensorfloat32 in tests to get consistent outputs

* Quick fix to tf_vision_encoder_decoder tests

* make fixup

* Update Blenderbot tests

* Remove unused arg in modeling_tf_opt

* load_tf_sharded_weights had strict=True! This meant transfer learning was impossible, so I'm setting it to False.

* Support prefixes when loading sharded TF checkpoints

* make fixup

* Add test to load sharded models with a weight prefix

* Fix sharded weight loading test

* Add a test for transfer from a sharded checkpoint

* make fixup

* Add test to check that crossloading from PT with a prefix works

* Refactor from_pretrained in the encoderdecoder classes

* Refactor from_pretrained in the encoderdecoder classes

* missmatched -> mismatched

* Explicitly check for None

* No comments showing my very impressive and attractive knowledge of Py3.9+

* Disable TF32 across all TF tests
This commit is contained in:
Matt
2023-02-28 18:41:34 +00:00
committed by GitHub
parent 871c31a6f1
commit acfb714bdf
7 changed files with 147 additions and 148 deletions

View File

@@ -925,16 +925,14 @@ class TFViT2GPT2ModelIntegrationTest(unittest.TestCase):
self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values):
outputs = model.generate(
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
)
outputs = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds, outputs.scores.numpy()
return preds
preds, scores = generate_step(pixel_values)
preds = generate_step(pixel_values)
# should produce
# ["a cat laying on top of a couch next to another cat"]