improve device usage
This commit is contained in:
committed by
Julien Chaumond
parent
c0707a85d2
commit
2a64107e44
@@ -29,7 +29,7 @@ And move all the stories to the same folder. We will refer as `$DATA_PATH` the p
|
|||||||
python run_summarization.py \
|
python run_summarization.py \
|
||||||
--documents_dir $DATA_PATH \
|
--documents_dir $DATA_PATH \
|
||||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||||
--visible_gpus 0,1,2 \
|
--to_cpu false \
|
||||||
--batch_size 4 \
|
--batch_size 4 \
|
||||||
--min_length 50 \
|
--min_length 50 \
|
||||||
--max_length 200 \
|
--max_length 200 \
|
||||||
@@ -39,7 +39,7 @@ python run_summarization.py \
|
|||||||
--compute_rouge true
|
--compute_rouge true
|
||||||
```
|
```
|
||||||
|
|
||||||
The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file.
|
The scripts executes on GPU if one is available and if `to_cpu` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||||
|
|
||||||
## Summarize any text
|
## Summarize any text
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ Put the documents that you would like to summarize in a folder (the path to whic
|
|||||||
python run_summarization.py \
|
python run_summarization.py \
|
||||||
--documents_dir $DATA_PATH \
|
--documents_dir $DATA_PATH \
|
||||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||||
--visible_gpus 0,1,2 \
|
--to_cpu false \
|
||||||
--batch_size 4 \
|
--batch_size 4 \
|
||||||
--min_length 50 \
|
--min_length 50 \
|
||||||
--max_length 200 \
|
--max_length 200 \
|
||||||
@@ -58,4 +58,4 @@ python run_summarization.py \
|
|||||||
--block_trigram true \
|
--block_trigram true \
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py`
|
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
|
||||||
|
|||||||
@@ -12,10 +12,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Convert BertExtAbs's checkpoints
|
""" Convert BertExtAbs's checkpoints.
|
||||||
|
|
||||||
The file currently does not do much as we ended up copying the exact model
|
The script looks like it is doing something trivial but it is not. The "weights"
|
||||||
structure, but I leave it here in case we ever want to refactor the model.
|
proposed by the authors are actually the entire model pickled. We need to load
|
||||||
|
the model within the original codebase to be able to only save its `state_dict`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
@@ -847,14 +847,12 @@ class Translator(object):
|
|||||||
global_scores (:obj:`GlobalScorer`):
|
global_scores (:obj:`GlobalScorer`):
|
||||||
object to rescore final translations
|
object to rescore final translations
|
||||||
copy_attn (bool): use copy attention during translation
|
copy_attn (bool): use copy attention during translation
|
||||||
cuda (bool): use cuda
|
|
||||||
beam_trace (bool): trace beam search for debugging
|
beam_trace (bool): trace beam search for debugging
|
||||||
logger(logging.Logger): logger.
|
logger(logging.Logger): logger.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
|
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cuda = args.visible_gpus != "-1"
|
|
||||||
|
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ def save_summaries(summaries, path, original_document_name):
|
|||||||
def build_data_iterator(args, tokenizer):
|
def build_data_iterator(args, tokenizer):
|
||||||
dataset = load_and_cache_examples(args, tokenizer)
|
dataset = load_and_cache_examples(args, tokenizer)
|
||||||
sampler = SequentialSampler(dataset)
|
sampler = SequentialSampler(dataset)
|
||||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
|
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||||
iterator = DataLoader(
|
iterator = DataLoader(
|
||||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
@@ -198,7 +198,7 @@ def load_and_cache_examples(args, tokenizer):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def collate(data, tokenizer, block_size):
|
def collate(data, tokenizer, block_size, device):
|
||||||
""" Collate formats the data passed to the data loader.
|
""" Collate formats the data passed to the data loader.
|
||||||
|
|
||||||
In particular we tokenize the data batch after batch to avoid keeping them
|
In particular we tokenize the data batch after batch to avoid keeping them
|
||||||
@@ -224,9 +224,9 @@ def collate(data, tokenizer, block_size):
|
|||||||
batch = Batch(
|
batch = Batch(
|
||||||
document_names=names,
|
document_names=names,
|
||||||
batch_size=len(encoded_stories),
|
batch_size=len(encoded_stories),
|
||||||
src=encoded_stories,
|
src=encoded_stories.to(device),
|
||||||
segs=encoder_token_type_ids,
|
segs=encoder_token_type_ids.to(device),
|
||||||
mask_src=encoder_mask,
|
mask_src=encoder_mask.to(device),
|
||||||
tgt_str=summaries,
|
tgt_str=summaries,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -271,10 +271,10 @@ def main():
|
|||||||
)
|
)
|
||||||
# EVALUATION options
|
# EVALUATION options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--visible_gpus",
|
"--to_cpu",
|
||||||
default=-1,
|
default=False,
|
||||||
type=int,
|
type=bool,
|
||||||
help="Number of GPUs with which to do the training.",
|
help="Whether to force the execution on CPU.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||||
@@ -311,8 +311,11 @@ def main():
|
|||||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
|
|
||||||
|
|
||||||
|
# Select device (distibuted not available)
|
||||||
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.to_cpu else "cpu")
|
||||||
|
|
||||||
|
# Check the existence of directories
|
||||||
if not args.summaries_output_dir:
|
if not args.summaries_output_dir:
|
||||||
args.summaries_output_dir = args.documents_dir
|
args.summaries_output_dir = args.documents_dir
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user