Pass device in Logits Processor's init (#29804)
* add device in logits processor * remove device when not needed * codestyle * tests * forgot `melody` version * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * codestyle * updates --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c73ee1333d
commit
83238eeebc
@@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
@@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# check that first input is skipped (min new length applying)
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
|
||||
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
|
||||
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id, device=torch_device
|
||||
)
|
||||
|
||||
expected_eos_scores_before_min_length = batch_size * [-float("inf")]
|
||||
@@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
eta_warp = EtaLogitsWarper(0.0625)
|
||||
eta_warp = EtaLogitsWarper(0.0625, device=torch_device)
|
||||
filtered_dist = torch.exp(eta_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
|
||||
@@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0)
|
||||
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0, device=torch_device)
|
||||
filtered_dist = eta_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
@@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores_comp = scores.clone()
|
||||
|
||||
# instantiate all dist processors
|
||||
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
|
||||
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TopKLogitsWarper(3)
|
||||
@@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(
|
||||
max_length=max_length, eos_token_id=eos_token_id, device=torch_device
|
||||
)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||
@@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
@@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
|
||||
Reference in New Issue
Block a user