Python 2 must DIE
This commit is contained in:
@@ -58,7 +58,7 @@ class RobertaEmbeddings(BertEmbeddings):
|
|||||||
# cf. fairseq's `utils.make_positions`
|
# cf. fairseq's `utils.make_positions`
|
||||||
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
return super().forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
||||||
|
|
||||||
|
|
||||||
class RobertaConfig(BertConfig):
|
class RobertaConfig(BertConfig):
|
||||||
@@ -109,8 +109,8 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
|||||||
class RobertaLMHead(nn.Module):
|
class RobertaLMHead(nn.Module):
|
||||||
"""Roberta Head for masked language modeling."""
|
"""Roberta Head for masked language modeling."""
|
||||||
|
|
||||||
def __init__(self, config: BertConfig):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super(RobertaLMHead, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from __future__ import (absolute_import, division, print_function,
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
import six
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_roberta import RobertaTokenizer
|
from pytorch_transformers.tokenization_roberta import RobertaTokenizer
|
||||||
|
|
||||||
@@ -31,10 +32,11 @@ class RobertaTokenizationTest(unittest.TestCase):
|
|||||||
tokenizer.encode('Hello world!'),
|
tokenizer.encode('Hello world!'),
|
||||||
[0, 31414, 232, 328, 2]
|
[0, 31414, 232, 328, 2]
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
if six.PY3:
|
||||||
tokenizer.encode('Hello world! cécé herlolip'),
|
self.assertListEqual(
|
||||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
tokenizer.encode('Hello world! cécé herlolip'),
|
||||||
)
|
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function,
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from io import open
|
||||||
|
import six
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
@@ -125,7 +127,7 @@ class Dictionary(object):
|
|||||||
Loads a pre-existing dictionary from a text file and adds its symbols
|
Loads a pre-existing dictionary from a text file and adds its symbols
|
||||||
to this instance.
|
to this instance.
|
||||||
"""
|
"""
|
||||||
if isinstance(f, str):
|
if isinstance(f, six.string_types):
|
||||||
try:
|
try:
|
||||||
if not ignore_utf_errors:
|
if not ignore_utf_errors:
|
||||||
with open(f, 'r', encoding='utf-8') as fd:
|
with open(f, 'r', encoding='utf-8') as fd:
|
||||||
|
|||||||
Reference in New Issue
Block a user