Zero shot distillation script cuda patch (#10284)
This commit is contained in:
@@ -174,7 +174,7 @@ def get_teacher_predictions(
|
|||||||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
if not no_cuda and torch.cuda.is_available():
|
if not no_cuda and torch.cuda.is_available():
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model.cuda())
|
||||||
batch_size *= len(model.device_ids)
|
batch_size *= len(model.device_ids)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user