@@ -30,7 +30,6 @@ if is_pretty_midi_available():
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
|
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
"vocab": "vocab.json",
|
"vocab": "vocab.json",
|
||||||
@@ -38,7 +37,7 @@ VOCAB_FILES_NAMES = {
|
|||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
"vocab": {
|
"vocab": {
|
||||||
"susnato/pop2piano_dev": "https://huggingface.co/susnato/pop2piano_dev/blob/main/vocab.json",
|
"sweetcocoa/pop2piano": "https://huggingface.co/sweetcocoa/pop2piano/blob/main/vocab.json",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -734,10 +734,10 @@ class Pop2PianoModelIntegrationTests(unittest.TestCase):
|
|||||||
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
|
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
|
||||||
from transformers import Pop2PianoFeatureExtractor, Pop2PianoTokenizer
|
from transformers import Pop2PianoFeatureExtractor, Pop2PianoTokenizer
|
||||||
|
|
||||||
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
|
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||||
model.eval()
|
model.eval()
|
||||||
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
|
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
|
||||||
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
|
tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
|
||||||
ds = load_dataset("sweetcocoa/pop2piano_ci", split="test")
|
ds = load_dataset("sweetcocoa/pop2piano_ci", split="test")
|
||||||
|
|
||||||
output_fe = feature_extractor(
|
output_fe = feature_extractor(
|
||||||
|
|||||||
@@ -55,8 +55,6 @@ if requirements_available:
|
|||||||
Pop2PianoTokenizer,
|
Pop2PianoTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
|
|
||||||
|
|
||||||
|
|
||||||
@require_scipy
|
@require_scipy
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -67,8 +65,8 @@ class Pop2PianoProcessorTest(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
|
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
|
||||||
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
|
tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
|
||||||
processor = Pop2PianoProcessor(feature_extractor, tokenizer)
|
processor = Pop2PianoProcessor(feature_extractor, tokenizer)
|
||||||
|
|
||||||
processor.save_pretrained(self.tmpdirname)
|
processor.save_pretrained(self.tmpdirname)
|
||||||
@@ -121,7 +119,7 @@ class Pop2PianoProcessorTest(unittest.TestCase):
|
|||||||
feature_extractor_outputs = self.get_feature_extractor()(
|
feature_extractor_outputs = self.get_feature_extractor()(
|
||||||
audio=input_speech, sampling_rate=sampling_rate, return_tensors="pt"
|
audio=input_speech, sampling_rate=sampling_rate, return_tensors="pt"
|
||||||
)
|
)
|
||||||
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
|
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||||
token_ids = model.generate(input_features=feature_extractor_outputs["input_features"], composer="composer1")
|
token_ids = model.generate(input_features=feature_extractor_outputs["input_features"], composer="composer1")
|
||||||
dummy_notes = [
|
dummy_notes = [
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -43,15 +43,12 @@ if requirements_available:
|
|||||||
from transformers import Pop2PianoTokenizer
|
from transformers import Pop2PianoTokenizer
|
||||||
|
|
||||||
|
|
||||||
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_pretty_midi
|
@require_pretty_midi
|
||||||
class Pop2PianoTokenizerTest(unittest.TestCase):
|
class Pop2PianoTokenizerTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
|
self.tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
|
||||||
|
|
||||||
def get_input_notes(self):
|
def get_input_notes(self):
|
||||||
notes = [
|
notes = [
|
||||||
@@ -246,30 +243,30 @@ class Pop2PianoTokenizerTest(unittest.TestCase):
|
|||||||
self.assertListEqual(subwords, subwords_loaded)
|
self.assertListEqual(subwords, subwords_loaded)
|
||||||
|
|
||||||
def test_padding_side_in_kwargs(self):
|
def test_padding_side_in_kwargs(self):
|
||||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", padding_side="left")
|
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", padding_side="left")
|
||||||
self.assertEqual(tokenizer_p.padding_side, "left")
|
self.assertEqual(tokenizer_p.padding_side, "left")
|
||||||
|
|
||||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", padding_side="right")
|
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", padding_side="right")
|
||||||
self.assertEqual(tokenizer_p.padding_side, "right")
|
self.assertEqual(tokenizer_p.padding_side, "right")
|
||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError,
|
ValueError,
|
||||||
Pop2PianoTokenizer.from_pretrained,
|
Pop2PianoTokenizer.from_pretrained,
|
||||||
"susnato/pop2piano_dev",
|
"sweetcocoa/pop2piano",
|
||||||
padding_side="unauthorized",
|
padding_side="unauthorized",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_truncation_side_in_kwargs(self):
|
def test_truncation_side_in_kwargs(self):
|
||||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", truncation_side="left")
|
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", truncation_side="left")
|
||||||
self.assertEqual(tokenizer_p.truncation_side, "left")
|
self.assertEqual(tokenizer_p.truncation_side, "left")
|
||||||
|
|
||||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", truncation_side="right")
|
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", truncation_side="right")
|
||||||
self.assertEqual(tokenizer_p.truncation_side, "right")
|
self.assertEqual(tokenizer_p.truncation_side, "right")
|
||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError,
|
ValueError,
|
||||||
Pop2PianoTokenizer.from_pretrained,
|
Pop2PianoTokenizer.from_pretrained,
|
||||||
"susnato/pop2piano_dev",
|
"sweetcocoa/pop2piano",
|
||||||
truncation_side="unauthorized",
|
truncation_side="unauthorized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user