update float()

This commit is contained in:
thomwolf
2018-11-04 21:25:36 +01:00
parent c6207d85b6
commit d9d7d1a462

View File

@@ -1015,7 +1015,7 @@
" print(input_mask)\n", " print(input_mask)\n",
" print(example_indices)\n", " print(example_indices)\n",
" input_ids = input_ids.to(device)\n", " input_ids = input_ids.to(device)\n",
" input_mask = input_mask.float().to(device)\n", " input_mask = input_mask.to(device)\n",
"\n", "\n",
" all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n", " all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n",
"\n", "\n",