Add Information Gain Filtration algorithm (#16953)
* Add information gain filtration algorithm * Complying with black requirements * Added author * Fixed import order * flake8 corrections Co-authored-by: Javier Turek <javier.turek@intel.com>
This commit is contained in:
100
examples/research_projects/information-gain-filtration/README.md
Normal file
100
examples/research_projects/information-gain-filtration/README.md
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
|
||||||
|
# Information Gain Filtration(IGF)
|
||||||
|
|
||||||
|
Authors @Tuko @mraunak
|
||||||
|
|
||||||
|
This folder contains the code how to implement IGF for finetuning on GPT-2.
|
||||||
|
|
||||||
|
## What is IGF?
|
||||||
|
|
||||||
|
Here we present a general fine-tuning method that we call information gain filtration for improving the overall training efficiency and final
|
||||||
|
performance of language model fine-tuning(see paper below). The method is an alternative fine-tuning method that trains
|
||||||
|
a secondary model (e.g., a simple convolutional network) to predict the amount of information
|
||||||
|
gained over a given pre-trained model. The secondary model is lightweight and trained to
|
||||||
|
predict the Information Gain measure. Information Gain is defined as the change in a loss
|
||||||
|
function for a model before and after an SGD update with a sample (Equation X in the paper).
|
||||||
|
A small subset of the training set named the “objective” set, is used to measure information
|
||||||
|
gain on the pre-trained model, and consequently to train the secondary model. After
|
||||||
|
training, the model is used for filtering samples for the fine-tuning process. Therefore,
|
||||||
|
a high information gain value would suggest a sample is informative, whereas a low value
|
||||||
|
would suggest a non-informative sample that should be filtered out. Thus, a thresholding
|
||||||
|
strategy is defined to select informative samples. With such a strategy, samples are filtered
|
||||||
|
and once enough samples are selected to form a mini-batch and a usual fine-tuning/optimization
|
||||||
|
step is applied. The filtration process is repeated until the fine-tuning process is over.
|
||||||
|
|
||||||
|
Paper [Selecting Informative Contexts Improves Language Model Finetuning](https://arxiv.org/abs/2005.00175)
|
||||||
|
|
||||||
|
# Results
|
||||||
|
|
||||||
|
Several experiments were conducted to show the robustness of the IGF method versus the
|
||||||
|
standard fine-tuning process. For example, we achieve a median perplexity of 54.0 on the
|
||||||
|
Books dataset compared to 57.3 for standard fine-tuning on GPT-2 Small. The code was
|
||||||
|
implemented using the Transformers library and Pytorch. While the method may seem more
|
||||||
|
expensive, we saw enough evidence that it may lead to a performance benefit in the final models.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Figure 1: Comparing IGF to Standard Fine-tuning:
|
||||||
|
IGF with constant (p < 10−3 , t-test) and shifting(p < 10−6 , t-test) thresholding significantly outperform standard fine-tuning. The left-hand figure shows
|
||||||
|
test-set perplexity after each fine-tuning batch, averaged over 50 runs (error bars denote ± one standard error). The right-hand figure shows the perplexity of each
|
||||||
|
method after 60 batches. IGF with shifting thresholding (red) clearly improves over standard batched fine-tuning with Adam
|
||||||
|
|
||||||
|
## How to use this project?
|
||||||
|
|
||||||
|
To fine-tune a transformer model with IGF on a language modeling task, use the following script:
|
||||||
|
|
||||||
|
- `model_name_or_path`: Path to pretrained model or model identifier from huggingface.co/models
|
||||||
|
- `data_file`: A jbl file containing tokenized data which can be split as objective dataset,
|
||||||
|
train_dataset and test_dataset
|
||||||
|
- `igf_data_file`: A jbl file containing the context and information gain pairs to train secondary learner.
|
||||||
|
- `context_len`: The maximum total input sequence length after tokenization. Sequences longer
|
||||||
|
than this will be truncated, sequences shorter will be padded.
|
||||||
|
- `size_objective_set`: Number of articles that are long enough to be used as our objective set"
|
||||||
|
- `min_len`: The minimum length of the article to be used as objective set
|
||||||
|
- `trim`: Truncate the example if it exceeds context length
|
||||||
|
- `eval_freq`: Secondary model evaluation can be triggered at eval_freq
|
||||||
|
- `max_steps`: To calculate training epochs
|
||||||
|
- `number`: The number of examples split to be used as objective_set/test_data
|
||||||
|
- `secondary_learner_batch_size`: The batch size of training data for secondary learner
|
||||||
|
- `secondary_learner_max_epochs`: The number of epochs to train secondary learner
|
||||||
|
- `recopy_model`: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||||
|
- `eval_interval`: Decay the selectivity of our secondary learner filter from"
|
||||||
|
1 standard deviation above average to 1 below average after eval_interval(10) batches"
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
python run_clm_igf.py\
|
||||||
|
--model_name_or_path "gpt2" \
|
||||||
|
--data_file="data/tokenized_stories_train_wikitext103" \
|
||||||
|
--igf_data_file="data/IGF_values" \
|
||||||
|
--context_len 32 \
|
||||||
|
--size_objective_set 100 \
|
||||||
|
--min_len 1026 \
|
||||||
|
--trim True \
|
||||||
|
--eval_freq 100 \
|
||||||
|
--max_steps 1000 \
|
||||||
|
--secondary_learner_batch_size 128 \
|
||||||
|
--secondary_learner_max_epochs 15 \
|
||||||
|
--number 100 \
|
||||||
|
--recopy_model \
|
||||||
|
--eval_interval 10 \
|
||||||
|
```
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you find the resource useful, please cite the following paper
|
||||||
|
|
||||||
|
```
|
||||||
|
@inproceedings{antonello-etal-2021-selecting,
|
||||||
|
title = "Selecting Informative Contexts Improves Language Model Fine-tuning",
|
||||||
|
author = "Antonello, Richard and Beckage, Nicole and Turek, Javier and Huth, Alexander",
|
||||||
|
booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
|
||||||
|
month = aug,
|
||||||
|
year = "2021",
|
||||||
|
address = "Online",
|
||||||
|
publisher = "Association for Computational Linguistics",
|
||||||
|
url = "https://aclanthology.org/2021.acl-long.87",
|
||||||
|
doi = "10.18653/v1/2021.acl-long.87",
|
||||||
|
pages = "1072--1085",
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -0,0 +1,419 @@
|
|||||||
|
# Copyright 2022 - Intel Corp. All rights reserved.
|
||||||
|
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Backage
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
"""
|
||||||
|
For reproducible training
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: A seed for reproducible training
|
||||||
|
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_perplexity(model, test_data, context_len):
|
||||||
|
"""
|
||||||
|
Computes perplexity of the transformer model on data in test_data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Pre-trained GPT2 model
|
||||||
|
test_data: Data on which perplexity calculation is required
|
||||||
|
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||||
|
than this will be truncated, sequences shorter will be padded
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Perplexity on input test data
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
eval_batch_size = 1
|
||||||
|
context = torch.zeros((eval_batch_size, context_len), dtype=torch.long, device=device)
|
||||||
|
eval_dataloader = DataLoader(test_data, shuffle=False, batch_size=eval_batch_size)
|
||||||
|
eval_loss = torch.zeros(1, device=device)
|
||||||
|
nb_eval_examples = 0
|
||||||
|
for batch in eval_dataloader:
|
||||||
|
batch.to(device)
|
||||||
|
# pad
|
||||||
|
context.zero_()
|
||||||
|
for i in range(eval_batch_size):
|
||||||
|
context[i, :] = batch[i]
|
||||||
|
outputs = model(context, labels=context)
|
||||||
|
eval_loss += outputs[0].sum().item()
|
||||||
|
nb_eval_examples += batch.size(0)
|
||||||
|
eval_loss = eval_loss / nb_eval_examples
|
||||||
|
perplexity = torch.exp(eval_loss)
|
||||||
|
model.train()
|
||||||
|
return perplexity
|
||||||
|
|
||||||
|
|
||||||
|
def load_gpt2(model_name="gpt2"):
|
||||||
|
"""
|
||||||
|
load original gpt2 and save off for quicker loading
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: GPT-2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GPT-2 model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
|
||||||
|
torch.save(model.state_dict(), model_name + "local.pt")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def recopy_gpt2(orig_model, device, max_steps):
|
||||||
|
"""
|
||||||
|
Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orig_model: Original pretrained GPT-2 model imported from Transformers library
|
||||||
|
device: CPU/GPU
|
||||||
|
max_steps: number of training steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Original PreTrained GPT-2 model,
|
||||||
|
lm_optimizer: Adam optimizer with Decoupled weight decay
|
||||||
|
lm_scheduler: linear scheduler with the appropriate schedule
|
||||||
|
|
||||||
|
"""
|
||||||
|
model = copy.deepcopy(orig_model)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
|
]
|
||||||
|
lm_optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
|
||||||
|
lm_scheduler = get_linear_schedule_with_warmup(lm_optimizer, 0, max_steps)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return model, lm_optimizer, lm_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def intermittent_save(contexts, real_perps, past_perps, filename):
|
||||||
|
|
||||||
|
"""
|
||||||
|
save the perplexity differences to filename
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Example on which the perplexity is calculated
|
||||||
|
real_perps: Perplexity after back-propagating on the selected context
|
||||||
|
past_perps: Perplexity of model before training on the context
|
||||||
|
filename: File to store perplexity differences
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file with perplexity differences
|
||||||
|
|
||||||
|
"""
|
||||||
|
# save the perplexity differences to filename
|
||||||
|
avg = np.array(real_perps).mean()
|
||||||
|
std = np.array(real_perps).std()
|
||||||
|
perp_diff = (real_perps - avg) / std
|
||||||
|
data_final = list(zip(contexts, perp_diff, past_perps))
|
||||||
|
joblib.dump(data_final, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_objective_set(
|
||||||
|
model,
|
||||||
|
orig_perp,
|
||||||
|
context_len,
|
||||||
|
train_data,
|
||||||
|
objective_set,
|
||||||
|
max_steps,
|
||||||
|
device,
|
||||||
|
filename="dev.jbl",
|
||||||
|
recopy_model=recopy_gpt2,
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Collect individual IGF values from pre-trained transformer model
|
||||||
|
max_steps samples of training data to train secondary model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Pre-trained GPT2 model
|
||||||
|
orig_perp: Perplexity of original pretrained GPT-2 model
|
||||||
|
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||||
|
than this will be truncated, sequences shorter will be padded
|
||||||
|
train_data: Data to train model
|
||||||
|
objective_set: Contexts used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||||||
|
max_steps: To calculate training epochs of model
|
||||||
|
device: GPU/CPU
|
||||||
|
filename: To store intermediate perplexity differences
|
||||||
|
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file stored intermediate perplexity differences in intermediate stages
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# initialize variables to record relevant information
|
||||||
|
contexts = []
|
||||||
|
real_perps = []
|
||||||
|
past_perps = []
|
||||||
|
|
||||||
|
# Initialize the transformer model
|
||||||
|
orig_model = copy.deepcopy(model)
|
||||||
|
orig_model.to(device="cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Compute perplexity of initial transformer model for comparison
|
||||||
|
model.train()
|
||||||
|
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
|
||||||
|
|
||||||
|
for step in tqdm(range(max_steps)):
|
||||||
|
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||||||
|
story = random.choice(train_data)
|
||||||
|
start = random.randint(0, len(story[0]) - context_len - 1)
|
||||||
|
context[0, :] = story[0][start : start + context_len]
|
||||||
|
lm_optimizer.zero_grad()
|
||||||
|
outputs = model(context, labels=context)
|
||||||
|
lm_loss = outputs[0]
|
||||||
|
past_perp = compute_perplexity(model, context, context_len)
|
||||||
|
model.train()
|
||||||
|
lm_loss.backward()
|
||||||
|
# Do LM backprop
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||||||
|
lm_optimizer.step()
|
||||||
|
lm_scheduler.step() # Update learning rate schedule
|
||||||
|
|
||||||
|
# Compute perplexity after back-propagating on the selected context
|
||||||
|
real_perp = compute_perplexity(model, objective_set, context_len)
|
||||||
|
|
||||||
|
# Periodically save the stored (X, IG(X)) pairs
|
||||||
|
if step % 1000 == 0 and step > 1:
|
||||||
|
intermittent_save(contexts, real_perps, past_perps, filename)
|
||||||
|
|
||||||
|
# Reset the pretrained model to the original pretrained GPT-2 weights after each iteration
|
||||||
|
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
|
||||||
|
|
||||||
|
past_perps.append(past_perp.item())
|
||||||
|
real_perps.append(orig_perp - real_perp.item())
|
||||||
|
contexts.append(np.array(context.cpu()))
|
||||||
|
|
||||||
|
intermittent_save(contexts, real_perps, past_perps, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_datasets(
|
||||||
|
context_len, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate objective set and training set
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||||
|
than this will be truncated, sequences shorter will be padded
|
||||||
|
file: Tokenized data split into training set and objective set
|
||||||
|
number: size of objective dataset
|
||||||
|
min_len: minimum length of a context in objective set
|
||||||
|
trim: If True truncate the context if it exceeds context length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated objective set and training data
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Generate objective set and training set
|
||||||
|
# Designate the first number (100) articles that are long enough to be used
|
||||||
|
# as our objective set, rest (that are long enough) are training data for
|
||||||
|
# secondary learner
|
||||||
|
|
||||||
|
data = joblib.load(file)
|
||||||
|
print("data loaded")
|
||||||
|
objective_set = []
|
||||||
|
if trim:
|
||||||
|
for i, example in enumerate(data):
|
||||||
|
if len(example[0]) > min_len:
|
||||||
|
start = random.randint(0, len(example[0]) - context_len - 1)
|
||||||
|
objective_set.append(example[0, start : start + context_len])
|
||||||
|
if len(objective_set) >= number:
|
||||||
|
break
|
||||||
|
train_data = []
|
||||||
|
for j in range(i + 1, len(data)):
|
||||||
|
if len(data[j][0]) > min_len:
|
||||||
|
train_data.append(data[j])
|
||||||
|
else:
|
||||||
|
objective_set = data[0:number]
|
||||||
|
train_data = data[number:]
|
||||||
|
|
||||||
|
joblib.dump(objective_set, "objective_set.jbl")
|
||||||
|
print("objective set saved")
|
||||||
|
return train_data, objective_set
|
||||||
|
|
||||||
|
|
||||||
|
def train_secondary_learner(
|
||||||
|
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Train the secondary learner (igf_model)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secondary_learner: secondary learner
|
||||||
|
train_dataset: data to train secondary learner
|
||||||
|
max_epochs: number of epochs to train secondary learner
|
||||||
|
batch_size: batch size of training data of secondary learner
|
||||||
|
eval_freq: secondary model evaluation can be triggered at eval_freq
|
||||||
|
igf_model_path: path to store trained secondary learner
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trained secondary learner
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
# We will use the first 512 pairs from our dataset as a test set for
|
||||||
|
# our secondary learner and the rest to train
|
||||||
|
test_dataset = train_dataset[:512]
|
||||||
|
train_dataset = train_dataset[512:]
|
||||||
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
|
||||||
|
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
|
||||||
|
|
||||||
|
# secondary learner model set up
|
||||||
|
loss = nn.MSELoss()
|
||||||
|
test_loss = nn.MSELoss(reduction="sum")
|
||||||
|
secondary_learner.to(device)
|
||||||
|
q_optimizer = torch.optim.Adam(secondary_learner.parameters(), lr=0.00001)
|
||||||
|
secondary_learner.train()
|
||||||
|
|
||||||
|
# TODO in original code this is written as number of actual batches seen
|
||||||
|
# not number of items seen but other places it is number of items instead.
|
||||||
|
# improve consistency! changed this to epochs for clarity
|
||||||
|
best_test_loss = float("inf")
|
||||||
|
# Iterate through batches until we've used max_steps batches
|
||||||
|
for epoch in range(int(max_epochs)):
|
||||||
|
tr_q_loss = 0.0
|
||||||
|
secondary_learner.train()
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
context = batch[0].to(device)
|
||||||
|
real_q = batch[1].to(device)
|
||||||
|
predicted_q = secondary_learner(context)
|
||||||
|
q_optimizer.zero_grad()
|
||||||
|
q_loss = loss(predicted_q, real_q.float())
|
||||||
|
q_loss.backward()
|
||||||
|
q_optimizer.step()
|
||||||
|
tr_q_loss += q_loss.item()
|
||||||
|
|
||||||
|
# model trains fairly quickly so we won't wait for a full epoch
|
||||||
|
# eval is triggered at eval_freq and end of epochs
|
||||||
|
if (step % eval_freq == 0 and step > 0) or ((step + 1) == len(train_dataloader)):
|
||||||
|
tr_loss = tr_q_loss / (step + 1)
|
||||||
|
|
||||||
|
secondary_learner.eval()
|
||||||
|
q_loss2 = 0.0
|
||||||
|
sum_q2 = 0.0
|
||||||
|
predicted = []
|
||||||
|
actual = []
|
||||||
|
# Compute performance of the secondary learner after this batch
|
||||||
|
for step2, batch2 in enumerate(test_dataloader):
|
||||||
|
features2 = batch2[0].to(device)
|
||||||
|
real_q2 = batch2[1].to(device)
|
||||||
|
predicted_q2 = secondary_learner(features2)
|
||||||
|
q_loss2 += test_loss(predicted_q2, real_q2).item()
|
||||||
|
sum_q2 += torch.sum(predicted_q2).item()
|
||||||
|
for ei, i in enumerate(predicted_q2.cpu().detach().numpy()):
|
||||||
|
predicted.append(i.item())
|
||||||
|
for ei, i in enumerate(real_q2.cpu().detach().numpy()):
|
||||||
|
actual.append(i.item())
|
||||||
|
|
||||||
|
q_loss2 /= len(test_dataset)
|
||||||
|
print(
|
||||||
|
"Epoch: ",
|
||||||
|
epoch,
|
||||||
|
"step: ",
|
||||||
|
step,
|
||||||
|
"Avg. q:",
|
||||||
|
sum_q2 / len(test_dataset),
|
||||||
|
"Train Loss: ",
|
||||||
|
tr_loss,
|
||||||
|
"Test Loss: ",
|
||||||
|
q_loss2,
|
||||||
|
)
|
||||||
|
if q_loss2 < best_test_loss:
|
||||||
|
joblib.dump((predicted, actual), "pred_vs_actual.jbl")
|
||||||
|
torch.save(secondary_learner.state_dict(), igf_model_path)
|
||||||
|
best_test_loss = q_loss2
|
||||||
|
|
||||||
|
secondary_learner.train()
|
||||||
|
return secondary_learner
|
||||||
|
|
||||||
|
|
||||||
|
class SecondaryLearner(nn.Module):
|
||||||
|
"""
|
||||||
|
Our secondary learner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
"""
|
||||||
|
We use a simple convolutional network as our secondary learner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Pre-trained GPT2 model
|
||||||
|
"""
|
||||||
|
# embeddings are from the pretrained model
|
||||||
|
super(SecondaryLearner, self).__init__()
|
||||||
|
self.embeddings = model.transformer.wte
|
||||||
|
self.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
|
||||||
|
self.conv = nn.Conv1d(self.embeddings.weight.size(1), 256, 3, padding=1)
|
||||||
|
self.fc = nn.Sequential(nn.Linear(256, 32), nn.Dropout(p=0.1), nn.Linear(32, 32), nn.Linear(32, 1))
|
||||||
|
|
||||||
|
def forward(self, context):
|
||||||
|
"""
|
||||||
|
Forward pass through the secondary learner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context input to the secondary learner
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor after squeeze operation
|
||||||
|
|
||||||
|
"""
|
||||||
|
pooled = torch.max(self.conv(self.embeddings(context).squeeze(1).transpose(1, 2)), 2)[0]
|
||||||
|
qs = self.fc(pooled)
|
||||||
|
return qs.squeeze(1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, state_path, model):
|
||||||
|
"""
|
||||||
|
Load the secondary learner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_path: Path to save secondary learner
|
||||||
|
model: Pretrained GPT-2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
secondary learner
|
||||||
|
"""
|
||||||
|
|
||||||
|
secondary_learner = cls(model) # this calls __init__
|
||||||
|
state_dict = torch.load(state_path)
|
||||||
|
secondary_learner.load_state_dict(state_dict)
|
||||||
|
secondary_learner.embeddings = model.transformer.wte
|
||||||
|
secondary_learner.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
|
||||||
|
return secondary_learner
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
matplotlib
|
||||||
|
numpy>=1.17.2
|
||||||
|
joblib>=0.13.2
|
||||||
|
scipy
|
||||||
|
torch>=1.10.1
|
||||||
|
transformers>=3.5
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 34 KiB |
@@ -0,0 +1,438 @@
|
|||||||
|
# Copyright 2022 - Intel Corp. All rights reserved.
|
||||||
|
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implementation of a new method for fine-tuning transformer models that we call
|
||||||
|
Information Gain Filtration 'IGF' on WikiText data set and compared the results
|
||||||
|
with the standard fine-tuning method
|
||||||
|
|
||||||
|
Steps followed in the code:
|
||||||
|
|
||||||
|
1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
|
||||||
|
Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular
|
||||||
|
context. Informativeness is the change in metric between the model’s accuracy on an
|
||||||
|
objective set before and after seeing that context. For casual language modeling, the
|
||||||
|
metric is perplexity.
|
||||||
|
|
||||||
|
2) A secondary learner is trained to infer a function approximation for IG using the dataset
|
||||||
|
created in (1).
|
||||||
|
|
||||||
|
3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
|
||||||
|
|
||||||
|
Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Prerequisite libraries:
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
from igf.igf import (
|
||||||
|
SecondaryLearner,
|
||||||
|
collect_objective_set,
|
||||||
|
compute_perplexity,
|
||||||
|
generate_datasets,
|
||||||
|
load_gpt2,
|
||||||
|
recopy_gpt2,
|
||||||
|
set_seed,
|
||||||
|
train_secondary_learner,
|
||||||
|
)
|
||||||
|
from transformers import GPT2LMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
def generate_n_pairs(
|
||||||
|
context_len=32,
|
||||||
|
max_steps=10,
|
||||||
|
size_objective_set=100,
|
||||||
|
min_len=1026,
|
||||||
|
trim=True,
|
||||||
|
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||||
|
igf_data_file="igf_context_pairs.jbl",
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Collecting *n* pairs for training the secondary learner
|
||||||
|
Args:
|
||||||
|
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||||
|
than this will be truncated, sequences shorter will be padded
|
||||||
|
max_steps: To calculate training epochs of secondary learner
|
||||||
|
size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||||||
|
min_len: The minimum length of the article to be used as objective set
|
||||||
|
trim: If True truncate the context if it exceeds context length
|
||||||
|
data_file: Tokenized data set split for training and evaluation of model
|
||||||
|
igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Data stored in igf_data_file
|
||||||
|
|
||||||
|
"""
|
||||||
|
# generates same data everytime
|
||||||
|
set_seed(3)
|
||||||
|
# generate train_data and objective_set
|
||||||
|
train_data, objective_set = generate_datasets(
|
||||||
|
context_len, data_file, number=size_objective_set, min_len=1026, trim=True
|
||||||
|
)
|
||||||
|
# keeps model same across runs
|
||||||
|
set_seed(4)
|
||||||
|
# model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
|
||||||
|
# can we train on GPU?
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# load pretrained model
|
||||||
|
model = load_gpt2("gpt2").to(device)
|
||||||
|
print("computing perplexity on objective set")
|
||||||
|
orig_perp = compute_perplexity(model, objective_set, context_len).item()
|
||||||
|
print("perplexity on objective set:", orig_perp)
|
||||||
|
|
||||||
|
# collect igf pairs and save to file demo.jbl
|
||||||
|
collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
|
||||||
|
|
||||||
|
# clean up, delete model and data we don't need anymore
|
||||||
|
del model, train_data, objective_set
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def training_secondary_learner(
|
||||||
|
secondary_learner_train_data,
|
||||||
|
secondary_learner_max_epochs=15,
|
||||||
|
secondary_learner_batch_size=128,
|
||||||
|
eval_freq=100,
|
||||||
|
igf_model_path="igf_model.pt",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Train the secondary learner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
|
||||||
|
secondary_learner_max_epochs: Number of epochs to train secondary learner
|
||||||
|
secondary_learner_batch_size: Batch size to train secondary learner
|
||||||
|
eval_freq (object): secondary model evaluation can be triggered at eval_freq
|
||||||
|
igf_model_path: path to store trained secondary learner
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trained secondary learner
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
|
||||||
|
# Load pre-trained model
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
# Initialize secondary learner to use embedding weights of model
|
||||||
|
secondary_learner = SecondaryLearner(model)
|
||||||
|
|
||||||
|
# Train secondary learner
|
||||||
|
secondary_learner = train_secondary_learner(
|
||||||
|
secondary_learner,
|
||||||
|
secondary_learner_train_data,
|
||||||
|
max_epochs=secondary_learner_max_epochs,
|
||||||
|
batch_size=secondary_learner_batch_size,
|
||||||
|
eval_freq=100,
|
||||||
|
igf_model_path=igf_model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
del model, secondary_learner_train_data
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return secondary_learner
|
||||||
|
|
||||||
|
|
||||||
|
def finetune(
|
||||||
|
model,
|
||||||
|
train_dataset,
|
||||||
|
test_dataset,
|
||||||
|
context_len=32,
|
||||||
|
max_steps=1000,
|
||||||
|
batch_size=16,
|
||||||
|
threshold=1.0,
|
||||||
|
recopy_model=recopy_gpt2,
|
||||||
|
secondary_learner=None,
|
||||||
|
eval_interval=10,
|
||||||
|
finetuned_model_name="gpt2_finetuned.pt",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: pre-trained GPT-2 model
|
||||||
|
train_dataset: Data set to train GPT-2 model
|
||||||
|
test_dataset: Evaluate GPT-2 model
|
||||||
|
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||||
|
than this will be truncated, sequences shorter will be padded
|
||||||
|
max_steps: To calculate training epochs
|
||||||
|
batch_size: Batch size to train GPT-2 model
|
||||||
|
threshold: The threshold value used by secondary learner to filter the train_data and allow only"
|
||||||
|
informative data as input to the model
|
||||||
|
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||||
|
secondary_learner: Selection of IGF as fine-tuning method if not None
|
||||||
|
eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
|
||||||
|
1 standard deviation above average to 1 below average
|
||||||
|
fine-tuned_model_name: name of the final final-tuned GPT-2 model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fine-tuned GPT-2 model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
train_sampler = RandomSampler(train_dataset)
|
||||||
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
|
||||||
|
|
||||||
|
num_train_epochs = max_steps // (len(train_dataset)) + 1
|
||||||
|
global_step = 0
|
||||||
|
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||||||
|
model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
if secondary_learner is not None:
|
||||||
|
secondary_learner.to(device)
|
||||||
|
secondary_learner.eval()
|
||||||
|
contexts = []
|
||||||
|
examples = 0
|
||||||
|
|
||||||
|
observed_qs = []
|
||||||
|
test_perps = []
|
||||||
|
|
||||||
|
# Compute the performance of the transformer model at the beginning
|
||||||
|
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||||||
|
test_perps.append(real_perp)
|
||||||
|
print("Test perplexity, step", global_step, ":", real_perp)
|
||||||
|
for epoch in range(int(num_train_epochs)):
|
||||||
|
for step, example in enumerate(train_dataloader):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
start = random.randint(0, example.size(2) - context_len - 1)
|
||||||
|
context[0, :] = example[0, 0, start : start + context_len]
|
||||||
|
lm_optimizer.zero_grad()
|
||||||
|
outputs = model(context, labels=context)
|
||||||
|
do_backprop = True
|
||||||
|
|
||||||
|
if secondary_learner is not None:
|
||||||
|
predicted_q = secondary_learner.forward(
|
||||||
|
torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
|
||||||
|
)[0].item()
|
||||||
|
observed_qs.append(float(predicted_q))
|
||||||
|
|
||||||
|
# Here we implement the simple non-constant threshold for the predicted IG(X) value
|
||||||
|
# We will decay the selectivity of our secondary learner filter from
|
||||||
|
# 1 standard deviation above average to 1 below average after 10 batches.
|
||||||
|
|
||||||
|
if global_step == 10:
|
||||||
|
threshold = -1
|
||||||
|
if predicted_q < threshold:
|
||||||
|
do_backprop = False
|
||||||
|
|
||||||
|
# If we passed the filter, add the context to the batch!
|
||||||
|
if do_backprop:
|
||||||
|
contexts.append(np.array(context.cpu()))
|
||||||
|
lm_loss = outputs[0]
|
||||||
|
lm_loss.backward()
|
||||||
|
examples += 1
|
||||||
|
|
||||||
|
del outputs
|
||||||
|
|
||||||
|
# Once the batch is filled with enough contexts, backprop on the batch.
|
||||||
|
if examples == batch_size:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
examples = 0
|
||||||
|
# Do LM backprop
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||||||
|
lm_optimizer.step()
|
||||||
|
lm_scheduler.step() # Update learning rate schedule
|
||||||
|
global_step += 1
|
||||||
|
# Compute the performance of the transformer model at this batch
|
||||||
|
if global_step % eval_interval == 0:
|
||||||
|
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||||||
|
test_perps.append(real_perp)
|
||||||
|
|
||||||
|
print("Test perplexity, step", global_step, ":", real_perp)
|
||||||
|
# Break out of the loop after 60 batches
|
||||||
|
if max_steps > 0 and global_step > 60:
|
||||||
|
break
|
||||||
|
if max_steps > 0 and global_step > 60:
|
||||||
|
break
|
||||||
|
|
||||||
|
# save finetuned transformer model
|
||||||
|
torch.save(model.state_dict(), finetuned_model_name)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# Do some cleaning up so we can reinitialize for the next run of this function
|
||||||
|
del lm_optimizer
|
||||||
|
del lm_scheduler
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
|
||||||
|
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The input data dir. Should contain data files for WikiText.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="A jbl file containing tokenized data which can be split as objective dataset, "
|
||||||
|
"train_dataset and test_dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--igf_data_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="A jbl file containing the context and information gain pairs to train secondary learner.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the final fine-tuned model is stored.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context_len",
|
||||||
|
default=32,
|
||||||
|
type=int,
|
||||||
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--size_objective_set",
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help="number of articles that are long enough to be used as our objective set",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--secondary_learner_batch_size",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="batch size of training data for secondary learner",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) "
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_interval",
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
help="decay the selectivity of our secondary learner filter from"
|
||||||
|
"1 standard deviation above average to 1 below average after 10 batches",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--threshold",
|
||||||
|
default=1.0,
|
||||||
|
type=float,
|
||||||
|
help="The threshold value used by secondary learner to filter the train_data and allow only"
|
||||||
|
" informative data as input to the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--recopy_model",
|
||||||
|
default=recopy_gpt2,
|
||||||
|
type=str,
|
||||||
|
help="Reset the model to the original pretrained GPT-2 weights after each iteration",
|
||||||
|
)
|
||||||
|
|
||||||
|
# function calls
|
||||||
|
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
|
||||||
|
generate_n_pairs(
|
||||||
|
context_len=32,
|
||||||
|
max_steps=10,
|
||||||
|
size_objective_set=100,
|
||||||
|
min_len=1026,
|
||||||
|
trim=True,
|
||||||
|
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||||
|
igf_data_file="igf_context_pairs.jbl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load train data for secondary learner
|
||||||
|
secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
|
||||||
|
|
||||||
|
# Train secondary learner
|
||||||
|
secondary_learner = training_secondary_learner(
|
||||||
|
secondary_learner_train_data,
|
||||||
|
secondary_learner_max_epochs=15,
|
||||||
|
secondary_learner_batch_size=128,
|
||||||
|
eval_freq=100,
|
||||||
|
igf_model_path="igf_model.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# load pretrained gpt2 model
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
set_seed(42)
|
||||||
|
|
||||||
|
# Generate train and test data to train and evaluate gpt2 model
|
||||||
|
train_dataset, test_dataset = generate_datasets(
|
||||||
|
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# fine-tuning of the gpt2 model using igf (Information Gain Filtration)
|
||||||
|
finetune(
|
||||||
|
model,
|
||||||
|
train_dataset,
|
||||||
|
test_dataset,
|
||||||
|
context_len=32,
|
||||||
|
max_steps=1000,
|
||||||
|
batch_size=16,
|
||||||
|
threshold=1.0,
|
||||||
|
recopy_model=recopy_gpt2,
|
||||||
|
secondary_learner=secondary_learner,
|
||||||
|
eval_interval=10,
|
||||||
|
finetuned_model_name="gpt2_finetuned.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user