dilbert -> distilbert
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocessing script before training DilBERT.
|
||||
Preprocessing script before training DistilBERT.
|
||||
"""
|
||||
from pytorch_transformers import BertForPreTraining
|
||||
import torch
|
||||
@@ -33,32 +33,32 @@ if __name__ == '__main__':
|
||||
compressed_sd = {}
|
||||
|
||||
for w in ['word_embeddings', 'position_embeddings']:
|
||||
compressed_sd[f'dilbert.embeddings.{w}.weight'] = \
|
||||
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
|
||||
state_dict[f'bert.embeddings.{w}.weight']
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'dilbert.embeddings.LayerNorm.{w}'] = \
|
||||
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
|
||||
state_dict[f'bert.embeddings.LayerNorm.{w}']
|
||||
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}']
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}']
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}']
|
||||
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
|
||||
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}']
|
||||
compressed_sd[f'dilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
||||
std_idx += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user