From cbadb5243c8cfdb5d0b3984f166387bbb2c9d418 Mon Sep 17 00:00:00 2001 From: Joe Davison Date: Fri, 19 Feb 2021 14:06:57 -0500 Subject: [PATCH] Zero shot distillation script cuda patch (#10284) --- .../zero-shot-distillation/distill_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/zero-shot-distillation/distill_classifier.py b/examples/research_projects/zero-shot-distillation/distill_classifier.py index f160387619..5012630a55 100644 --- a/examples/research_projects/zero-shot-distillation/distill_classifier.py +++ b/examples/research_projects/zero-shot-distillation/distill_classifier.py @@ -174,7 +174,7 @@ def get_teacher_predictions( model = AutoModelForSequenceClassification.from_pretrained(model_path) model_config = model.config if not no_cuda and torch.cuda.is_available(): - model = nn.DataParallel(model) + model = nn.DataParallel(model.cuda()) batch_size *= len(model.device_ids) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)