Add bloom flax (#25094)
* First commit * step 1 working * add alibi * placeholder for `scan` * add matrix mult alibi * beta scaling factor for bmm * working v1 - simple forward pass * move layer_number from attribute to arg in call * partial functioning scan * hacky working scan * add more modifs * add test * update scan for new kwarg order * fix position_ids problem * fix bug in attention layer * small fix - do the alibi broadcasting only once * prelim refactor * finish refactor * alibi shifting * incorporate dropout_add to attention module * make style * make padding work again * update * remove bogus file * up * get generation to work * clean code a bit * added small tests * adding albii test * make CI tests pass: - change init weight - add correct tuple for output attention - add scan test - make CI tests work * fix few nits * fix nit onnx * fix onnx nit * add missing dtype args to nn.Modules * remove debugging statements * fix scan generate * Update modeling_flax_bloom.py * Update test_modeling_flax_bloom.py * Update test_modeling_flax_bloom.py * Update test_modeling_flax_bloom.py * fix small test issue + make style * clean up * Update tests/models/bloom/test_modeling_flax_bloom.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fix function name * small fix test * forward contrib credits from PR17761 * Fix failing test * fix small typo documentation * fix non passing test - remove device from build alibi * refactor call - refactor `FlaxBloomBlockCollection` module * make style * upcast to fp32 * cleaner way to upcast * remove unused args * remove layer number * fix scan test * make style * fix i4 casting * fix slow test * Update src/transformers/models/bloom/modeling_flax_bloom.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove `layer_past` * refactor a bit * fix `scan` slow test * remove useless import * major changes - remove unused code - refactor a bit - revert import `torch` * major refactoring - change build alibi * remove scan * fix tests * make style * clean-up alibi * add integration tests * up * fix batch norm conversion * style * style * update pt-fx cross tests * update copyright * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * per-weight check * style * line formats --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: haileyschoelkopf <haileyschoelkopf@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -487,6 +487,33 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True),
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_batch_generated_text(self):
|
||||
path_560m = "bigscience/bloom-560m"
|
||||
|
||||
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").cuda()
|
||||
model = model.eval()
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(path_560m, padding_side="left")
|
||||
|
||||
input_sentences = [
|
||||
"Hello what is",
|
||||
"Running a quick test with the",
|
||||
]
|
||||
inputs = tokenizer(input_sentences, return_tensors="pt", padding=True, truncation=True)
|
||||
generated_ids = model.generate(
|
||||
inputs["input_ids"].cuda(), attention_mask=inputs["attention_mask"], max_length=20
|
||||
)
|
||||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# these generations match those of the PyTorch model
|
||||
EXPECTED_GENERATIONS = [
|
||||
"Hello what is the best way to get the data from the server? I have tried",
|
||||
"Running a quick test with the following command:\nsudo apt-get install python3\nsudo apt-get install python2",
|
||||
]
|
||||
|
||||
self.assertListEqual(generated_text, EXPECTED_GENERATIONS)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BloomEmbeddingTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user