Fix Encoder-Decoder testing issue about repo. names (#19250)
* Change "../gpt2" to "gpt2" Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -673,9 +673,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
@require_tf
|
||||
class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"google/vit-base-patch16-224-in21k", "../gpt2"
|
||||
)
|
||||
return TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "gpt2")
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = TFViTModel(config, name="encoder")
|
||||
@@ -720,12 +718,10 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te
|
||||
@require_tf
|
||||
class TFVisionEncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
return TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"google/vit-base-patch16-224-in21k", "../gpt2"
|
||||
)
|
||||
return TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "gpt2")
|
||||
|
||||
def get_decoder_config(self):
|
||||
config = AutoConfig.from_pretrained("../gpt2")
|
||||
config = AutoConfig.from_pretrained("gpt2")
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
return config
|
||||
@@ -735,7 +731,7 @@ class TFVisionEncoderDecoderModelTest(unittest.TestCase):
|
||||
|
||||
def get_encoder_decoder_models(self):
|
||||
encoder_model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k", name="encoder")
|
||||
decoder_model = TFGPT2LMHeadModel.from_pretrained("../gpt2", config=self.get_decoder_config(), name="decoder")
|
||||
decoder_model = TFGPT2LMHeadModel.from_pretrained("gpt2", config=self.get_decoder_config(), name="decoder")
|
||||
return {"encoder": encoder_model, "decoder": decoder_model}
|
||||
|
||||
def _check_configuration_tie(self, model):
|
||||
@@ -764,7 +760,7 @@ def prepare_img():
|
||||
class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
def get_encoder_decoder_config(self):
|
||||
encoder_config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
decoder_config = AutoConfig.from_pretrained("../gpt2", is_decoder=True, add_cross_attention=True)
|
||||
decoder_config = AutoConfig.from_pretrained("gpt2", is_decoder=True, add_cross_attention=True)
|
||||
return VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
||||
|
||||
def get_encoder_decoder_config_small(self):
|
||||
@@ -879,7 +875,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
|
||||
config = self.get_encoder_decoder_config()
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
decoder_tokenizer = AutoTokenizer.from_pretrained("../gpt2")
|
||||
decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
img = prepare_img()
|
||||
pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values
|
||||
@@ -896,7 +892,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
encoder = TFAutoModel.from_pretrained("google/vit-base-patch16-224-in21k", name="encoder")
|
||||
# It's necessary to specify `add_cross_attention=True` here.
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(
|
||||
"../gpt2", is_decoder=True, add_cross_attention=True, name="decoder"
|
||||
"gpt2", is_decoder=True, add_cross_attention=True, name="decoder"
|
||||
)
|
||||
pretrained_encoder_dir = os.path.join(tmp_dirname, "pretrained_encoder")
|
||||
pretrained_decoder_dir = os.path.join(tmp_dirname, "pretrained_decoder")
|
||||
|
||||
Reference in New Issue
Block a user