Add Pop2Piano (#21785)

* init commit

* config updated also some modeling

* Processor and Model config combined

* extraction pipeline(upto before spectogram & mel_conditioner) added but not properly tested

* model loading successful!

* feature extractor done!

* FE can now be called from HF

* postprocessing added in fe file

* same as prev commit

* Pop2PianoConfig doc done

* cfg docs slightly changed

* fe docs done

* batched

* batched working!

* temp

* v1

* checking

* trying to go with generate

* with generate and model tests passed

* before rebasing

* .

* tests done docs done remaining others & nits

* nits

* LogMelSpectogram shifted to FeatureExtractor

* is_tf rmeoved from pop2piano/init

* import solved

* tokenization tests added

* minor fixed regarding modeling_pop2piano

* tokenizer changed to only return midi_object and other changes

* Updated paper abstract(Camera-ready version) (#2)

* more comments and nits

* ruff changes

* code quality fix

* sg comments

* t5 change added and rebased

* comments except batching

* batching done

* comments

* small doc fix

* example removed from modeling

* ckpt

* forward it compatible with fe and generation done

* comments

* comments

* code-quality fix(maybe)

* ckpts changed

* doc file changed from mdx to md

* test fixes

* tokenizer test fix

* changes

* nits done main changes remaining

* code modified

* Pop2PianoProcessor added with tests

* other comments

* added Pop2PianoProcessor to dummy_objects

* added require_onnx to modeling file

* changes

* update .md file

* remove extra line in index.md

* back to the main index

* added pop2piano to index

* Added tokenizer.__call__ with valid args and batch_decode and aligned the processor part too

* changes

* added return types to 2 tokenizer methods

* the PR build test might work now

* added backends

* PR build fix

* vocab added

* comments

* refactored vocab into 1 file

* added conversion script

* comments

* essentia version changed in .md

* comments

* more tokenizer tests added

* minor fix

* tests extended for outputs acc check

* small fix

---------

Co-authored-by: Jongho Choi <sweetcocoa@snu.ac.kr>
This commit is contained in:
Susnato Dhar
2023-08-21 21:05:00 +05:30
committed by GitHub
parent 6f041fcbb8
commit 450a181d8b
38 changed files with 5284 additions and 0 deletions

View File

View File

@@ -0,0 +1,291 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_essentia,
require_librosa,
require_scipy,
require_tf,
require_torch,
)
from transformers.utils.import_utils import (
is_essentia_available,
is_librosa_available,
is_scipy_available,
is_torch_available,
)
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
requirements_available = (
is_torch_available() and is_essentia_available() and is_scipy_available() and is_librosa_available()
)
if requirements_available:
import torch
from transformers import Pop2PianoFeatureExtractor
class Pop2PianoFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
n_bars=2,
sample_rate=22050,
use_mel=True,
padding_value=0,
vocab_size_special=4,
vocab_size_note=128,
vocab_size_velocity=2,
vocab_size_time=100,
):
self.parent = parent
self.n_bars = n_bars
self.sample_rate = sample_rate
self.use_mel = use_mel
self.padding_value = padding_value
self.vocab_size_special = vocab_size_special
self.vocab_size_note = vocab_size_note
self.vocab_size_velocity = vocab_size_velocity
self.vocab_size_time = vocab_size_time
def prepare_feat_extract_dict(self):
return {
"n_bars": self.n_bars,
"sample_rate": self.sample_rate,
"use_mel": self.use_mel,
"padding_value": self.padding_value,
"vocab_size_special": self.vocab_size_special,
"vocab_size_note": self.vocab_size_note,
"vocab_size_velocity": self.vocab_size_velocity,
"vocab_size_time": self.vocab_size_time,
}
@require_torch
@require_essentia
@require_librosa
@require_scipy
class Pop2PianoFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Pop2PianoFeatureExtractor if requirements_available else None
def setUp(self):
self.feat_extract_tester = Pop2PianoFeatureExtractionTester(self)
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.use_mel
mel_2 = feat_extract_second.use_mel
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.use_mel
mel_2 = feat_extract_second.use_mel
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_call(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input = np.zeros([1000000], dtype=np.float32)
input_features = feature_extractor(speech_input, sampling_rate=16_000, return_tensors="np")
self.assertTrue(input_features.input_features.ndim == 3)
self.assertEqual(input_features.input_features.shape[-1], 512)
self.assertTrue(input_features.beatsteps.ndim == 2)
self.assertTrue(input_features.extrapolated_beatstep.ndim == 2)
def test_integration(self):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select([0])["audio"]
input_speech = [x["array"] for x in speech_samples][0]
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
feaure_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
input_features = feaure_extractor(
input_speech, sampling_rate=sampling_rate, return_tensors="pt"
).input_features
EXPECTED_INPUT_FEATURES = torch.tensor(
[[-7.1493, -6.8701, -4.3214], [-5.9473, -5.7548, -3.8438], [-6.1324, -5.9018, -4.3778]]
)
self.assertTrue(torch.allclose(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_attention_mask(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2],
sampling_rate=[44_100, 16_000],
return_tensors="np",
return_attention_mask=True,
)
self.assertTrue(hasattr(input_features, "attention_mask"))
# check shapes
self.assertTrue(input_features["attention_mask"].ndim == 2)
self.assertEqual(input_features["attention_mask_beatsteps"].shape[0], 2)
self.assertEqual(input_features["attention_mask_extrapolated_beatstep"].shape[0], 2)
# check if they are any values except 0 and 1
self.assertTrue(np.max(input_features["attention_mask"]) == 1)
self.assertTrue(np.max(input_features["attention_mask_beatsteps"]) == 1)
self.assertTrue(np.max(input_features["attention_mask_extrapolated_beatstep"]) == 1)
self.assertTrue(np.min(input_features["attention_mask"]) == 0)
self.assertTrue(np.min(input_features["attention_mask_beatsteps"]) == 0)
self.assertTrue(np.min(input_features["attention_mask_extrapolated_beatstep"]) == 0)
def test_batch_feature(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_attention_mask=True,
)
self.assertEqual(len(input_features["input_features"].shape), 3)
# check shape
self.assertEqual(input_features["beatsteps"].shape[0], 3)
self.assertEqual(input_features["extrapolated_beatstep"].shape[0], 3)
def test_batch_feature_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="np",
return_attention_mask=True,
)
# check np array or not
self.assertEqual(type(input_features["input_features"]), np.ndarray)
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
def test_batch_feature_pt(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="pt",
return_attention_mask=True,
)
# check pt tensor or not
self.assertEqual(type(input_features["input_features"]), torch.Tensor)
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
@require_tf
def test_batch_feature_tf(self):
import tensorflow as tf
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="tf",
return_attention_mask=True,
)
# check tf tensor or not
self.assertTrue(tf.is_tensor(input_features["input_features"]))
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_accepts_tensors_pt(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_accepts_tensors_tf(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_from_list(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_from_array(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not support truncation")
def test_attention_mask_with_truncation(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not supports truncation")
def test_truncation_from_array(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not supports truncation")
def test_truncation_from_list(self):
pass

View File

@@ -0,0 +1,778 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Pop2Piano model. """
import copy
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from transformers import Pop2PianoConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.testing_utils import (
require_essentia,
require_librosa,
require_onnx,
require_scipy,
require_torch,
slow,
torch_device,
)
from transformers.utils import is_essentia_available, is_librosa_available, is_scipy_available, is_torch_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from transformers import Pop2PianoForConditionalGeneration
from transformers.models.pop2piano.modeling_pop2piano import POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.pytorch_utils import is_torch_1_8_0
else:
is_torch_1_8_0 = False
@require_torch
class Pop2PianoModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
encoder_seq_length=7,
decoder_seq_length=9,
# For common tests
is_training=False,
use_attention_mask=True,
use_labels=True,
hidden_size=64,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
decoder_layers=None,
):
self.parent = parent
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = decoder_layers
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = (
ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) if self.use_labels else None
)
return self.get_config(), input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels
def get_pipeline_config(self):
return Pop2PianoConfig(
vocab_size=166, # Pop2Piano forces 100 extra tokens
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
def get_config(self):
return Pop2PianoConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
def check_prepare_lm_labels_via_shift_left(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
# make sure that lm_labels are correctly padded from the right
lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
# add causal pad token mask
triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
decoder_input_ids = model._shift_right(lm_labels)
for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
# first item
self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
if i < decoder_input_ids_slice.shape[-1]:
if i < decoder_input_ids.shape[-1] - 1:
# items before diagonal
self.parent.assertListEqual(
decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
)
# pad items after diagonal
if i < decoder_input_ids.shape[-1] - 2:
self.parent.assertListEqual(
decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
)
else:
# all items after square
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_model(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_past = result.past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
# There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 4)
def create_and_check_with_lm_head(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
self.parent.assertEqual(len(outputs), 4)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_decoder_model_past(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).get_decoder().to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).get_decoder()
model.to(torch_device)
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).get_decoder().to(torch_device).eval()
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_generate_with_past_key_values(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
torch.manual_seed(0)
output_without_past_cache = model.generate(
input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
)
torch.manual_seed(0)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_model_fp16_forward(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[
"encoder_last_hidden_state"
]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_encoder_decoder_shared_weights(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
for model_class in [Pop2PianoForConditionalGeneration]:
torch.manual_seed(0)
model = model_class(config=config).to(torch_device).eval()
# load state dict copies weights but does not tie them
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
torch.manual_seed(0)
tied_config = copy.deepcopy(config)
tied_config.tie_encoder_decoder = True
tied_model = model_class(config=tied_config).to(torch_device).eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = model_class.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx],
tied_model_result[0][0, :, random_slice_idx],
atol=1e-4,
)
)
def check_resize_embeddings_pop2piano_v1_1(
self,
config,
):
prev_vocab_size = config.vocab_size
config.tie_word_embeddings = False
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
model.resize_token_embeddings(prev_vocab_size - 10)
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"use_cache": False,
}
return config, inputs_dict
@require_torch
class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Pop2PianoForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = ()
all_parallelizable_model_classes = ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
test_model_parallel = False
is_encoder_decoder = True
def setUp(self):
self.model_tester = Pop2PianoModelTester(self)
self.config_tester = ConfigTester(self, config_class=Pop2PianoConfig, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_shift_right(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_config_and_model_silu_gated(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.feed_forward_proj = "gated-silu"
self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_decoder_model_past_with_3d_attn_mask(self):
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = self.model_tester.prepare_config_and_inputs()
attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
vocab_size=2,
)
decoder_attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
vocab_size=2,
)
self.model_tester.create_and_check_decoder_model_attention_mask_past(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_encoder_decoder_shared_weights(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_v1_1_resize_embeddings(self):
config = self.model_tester.prepare_config_and_inputs()[0]
self.model_tester.check_resize_embeddings_pop2piano_v1_1(config)
@slow
def test_model_from_pretrained(self):
for model_name in POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = Pop2PianoForConditionalGeneration.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_onnx
@unittest.skipIf(
is_torch_1_8_0,
reason="Test has a segmentation fault on torch 1.8.0",
)
def test_export_to_onnx(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = Pop2PianoForConditionalGeneration(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(
model,
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/Pop2Piano_test.onnx",
export_params=True,
opset_version=9,
input_names=["input_ids", "decoder_input_ids"],
)
def test_pass_with_input_features(self):
input_features = BatchFeature(
{
"input_features": torch.rand((75, 100, 512)).type(torch.float32),
"beatsteps": torch.randint(size=(1, 955), low=0, high=100).type(torch.float32),
"extrapolated_beatstep": torch.randint(size=(1, 900), low=0, high=100).type(torch.float32),
}
)
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
model_opts = model.generate(input_features=input_features["input_features"], return_dict_in_generate=True)
self.assertEqual(model_opts.sequences.ndim, 2)
def test_pass_with_batched_input_features(self):
input_features = BatchFeature(
{
"input_features": torch.rand((220, 70, 512)).type(torch.float32),
"beatsteps": torch.randint(size=(5, 955), low=0, high=100).type(torch.float32),
"extrapolated_beatstep": torch.randint(size=(5, 900), low=0, high=100).type(torch.float32),
"attention_mask": torch.concatenate(
[
torch.ones([120, 70], dtype=torch.int32),
torch.zeros([1, 70], dtype=torch.int32),
torch.ones([50, 70], dtype=torch.int32),
torch.zeros([1, 70], dtype=torch.int32),
torch.ones([47, 70], dtype=torch.int32),
torch.zeros([1, 70], dtype=torch.int32),
],
axis=0,
),
"attention_mask_beatsteps": torch.ones((5, 955)).type(torch.int32),
"attention_mask_extrapolated_beatstep": torch.ones((5, 900)).type(torch.int32),
}
)
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
model_opts = model.generate(
input_features=input_features["input_features"],
attention_mask=input_features["attention_mask"],
return_dict_in_generate=True,
)
self.assertEqual(model_opts.sequences.ndim, 2)
@require_torch
class Pop2PianoModelIntegrationTests(unittest.TestCase):
@slow
def test_mel_conditioner_integration(self):
composer = "composer1"
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
input_embeds = torch.ones([10, 100, 512])
composer_value = model.generation_config.composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value)
composer_value = composer_value.repeat(input_embeds.size(0))
outputs = model.mel_conditioner(
input_embeds, composer_value, min(model.generation_config.composer_to_feature_token.values())
)
# check shape
self.assertEqual(outputs.size(), torch.Size([10, 101, 512]))
# check values
EXPECTED_OUTPUTS = torch.tensor(
[[1.0475305318832397, 0.29052114486694336, -0.47778210043907166], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
)
self.assertTrue(torch.allclose(outputs[0, :3, :3], EXPECTED_OUTPUTS, atol=1e-4))
@slow
@require_essentia
@require_librosa
@require_scipy
def test_full_model_integration(self):
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
from transformers import Pop2PianoProcessor
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
sampling_rate = 44_100
processor = Pop2PianoProcessor.from_pretrained("sweetcocoa/pop2piano")
input_features = processor.feature_extractor(
speech_input1, sampling_rate=sampling_rate, return_tensors="pt"
)
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
outputs = model.generate(
input_features=input_features["input_features"], return_dict_in_generate=True
).sequences
# check for shapes
self.assertEqual(outputs.size(0), 70)
# check for values
self.assertEqual(outputs[0, :2].detach().cpu().numpy().tolist(), [0, 1])
# This is the test for a real music from K-Pop genre.
@slow
@require_essentia
@require_librosa
@require_scipy
def test_real_music(self):
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
from transformers import Pop2PianoFeatureExtractor, Pop2PianoTokenizer
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
model.eval()
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
ds = load_dataset("sweetcocoa/pop2piano_ci", split="test")
output_fe = feature_extractor(
ds["audio"][0]["array"], sampling_rate=ds["audio"][0]["sampling_rate"], return_tensors="pt"
)
output_model = model.generate(input_features=output_fe["input_features"], composer="composer1")
output_tokenizer = tokenizer.batch_decode(token_ids=output_model, feature_extractor_output=output_fe)
pretty_midi_object = output_tokenizer["pretty_midi_objects"][0]
# Checking if no of notes are same
self.assertEqual(len(pretty_midi_object.instruments[0].notes), 59)
predicted_timings = []
for i in pretty_midi_object.instruments[0].notes:
predicted_timings.append(i.start)
# Checking note start timings(first 6)
EXPECTED_START_TIMINGS = [
0.4876190423965454,
0.7314285635948181,
0.9752380847930908,
1.4396371841430664,
1.6718367338180542,
1.904036283493042,
]
np.allclose(EXPECTED_START_TIMINGS, predicted_timings[:6])
# Checking note end timings(last 6)
EXPECTED_END_TIMINGS = [
12.341403007507324,
12.567797183990479,
12.567797183990479,
12.567797183990479,
12.794191360473633,
12.794191360473633,
]
np.allclose(EXPECTED_END_TIMINGS, predicted_timings[-6:])

View File

@@ -0,0 +1,266 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
import pytest
from datasets import load_dataset
from transformers.testing_utils import (
require_essentia,
require_librosa,
require_pretty_midi,
require_scipy,
require_torch,
)
from transformers.tokenization_utils import BatchEncoding
from transformers.utils.import_utils import (
is_essentia_available,
is_librosa_available,
is_pretty_midi_available,
is_scipy_available,
is_torch_available,
)
requirements_available = (
is_torch_available()
and is_essentia_available()
and is_scipy_available()
and is_librosa_available()
and is_pretty_midi_available()
)
if requirements_available:
import pretty_midi
from transformers import (
Pop2PianoFeatureExtractor,
Pop2PianoForConditionalGeneration,
Pop2PianoProcessor,
Pop2PianoTokenizer,
)
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
@require_scipy
@require_torch
@require_librosa
@require_essentia
@require_pretty_midi
class Pop2PianoProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
processor = Pop2PianoProcessor(feature_extractor, tokenizer)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return Pop2PianoTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_feature_extractor(self, **kwargs):
return Pop2PianoFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_additional_features(self):
processor = Pop2PianoProcessor(
tokenizer=self.get_tokenizer(),
feature_extractor=self.get_feature_extractor(),
)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(
unk_token="-1",
eos_token="1",
pad_token="0",
bos_token="2",
)
feature_extractor_add_kwargs = self.get_feature_extractor()
processor = Pop2PianoProcessor.from_pretrained(
self.tmpdirname,
unk_token="-1",
eos_token="1",
pad_token="0",
bos_token="2",
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, Pop2PianoTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, Pop2PianoFeatureExtractor)
def get_inputs(self):
"""get inputs for both feature extractor and tokenizer"""
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select([0])["audio"]
input_speech = [x["array"] for x in speech_samples][0]
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
feature_extractor_outputs = self.get_feature_extractor()(
audio=input_speech, sampling_rate=sampling_rate, return_tensors="pt"
)
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
token_ids = model.generate(input_features=feature_extractor_outputs["input_features"], composer="composer1")
dummy_notes = [
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
pretty_midi.Note(start=0.673379, end=0.905578, pitch=73, velocity=77),
pretty_midi.Note(start=0.905578, end=2.159456, pitch=73, velocity=77),
pretty_midi.Note(start=1.114558, end=2.159456, pitch=78, velocity=77),
pretty_midi.Note(start=1.323537, end=1.532517, pitch=80, velocity=77),
],
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
],
]
return input_speech, sampling_rate, token_ids, dummy_notes
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
input_speech, sampling_rate, _, _ = self.get_inputs()
feature_extractor_outputs = feature_extractor(
audio=input_speech, sampling_rate=sampling_rate, return_tensors="np"
)
processor_outputs = processor(audio=input_speech, sampling_rate=sampling_rate, return_tensors="np")
for key in feature_extractor_outputs.keys():
self.assertTrue(np.allclose(feature_extractor_outputs[key], processor_outputs[key], atol=1e-4))
def test_processor_batch_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
audio, sampling_rate, token_ids, _ = self.get_inputs()
feature_extractor_output = feature_extractor(audio=audio, sampling_rate=sampling_rate, return_tensors="pt")
encoded_processor = processor.batch_decode(
token_ids=token_ids,
feature_extractor_output=feature_extractor_output,
return_midi=True,
)
encoded_tokenizer = tokenizer.batch_decode(
token_ids=token_ids,
feature_extractor_output=feature_extractor_output,
return_midi=True,
)
# check start timings
encoded_processor_start_timings = [token.start for token in encoded_processor["notes"]]
encoded_tokenizer_start_timings = [token.start for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_start_timings, encoded_tokenizer_start_timings)
# check end timings
encoded_processor_end_timings = [token.end for token in encoded_processor["notes"]]
encoded_tokenizer_end_timings = [token.end for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_end_timings, encoded_tokenizer_end_timings)
# check pitch
encoded_processor_pitch = [token.pitch for token in encoded_processor["notes"]]
encoded_tokenizer_pitch = [token.pitch for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_pitch, encoded_tokenizer_pitch)
# check velocity
encoded_processor_velocity = [token.velocity for token in encoded_processor["notes"]]
encoded_tokenizer_velocity = [token.velocity for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_velocity, encoded_tokenizer_velocity)
def test_tokenizer_call(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
_, _, _, notes = self.get_inputs()
encoded_processor = processor(
notes=notes,
)
self.assertTrue(isinstance(encoded_processor, BatchEncoding))
def test_processor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
audio, sampling_rate, _, notes = self.get_inputs()
inputs = processor(
audio=audio,
sampling_rate=sampling_rate,
notes=notes,
)
self.assertListEqual(
list(inputs.keys()),
["input_features", "beatsteps", "extrapolated_beatstep", "token_ids"],
)
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
audio, sampling_rate, _, notes = self.get_inputs()
feature_extractor(audio, sampling_rate, return_tensors="pt")
inputs = processor(
audio=audio,
sampling_rate=sampling_rate,
notes=notes,
)
self.assertListEqual(
list(inputs.keys()),
["input_features", "beatsteps", "extrapolated_beatstep", "token_ids"],
)

View File

@@ -0,0 +1,418 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Please note that Pop2PianoTokenizer is too far from our usual tokenizers and thus cannot use the TokenizerTesterMixin class.
"""
import os
import pickle
import shutil
import tempfile
import unittest
from transformers.feature_extraction_utils import BatchFeature
from transformers.testing_utils import (
is_pretty_midi_available,
is_torch_available,
require_pretty_midi,
require_torch,
)
from transformers.tokenization_utils import BatchEncoding
if is_torch_available():
import torch
requirements_available = is_torch_available() and is_pretty_midi_available()
if requirements_available:
import pretty_midi
from transformers import Pop2PianoTokenizer
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
@require_torch
@require_pretty_midi
class Pop2PianoTokenizerTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
def get_input_notes(self):
notes = [
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
pretty_midi.Note(start=0.673379, end=0.905578, pitch=73, velocity=77),
pretty_midi.Note(start=0.905578, end=2.159456, pitch=73, velocity=77),
pretty_midi.Note(start=1.114558, end=2.159456, pitch=78, velocity=77),
pretty_midi.Note(start=1.323537, end=1.532517, pitch=80, velocity=77),
],
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
],
]
return notes
def test_call(self):
notes = self.get_input_notes()
output = self.tokenizer(
notes,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=10,
return_attention_mask=True,
)
# check the output type
self.assertTrue(isinstance(output, BatchEncoding))
# check the values
expected_output_token_ids = torch.tensor(
[[134, 133, 74, 135, 77, 132, 77, 133, 77, 82], [134, 133, 74, 136, 132, 74, 134, 134, 134, 134]]
)
expected_output_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
self.assertTrue(torch.allclose(output["token_ids"], expected_output_token_ids, atol=1e-4))
self.assertTrue(torch.allclose(output["attention_mask"], expected_output_attention_mask, atol=1e-4))
def test_batch_decode(self):
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
# Please note that this test does not test the accuracy of the outputs, instead it is designed to make sure that
# the tokenizer's batch_decode can deal with attention_mask in feature-extractor outputs. For the accuracy check
# please see the `test_batch_decode_outputs` test.
model_output = torch.concatenate(
[
torch.randint(size=[120, 96], low=0, high=70, dtype=torch.long),
torch.zeros(size=[1, 96], dtype=torch.long),
torch.randint(size=[50, 96], low=0, high=40, dtype=torch.long),
torch.zeros(size=[1, 96], dtype=torch.long),
],
axis=0,
)
input_features = BatchFeature(
{
"beatsteps": torch.ones([2, 955]),
"extrapolated_beatstep": torch.ones([2, 1000]),
"attention_mask": torch.concatenate(
[
torch.ones([120, 96], dtype=torch.long),
torch.zeros([1, 96], dtype=torch.long),
torch.ones([50, 96], dtype=torch.long),
torch.zeros([1, 96], dtype=torch.long),
],
axis=0,
),
"attention_mask_beatsteps": torch.ones([2, 955]),
"attention_mask_extrapolated_beatstep": torch.ones([2, 1000]),
}
)
output = self.tokenizer.batch_decode(token_ids=model_output, feature_extractor_output=input_features)[
"pretty_midi_objects"
]
# check length
self.assertTrue(len(output) == 2)
# check object type
self.assertTrue(isinstance(output[0], pretty_midi.pretty_midi.PrettyMIDI))
self.assertTrue(isinstance(output[1], pretty_midi.pretty_midi.PrettyMIDI))
def test_batch_decode_outputs(self):
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
# Please note that this test tests the accuracy of the outputs of the tokenizer's `batch_decode` method.
model_output = torch.tensor(
[
[134, 133, 74, 135, 77, 82, 84, 136, 132, 74, 77, 82, 84],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
]
)
input_features = BatchEncoding(
{
"beatsteps": torch.tensor([[0.0697, 0.1103, 0.1509, 0.1916]]),
"extrapolated_beatstep": torch.tensor([[0.0000, 0.0406, 0.0813, 0.1219]]),
}
)
output = self.tokenizer.batch_decode(token_ids=model_output, feature_extractor_output=input_features)
# check outputs
self.assertEqual(len(output["notes"]), 4)
predicted_start_timings, predicted_end_timings = [], []
for i in output["notes"]:
predicted_start_timings.append(i.start)
predicted_end_timings.append(i.end)
# Checking note start timings
expected_start_timings = torch.tensor(
[
0.069700,
0.110300,
0.110300,
0.110300,
]
)
predicted_start_timings = torch.tensor(predicted_start_timings)
self.assertTrue(torch.allclose(expected_start_timings, predicted_start_timings, atol=1e-4))
# Checking note end timings
expected_end_timings = torch.tensor(
[
0.191600,
0.191600,
0.191600,
0.191600,
]
)
predicted_end_timings = torch.tensor(predicted_end_timings)
self.assertTrue(torch.allclose(expected_end_timings, predicted_end_timings, atol=1e-4))
def test_get_vocab(self):
vocab_dict = self.tokenizer.get_vocab()
self.assertIsInstance(vocab_dict, dict)
self.assertGreaterEqual(len(self.tokenizer), len(vocab_dict))
vocab = [self.tokenizer.convert_ids_to_tokens(i) for i in range(len(self.tokenizer))]
self.assertEqual(len(vocab), len(self.tokenizer))
self.tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = [self.tokenizer.convert_ids_to_tokens(i) for i in range(len(self.tokenizer))]
self.assertEqual(len(vocab), len(self.tokenizer))
def test_save_and_load_tokenizer(self):
tmpdirname = tempfile.mkdtemp()
sample_notes = self.get_input_notes()
self.tokenizer.add_tokens(["bim", "bambam"])
additional_special_tokens = self.tokenizer.additional_special_tokens
additional_special_tokens.append("new_additional_special_token")
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
before_token_ids = self.tokenizer(sample_notes)["token_ids"]
before_vocab = self.tokenizer.get_vocab()
self.tokenizer.save_pretrained(tmpdirname)
after_tokenizer = self.tokenizer.__class__.from_pretrained(tmpdirname)
after_token_ids = after_tokenizer(sample_notes)["token_ids"]
after_vocab = after_tokenizer.get_vocab()
self.assertDictEqual(before_vocab, after_vocab)
self.assertListEqual(before_token_ids, after_token_ids)
self.assertIn("bim", after_vocab)
self.assertIn("bambam", after_vocab)
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
shutil.rmtree(tmpdirname)
def test_pickle_tokenizer(self):
tmpdirname = tempfile.mkdtemp()
notes = self.get_input_notes()
subwords = self.tokenizer(notes)["token_ids"]
filename = os.path.join(tmpdirname, "tokenizer.bin")
with open(filename, "wb") as handle:
pickle.dump(self.tokenizer, handle)
with open(filename, "rb") as handle:
tokenizer_new = pickle.load(handle)
subwords_loaded = tokenizer_new(notes)["token_ids"]
self.assertListEqual(subwords, subwords_loaded)
def test_padding_side_in_kwargs(self):
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", padding_side="left")
self.assertEqual(tokenizer_p.padding_side, "left")
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", padding_side="right")
self.assertEqual(tokenizer_p.padding_side, "right")
self.assertRaises(
ValueError,
Pop2PianoTokenizer.from_pretrained,
"susnato/pop2piano_dev",
padding_side="unauthorized",
)
def test_truncation_side_in_kwargs(self):
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", truncation_side="left")
self.assertEqual(tokenizer_p.truncation_side, "left")
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", truncation_side="right")
self.assertEqual(tokenizer_p.truncation_side, "right")
self.assertRaises(
ValueError,
Pop2PianoTokenizer.from_pretrained,
"susnato/pop2piano_dev",
truncation_side="unauthorized",
)
def test_right_and_left_padding(self):
tokenizer = self.tokenizer
notes = self.get_input_notes()
notes = notes[0]
max_length = 20
padding_idx = tokenizer.pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer.padding_side = "right"
padded_notes = tokenizer(notes, padding="max_length", max_length=max_length)["token_ids"]
padded_notes_length = len(padded_notes)
notes_without_padding = tokenizer(notes, padding="do_not_pad")["token_ids"]
padding_size = max_length - len(notes_without_padding)
self.assertEqual(padded_notes_length, max_length)
self.assertEqual(notes_without_padding + [padding_idx] * padding_size, padded_notes)
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer.padding_side = "left"
padded_notes = tokenizer(notes, padding="max_length", max_length=max_length)["token_ids"]
padded_notes_length = len(padded_notes)
notes_without_padding = tokenizer(notes, padding="do_not_pad")["token_ids"]
padding_size = max_length - len(notes_without_padding)
self.assertEqual(padded_notes_length, max_length)
self.assertEqual([padding_idx] * padding_size + notes_without_padding, padded_notes)
# RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding'
notes_without_padding = tokenizer(notes)["token_ids"]
tokenizer.padding_side = "right"
padded_notes_right = tokenizer(notes, padding=False)["token_ids"]
self.assertEqual(len(padded_notes_right), len(notes_without_padding))
self.assertEqual(padded_notes_right, notes_without_padding)
tokenizer.padding_side = "left"
padded_notes_left = tokenizer(notes, padding="longest")["token_ids"]
self.assertEqual(len(padded_notes_left), len(notes_without_padding))
self.assertEqual(padded_notes_left, notes_without_padding)
tokenizer.padding_side = "right"
padded_notes_right = tokenizer(notes, padding="longest")["token_ids"]
self.assertEqual(len(padded_notes_right), len(notes_without_padding))
self.assertEqual(padded_notes_right, notes_without_padding)
tokenizer.padding_side = "left"
padded_notes_left = tokenizer(notes, padding=False)["token_ids"]
self.assertEqual(len(padded_notes_left), len(notes_without_padding))
self.assertEqual(padded_notes_left, notes_without_padding)
def test_right_and_left_truncation(self):
tokenizer = self.tokenizer
notes = self.get_input_notes()
notes = notes[0]
truncation_size = 3
# RIGHT TRUNCATION - Check that it correctly truncates when a maximum length is specified along with the truncation flag set to True
tokenizer.truncation_side = "right"
full_encoded_notes = tokenizer(notes)["token_ids"]
full_encoded_notes_length = len(full_encoded_notes)
truncated_notes = tokenizer(notes, max_length=full_encoded_notes_length - truncation_size, truncation=True)[
"token_ids"
]
self.assertEqual(full_encoded_notes_length, len(truncated_notes) + truncation_size)
self.assertEqual(full_encoded_notes[:-truncation_size], truncated_notes)
# LEFT TRUNCATION - Check that it correctly truncates when a maximum length is specified along with the truncation flag set to True
tokenizer.truncation_side = "left"
full_encoded_notes = tokenizer(notes)["token_ids"]
full_encoded_notes_length = len(full_encoded_notes)
truncated_notes = tokenizer(notes, max_length=full_encoded_notes_length - truncation_size, truncation=True)[
"token_ids"
]
self.assertEqual(full_encoded_notes_length, len(truncated_notes) + truncation_size)
self.assertEqual(full_encoded_notes[truncation_size:], truncated_notes)
# RIGHT & LEFT TRUNCATION - Check that nothing is done for 'longest' and 'no_truncation'
tokenizer.truncation_side = "right"
truncated_notes_right = tokenizer(notes, truncation=True)["token_ids"]
self.assertEqual(full_encoded_notes_length, len(truncated_notes_right))
self.assertEqual(full_encoded_notes, truncated_notes_right)
tokenizer.truncation_side = "left"
truncated_notes_left = tokenizer(notes, truncation="longest_first")["token_ids"]
self.assertEqual(len(truncated_notes_left), full_encoded_notes_length)
self.assertEqual(truncated_notes_left, full_encoded_notes)
tokenizer.truncation_side = "right"
truncated_notes_right = tokenizer(notes, truncation="longest_first")["token_ids"]
self.assertEqual(len(truncated_notes_right), full_encoded_notes_length)
self.assertEqual(truncated_notes_right, full_encoded_notes)
tokenizer.truncation_side = "left"
truncated_notes_left = tokenizer(notes, truncation=True)["token_ids"]
self.assertEqual(len(truncated_notes_left), full_encoded_notes_length)
self.assertEqual(truncated_notes_left, full_encoded_notes)
def test_padding_to_multiple_of(self):
notes = self.get_input_notes()
if self.tokenizer.pad_token is None:
self.skipTest("No padding token.")
else:
normal_tokens = self.tokenizer(notes[0], padding=True, pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
normal_tokens = self.tokenizer(notes[0], pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
# Should also work with truncation
normal_tokens = self.tokenizer(notes[0], padding=True, truncation=True, pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
self.assertRaises(
ValueError,
self.tokenizer.__call__,
notes[0],
padding=True,
truncation=True,
max_length=12,
pad_to_multiple_of=8,
)
def test_padding_with_attention_mask(self):
if self.tokenizer.pad_token is None:
self.skipTest("No padding token.")
if "attention_mask" not in self.tokenizer.model_input_names:
self.skipTest("This model does not use attention mask.")
features = [
{"token_ids": [1, 2, 3, 4, 5, 6], "attention_mask": [1, 1, 1, 1, 1, 0]},
{"token_ids": [1, 2, 3], "attention_mask": [1, 1, 0]},
]
padded_features = self.tokenizer.pad(features)
if self.tokenizer.padding_side == "right":
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [1, 1, 0, 0, 0, 0]])
else:
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [0, 0, 0, 1, 1, 0]])