Update quality tooling for formatting (#21480)

* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
This commit is contained in:
Sylvain Gugger
2023-02-06 18:10:56 -05:00
committed by GitHub
parent b7bb2b59f7
commit 6f79d26442
1211 changed files with 1532 additions and 2687 deletions

View File

@@ -33,6 +33,7 @@ if is_flax_available():
import jax.numpy as jnp
from flax.training.common_utils import onehot
from flax.traverse_util import flatten_dict
from transformers import (
FlaxBartForCausalLM,
FlaxBertForCausalLM,
@@ -73,7 +74,7 @@ class FlaxEncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
**kwargs,
):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
@@ -103,7 +104,7 @@ class FlaxEncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@@ -142,7 +143,7 @@ class FlaxEncoderDecoderMixin:
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
@@ -169,7 +170,7 @@ class FlaxEncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
@@ -208,7 +209,7 @@ class FlaxEncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
# assert that loading encoder and decoder models from configs has been correctly executed
@@ -253,7 +254,7 @@ class FlaxEncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
**kwargs,
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
@@ -336,7 +337,7 @@ class FlaxEncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
**kwargs,
):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
@@ -406,7 +407,6 @@ class FlaxEncoderDecoderMixin:
self.assertTrue((grad == grad_frozen).all())
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
@@ -448,7 +448,6 @@ class FlaxEncoderDecoderMixin:
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
@@ -460,7 +459,6 @@ class FlaxEncoderDecoderMixin:
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
@@ -508,7 +506,6 @@ class FlaxEncoderDecoderMixin:
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")

View File

@@ -62,7 +62,7 @@ class EncoderDecoderMixin:
decoder_attention_mask,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
@@ -95,7 +95,7 @@ class EncoderDecoderMixin:
decoder_attention_mask,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@@ -135,7 +135,7 @@ class EncoderDecoderMixin:
decoder_attention_mask,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
inputs = input_values if input_features is None else input_features
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
@@ -173,7 +173,7 @@ class EncoderDecoderMixin:
return_dict,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
@@ -202,7 +202,7 @@ class EncoderDecoderMixin:
decoder_attention_mask,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@@ -245,7 +245,7 @@ class EncoderDecoderMixin:
decoder_attention_mask,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@@ -292,7 +292,7 @@ class EncoderDecoderMixin:
labels=None,
input_values=None,
input_features=None,
**kwargs
**kwargs,
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]