[Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile (#5395)
* add first version of clm tf * make style * add more tests for bert * update tf clm loss * fix tests * correct tf ner script * add mlm loss * delete bogus file * clean tf auto model + add tests * finish adding clm loss everywhere * fix training in distilbert * fix flake8 * save intermediate * fix tf t5 naming * remove prints * finish up * up * fix tf gpt2 * fix new test utils import * fix flake8 * keep backward compatibility * Update src/transformers/modeling_tf_albert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_roberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_mobilebert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_distilbert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply sylvains suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
33e43edddc
commit
4dc65591b5
@@ -17,6 +17,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -184,7 +185,12 @@ def main():
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(seq_len):
|
||||
if label_ids[i, j] != -1:
|
||||
if label_ids[i, j] == -1:
|
||||
label_ids[i, j] = -100
|
||||
warnings.warn(
|
||||
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
|
||||
)
|
||||
if label_ids[i, j] != -100:
|
||||
out_label_list[i].append(label_map[label_ids[i][j]])
|
||||
preds_list[i].append(label_map[preds[i][j]])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user