@@ -130,7 +130,7 @@ from torch import nn
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
@@ -156,9 +156,7 @@ class EarlyStoppingCallback(TrainerCallback):
|
|||||||
|
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
if state.global_step >= self.num_steps:
|
if state.global_step >= self.num_steps:
|
||||||
return {"should_training_stop": True}
|
control.should_training_stop = True
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then pass it to the [`Trainer`]'s `callback` parameter.
|
Then pass it to the [`Trainer`]'s `callback` parameter.
|
||||||
@@ -737,7 +735,7 @@ accelerate launch --num_processes=2 \
|
|||||||
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
|
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
|
||||||
--fsdp_sharding_strategy=1 \
|
--fsdp_sharding_strategy=1 \
|
||||||
--fsdp_state_dict_type=FULL_STATE_DICT \
|
--fsdp_state_dict_type=FULL_STATE_DICT \
|
||||||
./examples/pytorch/text-classification/run_glue.py
|
./examples/pytorch/text-classification/run_glue.py \
|
||||||
--model_name_or_path google-bert/bert-base-cased \
|
--model_name_or_path google-bert/bert-base-cased \
|
||||||
--task_name $TASK_NAME \
|
--task_name $TASK_NAME \
|
||||||
--do_train \
|
--do_train \
|
||||||
|
|||||||
Reference in New Issue
Block a user