[Deberta/Deberta-v2] Refactor code base to support compile, export, and fix LLM (#22105)
* some modification for roadmap * revert some changes * yups * weird * make it work * sttling * fix-copies * fixup * renaming * more fix-copies * move stuff around * remove torch script warnings * ignore copies * revert bad changes * woops * just styling * nit * revert * style fixup * nits configuration style * fixup * nits * will this fix the tf pt issue? * style * ??????? * update * eval? * update error message * updates * style * grumble grumble * update * style * nit * skip torch fx tests that were failing * style * skip the failing tests * skip another test and make style
This commit is contained in:
@@ -2539,7 +2539,11 @@ class ModelTesterMixin:
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, tol, f"{name}: Difference between PyTorch and TF is {max_diff} (>= {tol}).")
|
||||
self.assertLessEqual(
|
||||
max_diff,
|
||||
tol,
|
||||
f"{name}: Difference between PyTorch and TF is {max_diff} (>= {tol}) for {model_class.__name__}",
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`tf_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `tf.Tensor`. Got"
|
||||
@@ -2615,7 +2619,7 @@ class ModelTesterMixin:
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
pt_model = model_class(config)
|
||||
pt_model = model_class(config).eval()
|
||||
tf_model = tf_model_class(config)
|
||||
|
||||
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
Reference in New Issue
Block a user