updating examples and doc
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
|
||||
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@@ -230,6 +230,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -242,7 +245,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
@@ -410,7 +413,7 @@ def main():
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
# Distributed and parrallel training
|
||||
# Distributed and parallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
|
||||
Reference in New Issue
Block a user