Python 2 must DIE
This commit is contained in:
@@ -58,7 +58,7 @@ class RobertaEmbeddings(BertEmbeddings):
|
||||
# cf. fairseq's `utils.make_positions`
|
||||
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
return super().forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
|
||||
|
||||
class RobertaConfig(BertConfig):
|
||||
@@ -109,8 +109,8 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
class RobertaLMHead(nn.Module):
|
||||
"""Roberta Head for masked language modeling."""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
def __init__(self, config):
|
||||
super(RobertaLMHead, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import (absolute_import, division, print_function,
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
import six
|
||||
|
||||
from pytorch_transformers.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
@@ -31,6 +32,7 @@ class RobertaTokenizationTest(unittest.TestCase):
|
||||
tokenizer.encode('Hello world!'),
|
||||
[0, 31414, 232, 328, 2]
|
||||
)
|
||||
if six.PY3:
|
||||
self.assertListEqual(
|
||||
tokenizer.encode('Hello world! cécé herlolip'),
|
||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||
|
||||
@@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function,
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from io import open
|
||||
import six
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
@@ -125,7 +127,7 @@ class Dictionary(object):
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols
|
||||
to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
if isinstance(f, six.string_types):
|
||||
try:
|
||||
if not ignore_utf_errors:
|
||||
with open(f, 'r', encoding='utf-8') as fd:
|
||||
|
||||
Reference in New Issue
Block a user