From 196cce6e9b10c5749daf05fdc02d75b924639b00 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 27 Jan 2022 17:58:55 +0300 Subject: [PATCH] Add a device argument to the eval script (#15371) * Device argument for the eval script * Default to none * isort --- .../research_projects/robust-speech-event/eval.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/robust-speech-event/eval.py b/examples/research_projects/robust-speech-event/eval.py index 9952f7cb46..53cd244daf 100755 --- a/examples/research_projects/robust-speech-event/eval.py +++ b/examples/research_projects/robust-speech-event/eval.py @@ -3,6 +3,7 @@ import argparse import re from typing import Dict +import torch from datasets import Audio, Dataset, load_dataset, load_metric from transformers import AutoFeatureExtractor, pipeline @@ -78,7 +79,9 @@ def main(args): dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate)) # load eval pipeline - asr = pipeline("automatic-speech-recognition", model=args.model_id) + if args.device is None: + args.device = 0 if torch.cuda.is_available() else -1 + asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device) # map function to decode audio def map_to_pred(batch): @@ -123,6 +126,12 @@ if __name__ == "__main__": parser.add_argument( "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis." ) + parser.add_argument( + "--device", + type=int, + default=None, + help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", + ) args = parser.parse_args() main(args)