Fix import error in script to convert faisreq roberta checkpoints

This commit is contained in:
Louis MARTIN
2019-10-14 01:38:57 -07:00
parent a701c9b321
commit 49cba6e543

View File

@@ -23,15 +23,15 @@ import torch
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from transformers import (BertConfig, BertEncoder, from transformers.modeling_bert import (BertConfig, BertEncoder,
BertIntermediate, BertLayer, BertIntermediate, BertLayer,
BertModel, BertOutput, BertModel, BertOutput,
BertSelfAttention, BertSelfAttention,
BertSelfOutput) BertSelfOutput)
from transformers import (RobertaEmbeddings, from transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaModel) RobertaModel)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)