Fix gradient checkpointing + fp16 autocast for most models (#24247)
* fix gc bug * continue PoC on OPT * fixes * 🤯 * fix tests * remove pytest.mark * fixup * forward contrib credits from discussions * forward contrib credits from discussions * reverting changes on untouched files. --------- Co-authored-by: zhaoqf123 <zhaoqf123@users.noreply.github.com> Co-authored-by: 7eu7d7 <7eu7d7@users.noreply.github.com>
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import gc
|
||||
@@ -549,6 +548,41 @@ class ModelTesterMixin:
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_training_gradient_checkpointing_autocast(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
|
||||
if (
|
||||
model_class.__name__
|
||||
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
|
||||
or not model_class.supports_gradient_checkpointing
|
||||
):
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
with torch.cuda.amp.autocast(True, dtype=torch.float16):
|
||||
output = model(**inputs)[0]
|
||||
loss = output.mean()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for n, param in model.named_parameters():
|
||||
self.assertTrue(param.grad is not None, f"None gradient in param {n}")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model does not output attentions")
|
||||
|
||||
Reference in New Issue
Block a user