Compare commits

...

3 Commits

Author SHA1 Message Date
Sylvain Gugger
cd48078ce5 Release: v4.3.2
Some checks failed
Model templates runner / run_tests_templates (push) Has been cancelled
Release - Conda / build_and_package (push) Has been cancelled
2021-02-09 14:07:52 -05:00
Suraj Patil
727ab9d398 [RAG] fix generate (#10094)
* fix rag generate and tests

* put back adjust_logits_during_generation

* tests are okay

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-02-09 14:04:32 -05:00
Patrick von Platen
c95fae6d65 fix import (#10103) 2021-02-09 14:03:17 -05:00
4 changed files with 13 additions and 3 deletions

View File

@@ -282,7 +282,7 @@ install_requires = [
setup(
name="transformers",
version="4.3.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.3.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.3.1"
__version__ = "4.3.2"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.

View File

@@ -144,7 +144,11 @@ try:
_faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
try:
_faiss_version = importlib_metadata.version("faiss-cpu")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_scatter_available = importlib.util.find_spec("torch_scatter") is not None

View File

@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
encoder_no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None,
num_return_sequences=None,
@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,