@@ -130,7 +130,7 @@ from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
@@ -156,9 +156,7 @@ class EarlyStoppingCallback(TrainerCallback):
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.global_step >= self.num_steps:
|
||||
return {"should_training_stop": True}
|
||||
else:
|
||||
return {}
|
||||
control.should_training_stop = True
|
||||
```
|
||||
|
||||
Then pass it to the [`Trainer`]'s `callback` parameter.
|
||||
@@ -737,7 +735,7 @@ accelerate launch --num_processes=2 \
|
||||
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
|
||||
--fsdp_sharding_strategy=1 \
|
||||
--fsdp_state_dict_type=FULL_STATE_DICT \
|
||||
./examples/pytorch/text-classification/run_glue.py
|
||||
./examples/pytorch/text-classification/run_glue.py \
|
||||
--model_name_or_path google-bert/bert-base-cased \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
|
||||
Reference in New Issue
Block a user