fix tests
This commit is contained in:
@@ -492,11 +492,11 @@ class CommonTestCases:
|
|||||||
return equal
|
return equal
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not hasattr(model_class, 'tie_weights'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
config.torchscript = True
|
config.torchscript = True
|
||||||
model_not_tied = model_class(config)
|
model_not_tied = model_class(config)
|
||||||
|
if model_not_tied.get_output_embeddings() is None:
|
||||||
|
continue
|
||||||
|
|
||||||
params_not_tied = list(model_not_tied.parameters())
|
params_not_tied = list(model_not_tied.parameters())
|
||||||
|
|
||||||
config_tied = copy.deepcopy(config)
|
config_tied = copy.deepcopy(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user