updating examples and doc

This commit is contained in:
thomwolf
2019-07-14 23:20:10 +02:00
parent c490f5ce87
commit 2397f958f9
16 changed files with 601 additions and 667 deletions

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function
@@ -230,6 +230,9 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if args.local_rank in [-1, 0]:
tb_writer.close()
return results
@@ -242,7 +245,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length),
str(task)))
if os.path.exists(cached_features_file) and not args.overwrite_cache:
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
@@ -410,7 +413,7 @@ def main():
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
# Distributed and parrallel training
# Distributed and parallel training
model.to(args.device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],