BLIPs clean-up (#35560)
* blips clean up * update processor * readability * fix processor length * fix copies * tmp * update and fix copies * why keep these, delete? * fix test fetcher * irrelevant comment * fix tests * fix tests * fix copies
This commit is contained in:
committed by
GitHub
parent
4f8f51be4e
commit
75794792ad
@@ -1539,16 +1539,25 @@ class Blip2Model(Blip2PreTrainedModel):
|
|||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
language_model_inputs = self.language_projection(query_output)
|
language_model_inputs = self.language_projection(query_output)
|
||||||
language_model_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1)
|
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
expected_device = language_model_attention_mask.device
|
|
||||||
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)
|
if input_ids is None:
|
||||||
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
)
|
||||||
|
special_image_mask = special_image_mask.all(-1)
|
||||||
|
else:
|
||||||
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
|
|
||||||
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||||
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||||
|
special_image_mask, language_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.use_decoder_only_language_model:
|
if self.config.use_decoder_only_language_model:
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
@@ -2026,9 +2035,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
language_model_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
@@ -2036,34 +2042,19 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
||||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_id
|
|
||||||
|
|
||||||
special_image_mask = (
|
|
||||||
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
|
||||||
)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
|
||||||
special_image_mask, language_model_inputs
|
|
||||||
)
|
)
|
||||||
|
special_image_mask = special_image_mask.all(-1)
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
)
|
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
special_image_mask, language_model_inputs
|
||||||
attention_mask = torch.cat(
|
)
|
||||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.use_decoder_only_language_model:
|
if self.config.use_decoder_only_language_model:
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
@@ -2172,15 +2163,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
query_output = query_output.to(image_embeds.dtype)
|
query_output = query_output.to(image_embeds.dtype)
|
||||||
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
language_model_inputs = self.language_projection(query_output)
|
||||||
language_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
start_tokens = [self.config.text_config.bos_token_id]
|
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
||||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
@@ -2188,53 +2175,24 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
||||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_id
|
|
||||||
|
|
||||||
special_image_mask = (
|
|
||||||
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
|
||||||
)
|
)
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
special_image_mask = special_image_mask.all(-1)
|
||||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
|
||||||
special_image_mask, language_model_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = attention_mask.to(language_attention_mask.device)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
||||||
)
|
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
# -1 is to account for the prepended BOS after `generate.`
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
|
||||||
generate_kwargs["max_length"] = (
|
|
||||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
|
||||||
)
|
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
|
||||||
|
|
||||||
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
if input_ids is not None:
|
|
||||||
input_ids = input_ids.to(language_model_inputs.device)
|
|
||||||
inputs["input_ids"] = input_ids
|
inputs["input_ids"] = input_ids
|
||||||
|
|
||||||
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -2362,8 +2320,13 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
|||||||
|
|
||||||
if use_image_text_matching_head:
|
if use_image_text_matching_head:
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device)
|
if self.config.image_token_index is not None:
|
||||||
attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)
|
input_ids = input_ids[:, self.config.num_query_tokens :]
|
||||||
|
else:
|
||||||
|
query_attention_mask = torch.ones(
|
||||||
|
query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device
|
||||||
|
)
|
||||||
|
attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)
|
||||||
|
|
||||||
query_embeds = self.embeddings(
|
query_embeds = self.embeddings(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -2395,6 +2358,10 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
|||||||
image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state
|
image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state
|
||||||
image_embeds = image_embeds.to(dtype=self.vision_projection.weight.dtype)
|
image_embeds = image_embeds.to(dtype=self.vision_projection.weight.dtype)
|
||||||
|
|
||||||
|
if self.config.image_token_index is not None:
|
||||||
|
input_ids = input_ids[:, self.config.num_query_tokens :]
|
||||||
|
attention_mask = attention_mask[:, self.config.num_query_tokens :]
|
||||||
|
|
||||||
query_embeds = self.embeddings(
|
query_embeds = self.embeddings(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -112,11 +112,13 @@ class Blip2Processor(ProcessorMixin):
|
|||||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# BC for explicit return_tensors
|
# BC for explicit return_tensors
|
||||||
if "return_tensors" in output_kwargs["common_kwargs"]:
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
|
max_length = output_kwargs["text_kwargs"].pop("max_length", None)
|
||||||
else:
|
if max_length is not None:
|
||||||
return_tensors = None
|
output_kwargs["text_kwargs"]["max_length"] = max_length - self.num_query_tokens
|
||||||
|
|
||||||
encoding = BatchFeature(tensor_type=return_tensors)
|
encoding = BatchFeature(tensor_type=return_tensors)
|
||||||
if text is not None:
|
if text is not None:
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
@@ -124,40 +126,28 @@ class Blip2Processor(ProcessorMixin):
|
|||||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
text_encoding = {}
|
# We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
|
||||||
|
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
if images is not None and self.num_query_tokens is not None:
|
||||||
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
# Image tokens should not be padded/truncated or prepended with special BOS token
|
||||||
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
|
|
||||||
|
|
||||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
|
||||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
|
||||||
if self.num_query_tokens is not None:
|
|
||||||
image_tokens = self.image_token.content * self.num_query_tokens
|
image_tokens = self.image_token.content * self.num_query_tokens
|
||||||
image_token_encoding = self.tokenizer(
|
output_kwargs["text_kwargs"]["add_special_tokens"] = False
|
||||||
[image_tokens] * len(text), add_special_tokens=False, return_tensors=None
|
output_kwargs["text_kwargs"]["padding"] = False
|
||||||
)
|
output_kwargs["text_kwargs"]["truncation"] = False
|
||||||
for k in _text_encoding:
|
image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"])
|
||||||
text_encoding[k] = [
|
for k in text_encoding:
|
||||||
img_encoding + txt_encoding
|
text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]]
|
||||||
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
|
encoding.update(text_encoding)
|
||||||
]
|
|
||||||
else:
|
|
||||||
text_encoding = _text_encoding
|
|
||||||
logger.warning_once(
|
|
||||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
||||||
)
|
|
||||||
|
|
||||||
# cast to desired return tensors type
|
# Now add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
|
||||||
encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors))
|
|
||||||
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
|
|
||||||
# else, return the text encoding.
|
# else, return the text encoding.
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
encoding.update(image_encoding)
|
encoding.update(image_encoding)
|
||||||
|
|
||||||
|
# Cast to desired return tensors type
|
||||||
|
encoding = BatchFeature(encoding, tensor_type=return_tensors)
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||||
|
|||||||
@@ -799,7 +799,7 @@ class InstructBlipQFormerLayer(GradientCheckpointingLayer):
|
|||||||
self.chunk_size_feed_forward,
|
self.chunk_size_feed_forward,
|
||||||
self.seq_len_dim,
|
self.seq_len_dim,
|
||||||
attention_output[:, query_length:, :],
|
attention_output[:, query_length:, :],
|
||||||
)
|
).to(layer_output.device)
|
||||||
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
||||||
else:
|
else:
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
@@ -1560,9 +1560,6 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
)
|
)
|
||||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
language_model_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
@@ -1570,30 +1567,17 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
)
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
special_image_mask = special_image_mask.all(-1)
|
||||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_id
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
)
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.use_decoder_only_language_model:
|
if self.config.use_decoder_only_language_model:
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
@@ -1682,54 +1666,29 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
language_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
start_tokens = [self.config.text_config.bos_token_id]
|
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
||||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
)
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
special_image_mask = special_image_mask.all(-1)
|
||||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_id
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
||||||
)
|
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
# -1 is to account for the prepended BOS after `generate.`
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
generate_kwargs["max_length"] = (
|
|
||||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
|
||||||
)
|
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
|
||||||
|
|
||||||
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import Union
|
|||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding, PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto import AutoTokenizer
|
from ..auto import AutoTokenizer
|
||||||
|
|
||||||
@@ -78,6 +78,7 @@ class InstructBlipProcessor(ProcessorMixin):
|
|||||||
else:
|
else:
|
||||||
self.image_token = tokenizer.image_token
|
self.image_token = tokenizer.image_token
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -111,52 +112,40 @@ class InstructBlipProcessor(ProcessorMixin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoding = BatchFeature()
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
encoding = {}
|
||||||
if text is not None:
|
if text is not None:
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = [text]
|
text = [text]
|
||||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
# we have to concatenate lists - so we keep track of return_tensors here
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
|
||||||
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
|
||||||
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
|
|
||||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
|
||||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
|
||||||
if self.num_query_tokens is not None and images is not None:
|
|
||||||
text_encoding = {}
|
|
||||||
image_tokens = self.image_token.content * self.num_query_tokens
|
|
||||||
image_token_encoding = self.tokenizer(
|
|
||||||
[image_tokens] * len(text), add_special_tokens=False, return_tensors=None
|
|
||||||
)
|
|
||||||
for k in _text_encoding:
|
|
||||||
text_encoding[k] = [
|
|
||||||
img_encoding + txt_encoding
|
|
||||||
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
text_encoding = _text_encoding
|
|
||||||
if images is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
||||||
)
|
|
||||||
|
|
||||||
# cast to desired return tensors type after concatenating
|
|
||||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
encoding.update(text_encoding)
|
|
||||||
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
|
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
||||||
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
||||||
|
|
||||||
|
# We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
|
||||||
|
if output_kwargs["text_kwargs"].get("max_length") is not None:
|
||||||
|
output_kwargs["text_kwargs"]["max_length"] -= self.num_query_tokens
|
||||||
|
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
# Image tokens should not be padded/truncated or prepended with special BOS token
|
||||||
|
image_tokens = self.image_token.content * self.num_query_tokens
|
||||||
|
output_kwargs["text_kwargs"]["add_special_tokens"] = False
|
||||||
|
output_kwargs["text_kwargs"]["padding"] = False
|
||||||
|
output_kwargs["text_kwargs"]["truncation"] = False
|
||||||
|
image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"])
|
||||||
|
for k in text_encoding:
|
||||||
|
text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]]
|
||||||
|
encoding.update(text_encoding)
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
encoding.update(image_encoding)
|
encoding.update(image_encoding)
|
||||||
|
|
||||||
|
# Cast to desired return tensors type
|
||||||
|
encoding = BatchFeature(encoding, tensor_type=return_tensors)
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||||
|
|||||||
@@ -660,7 +660,7 @@ class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
|
|||||||
self.chunk_size_feed_forward,
|
self.chunk_size_feed_forward,
|
||||||
self.seq_len_dim,
|
self.seq_len_dim,
|
||||||
attention_output[:, query_length:, :],
|
attention_output[:, query_length:, :],
|
||||||
)
|
).to(layer_output.device)
|
||||||
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
||||||
else:
|
else:
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
@@ -1527,9 +1527,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
)
|
)
|
||||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
language_model_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
@@ -1537,30 +1534,17 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
)
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
special_image_mask = special_image_mask.all(-1)
|
||||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.video_token_id
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.video_token_id
|
||||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.54."
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
)
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.use_decoder_only_language_model:
|
if self.config.use_decoder_only_language_model:
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
@@ -1650,54 +1634,28 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
language_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
start_tokens = [self.config.text_config.bos_token_id]
|
video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
start_tokens = video_tokens + [self.config.text_config.bos_token_id]
|
||||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
)
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
special_image_mask = special_image_mask.all(-1)
|
||||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.video_token_id
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.video_token_id
|
||||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.54."
|
|
||||||
)
|
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
# -1 is to account for the prepended BOS after `generate.`
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
generate_kwargs["max_length"] = (
|
|
||||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
|
||||||
)
|
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
|
||||||
|
|
||||||
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
|
|||||||
@@ -464,9 +464,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
)
|
)
|
||||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
language_model_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
@@ -474,30 +471,17 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
)
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
special_image_mask = special_image_mask.all(-1)
|
||||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.video_token_id
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.video_token_id
|
||||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.54."
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
)
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.use_decoder_only_language_model:
|
if self.config.use_decoder_only_language_model:
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
@@ -587,54 +571,28 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
language_attention_mask = torch.ones(
|
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
start_tokens = [self.config.text_config.bos_token_id]
|
video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
start_tokens = video_tokens + [self.config.text_config.bos_token_id]
|
||||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
if input_ids is None:
|
||||||
# otherwise we expand manually by concatenating
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
if input_ids is None:
|
)
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
special_image_mask = special_image_mask.all(-1)
|
||||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.all(-1)
|
|
||||||
else:
|
|
||||||
special_image_mask = input_ids == self.config.video_token_id
|
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
special_image_mask = input_ids == self.config.video_token_id
|
||||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.54."
|
|
||||||
)
|
|
||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
# -1 is to account for the prepended BOS after `generate.`
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||||
generate_kwargs["max_length"] = (
|
|
||||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
|
||||||
)
|
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
|
||||||
|
|
||||||
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from ...image_processing_utils import BatchFeature
|
|||||||
from ...processing_utils import ProcessorMixin
|
from ...processing_utils import ProcessorMixin
|
||||||
from ...tokenization_utils_base import (
|
from ...tokenization_utils_base import (
|
||||||
AddedToken,
|
AddedToken,
|
||||||
BatchEncoding,
|
|
||||||
PaddingStrategy,
|
PaddingStrategy,
|
||||||
PreTokenizedInput,
|
PreTokenizedInput,
|
||||||
TextInput,
|
TextInput,
|
||||||
@@ -99,60 +98,13 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
|||||||
if images is None and text is None:
|
if images is None and text is None:
|
||||||
raise ValueError("You have to specify at least one of images or text.")
|
raise ValueError("You have to specify at least one of images or text.")
|
||||||
|
|
||||||
encoding = BatchFeature()
|
encoding = {}
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = [text]
|
text = [text]
|
||||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
_text_encoding = self.tokenizer(
|
|
||||||
text=text,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
max_length=max_length,
|
|
||||||
stride=stride,
|
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_attention_mask=return_attention_mask,
|
|
||||||
return_overflowing_tokens=return_overflowing_tokens,
|
|
||||||
return_special_tokens_mask=return_special_tokens_mask,
|
|
||||||
return_offsets_mapping=return_offsets_mapping,
|
|
||||||
return_token_type_ids=return_token_type_ids,
|
|
||||||
return_length=return_length,
|
|
||||||
verbose=verbose,
|
|
||||||
return_tensors=None, # required to concatenate below
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
|
||||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
|
||||||
if self.num_query_tokens is not None and images is not None:
|
|
||||||
text_encoding = {}
|
|
||||||
video_tokens = (
|
|
||||||
self.video_token.content * self.num_query_tokens * 4
|
|
||||||
) # InstrucBLIP works with 4 frames only
|
|
||||||
video_token_encoding = self.tokenizer(
|
|
||||||
[video_tokens] * len(text), add_special_tokens=False, return_tensors=None
|
|
||||||
)
|
|
||||||
for k in _text_encoding:
|
|
||||||
text_encoding[k] = [
|
|
||||||
img_encoding + txt_encoding
|
|
||||||
for img_encoding, txt_encoding in zip(video_token_encoding[k], _text_encoding[k])
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
text_encoding = _text_encoding
|
|
||||||
if images is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
|
||||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.54."
|
|
||||||
)
|
|
||||||
|
|
||||||
# cast to desired return tensors type after concatenating
|
|
||||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
|
||||||
encoding.update(text_encoding)
|
|
||||||
qformer_text_encoding = self.qformer_tokenizer(
|
qformer_text_encoding = self.qformer_tokenizer(
|
||||||
text=text,
|
text=text,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
@@ -174,10 +126,51 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
|||||||
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
||||||
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
||||||
|
|
||||||
|
# We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
|
||||||
|
# InstrucBLIP works with 4 frames only
|
||||||
|
if max_length is not None:
|
||||||
|
max_length -= self.num_query_tokens
|
||||||
|
text_encoding = self.tokenizer(
|
||||||
|
text=text,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
stride=stride,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
return_overflowing_tokens=return_overflowing_tokens,
|
||||||
|
return_special_tokens_mask=return_special_tokens_mask,
|
||||||
|
return_offsets_mapping=return_offsets_mapping,
|
||||||
|
return_token_type_ids=return_token_type_ids,
|
||||||
|
return_length=return_length,
|
||||||
|
verbose=verbose,
|
||||||
|
return_tensors=None, # required to concatenate below
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
video_tokens = self.video_token.content * self.num_query_tokens * 4
|
||||||
|
video_text_encoding = self.tokenizer(
|
||||||
|
video_tokens,
|
||||||
|
add_special_tokens=False, # required to concatenate below
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
return_overflowing_tokens=return_overflowing_tokens,
|
||||||
|
return_special_tokens_mask=return_special_tokens_mask,
|
||||||
|
return_offsets_mapping=return_offsets_mapping,
|
||||||
|
return_token_type_ids=return_token_type_ids,
|
||||||
|
return_length=return_length,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
for k in text_encoding:
|
||||||
|
text_encoding[k] = [video_text_encoding[k] + sample for sample in text_encoding[k]]
|
||||||
|
encoding.update(text_encoding)
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_encoding = self.video_processor(images, return_tensors=return_tensors)
|
image_encoding = self.video_processor(images, return_tensors=return_tensors)
|
||||||
encoding.update(image_encoding)
|
encoding.update(image_encoding)
|
||||||
|
|
||||||
|
encoding = BatchFeature(encoding, tensor_type=return_tensors)
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
require_vision,
|
require_vision,
|
||||||
@@ -777,7 +778,14 @@ class Blip2TextModelTester:
|
|||||||
# this model tester uses an encoder-decoder language model (T5)
|
# this model tester uses an encoder-decoder language model (T5)
|
||||||
class Blip2ModelTester:
|
class Blip2ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, parent, vision_kwargs=None, qformer_kwargs=None, text_kwargs=None, is_training=True, num_query_tokens=10
|
self,
|
||||||
|
parent,
|
||||||
|
vision_kwargs=None,
|
||||||
|
qformer_kwargs=None,
|
||||||
|
text_kwargs=None,
|
||||||
|
is_training=True,
|
||||||
|
num_query_tokens=10,
|
||||||
|
image_token_index=4,
|
||||||
):
|
):
|
||||||
if vision_kwargs is None:
|
if vision_kwargs is None:
|
||||||
vision_kwargs = {}
|
vision_kwargs = {}
|
||||||
@@ -792,11 +800,10 @@ class Blip2ModelTester:
|
|||||||
self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs)
|
self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs)
|
||||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||||
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
||||||
self.encoder_seq_length = (
|
self.encoder_seq_length = self.text_model_tester.encoder_seq_length
|
||||||
self.text_model_tester.encoder_seq_length + num_query_tokens
|
|
||||||
) # need enc seq_length for gen tests
|
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
|
self.image_token_index = image_token_index
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||||
@@ -819,6 +826,7 @@ class Blip2ModelTester:
|
|||||||
qformer_config=self.qformer_model_tester.get_config(),
|
qformer_config=self.qformer_model_tester.get_config(),
|
||||||
text_config=self.text_model_tester.get_config(),
|
text_config=self.text_model_tester.get_config(),
|
||||||
num_query_tokens=self.num_query_tokens,
|
num_query_tokens=self.num_query_tokens,
|
||||||
|
image_token_index=self.image_token_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_for_conditional_generation(
|
def create_and_check_for_conditional_generation(
|
||||||
@@ -1872,37 +1880,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
|
self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
|
||||||
self.assertEqual(generated_text, expected_ids_and_text[1])
|
self.assertEqual(generated_text, expected_ids_and_text[1])
|
||||||
|
|
||||||
def test_expansion_in_processing(self):
|
@require_torch_gpu
|
||||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
||||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
|
||||||
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
image = prepare_img()
|
|
||||||
prompt = "Question: which city is this? Answer:"
|
|
||||||
|
|
||||||
# Make sure we will go the legacy path by setting these args to None
|
|
||||||
processor.num_query_tokens = None
|
|
||||||
model.config.image_token_index = None
|
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
|
||||||
|
|
||||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
|
||||||
processor.num_query_tokens = model.config.num_query_tokens
|
|
||||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
|
||||||
model.config.image_token_index = len(processor.tokenizer) - 1
|
|
||||||
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
|
|
||||||
|
|
||||||
# Generate again with new inputs
|
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
|
||||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
|
||||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
self.assertTrue(generated_text_expanded == generated_text)
|
|
||||||
|
|
||||||
@require_torch_accelerator
|
|
||||||
def test_inference_itm(self):
|
def test_inference_itm(self):
|
||||||
model_name = "Salesforce/blip2-itm-vit-g"
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
processor = Blip2Processor.from_pretrained(model_name)
|
processor = Blip2Processor.from_pretrained(model_name)
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_image_processor(self, **kwargs):
|
def get_image_processor(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {"num_query_tokens": 1}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||||
@@ -84,26 +87,12 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
for key in input_feat_extract.keys():
|
for key in input_feat_extract.keys():
|
||||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
def test_tokenizer(self):
|
|
||||||
image_processor = self.get_image_processor()
|
|
||||||
tokenizer = self.get_tokenizer()
|
|
||||||
|
|
||||||
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor)
|
|
||||||
|
|
||||||
input_str = "lower newer"
|
|
||||||
|
|
||||||
encoded_processor = processor(text=input_str)
|
|
||||||
|
|
||||||
encoded_tok = tokenizer(input_str, return_token_type_ids=False)
|
|
||||||
|
|
||||||
for key in encoded_tok.keys():
|
|
||||||
self.assertListEqual(encoded_tok[key], encoded_processor[key][0])
|
|
||||||
|
|
||||||
def test_processor(self):
|
def test_processor(self):
|
||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor)
|
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs)
|
||||||
|
|
||||||
input_str = "lower newer"
|
input_str = "lower newer"
|
||||||
image_input = self.prepare_image_inputs()
|
image_input = self.prepare_image_inputs()
|
||||||
@@ -119,8 +108,9 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def test_tokenizer_decode(self):
|
def test_tokenizer_decode(self):
|
||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor)
|
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs)
|
||||||
|
|
||||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||||
|
|
||||||
@@ -132,8 +122,9 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def test_model_input_names(self):
|
def test_model_input_names(self):
|
||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor)
|
processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs)
|
||||||
|
|
||||||
input_str = "lower newer"
|
input_str = "lower newer"
|
||||||
image_input = self.prepare_image_inputs()
|
image_input = self.prepare_image_inputs()
|
||||||
|
|||||||
@@ -809,34 +809,3 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
|||||||
predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1]
|
predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1]
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.")
|
self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.")
|
||||||
|
|
||||||
def test_expansion_in_processing(self):
|
|
||||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
|
|
||||||
model = InstructBlipForConditionalGeneration.from_pretrained(
|
|
||||||
"Salesforce/instructblip-flan-t5-xl",
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
image = prepare_img()
|
|
||||||
prompt = "What's in the image?"
|
|
||||||
|
|
||||||
# Make sure we will go the legacy path by setting these args to None
|
|
||||||
processor.num_query_tokens = None
|
|
||||||
model.config.image_token_index = None
|
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
|
||||||
|
|
||||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
|
||||||
processor.num_query_tokens = model.config.num_query_tokens
|
|
||||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
|
||||||
model.config.image_token_index = len(processor.tokenizer) - 2
|
|
||||||
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
|
|
||||||
|
|
||||||
# Generate again with new inputs
|
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
|
||||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
|
||||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
self.assertTrue(generated_text_expanded == generated_text)
|
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_qformer_tokenizer(self, **kwargs):
|
def get_qformer_tokenizer(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {"num_query_tokens": 1}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||||
@@ -90,9 +93,13 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipProcessor(
|
processor = InstructBlipProcessor(
|
||||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_input = self.prepare_image_inputs()
|
image_input = self.prepare_image_inputs()
|
||||||
@@ -103,35 +110,17 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
for key in input_feat_extract.keys():
|
for key in input_feat_extract.keys():
|
||||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
def test_tokenizer(self):
|
|
||||||
image_processor = self.get_image_processor()
|
|
||||||
tokenizer = self.get_tokenizer()
|
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
|
||||||
|
|
||||||
processor = InstructBlipProcessor(
|
|
||||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
input_str = ["lower newer"]
|
|
||||||
|
|
||||||
encoded_processor = processor(text=input_str)
|
|
||||||
|
|
||||||
encoded_tokens = tokenizer(input_str, return_token_type_ids=False)
|
|
||||||
encoded_tokens_qformer = qformer_tokenizer(input_str, return_token_type_ids=False)
|
|
||||||
|
|
||||||
for key in encoded_tokens.keys():
|
|
||||||
self.assertListEqual(encoded_tokens[key], encoded_processor[key])
|
|
||||||
|
|
||||||
for key in encoded_tokens_qformer.keys():
|
|
||||||
self.assertListEqual(encoded_tokens_qformer[key], encoded_processor["qformer_" + key])
|
|
||||||
|
|
||||||
def test_processor(self):
|
def test_processor(self):
|
||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipProcessor(
|
processor = InstructBlipProcessor(
|
||||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_str = "lower newer"
|
input_str = "lower newer"
|
||||||
@@ -141,7 +130,7 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(inputs.keys()),
|
list(inputs.keys()),
|
||||||
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
|
["qformer_input_ids", "qformer_attention_mask", "input_ids", "attention_mask", "pixel_values"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# test if it raises when no input is passed
|
# test if it raises when no input is passed
|
||||||
@@ -152,9 +141,13 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipProcessor(
|
processor = InstructBlipProcessor(
|
||||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||||
@@ -168,9 +161,13 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipProcessor(
|
processor = InstructBlipProcessor(
|
||||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_str = "lower newer"
|
input_str = "lower newer"
|
||||||
@@ -180,5 +177,5 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(inputs.keys()),
|
list(inputs.keys()),
|
||||||
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
|
["qformer_input_ids", "qformer_attention_mask", "input_ids", "attention_mask", "pixel_values"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -750,34 +750,3 @@ class InstructBlipVideoModelIntegrationTest(unittest.TestCase):
|
|||||||
generated_text,
|
generated_text,
|
||||||
"Explain what is happening in this short video. a baby girl wearing glasses is reading a book on the bed 1080p",
|
"Explain what is happening in this short video. a baby girl wearing glasses is reading a book on the bed 1080p",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_expansion_in_processing(self):
|
|
||||||
processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
|
||||||
model = InstructBlipVideoForConditionalGeneration.from_pretrained(
|
|
||||||
"Salesforce/instructblip-vicuna-7b",
|
|
||||||
load_in_8bit=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip = prepare_video()
|
|
||||||
prompt = "Explain what is happening in this short video."
|
|
||||||
|
|
||||||
# Make sure we will go the legacy path by setting these args to None
|
|
||||||
processor.num_query_tokens = None
|
|
||||||
model.config.video_token_index = None
|
|
||||||
inputs = processor(images=clip, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
|
||||||
|
|
||||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
|
||||||
processor.num_query_tokens = model.config.num_query_tokens
|
|
||||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<video>"]})
|
|
||||||
model.config.video_token_index = len(processor.tokenizer) - 1
|
|
||||||
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64)
|
|
||||||
|
|
||||||
# Generate again with new inputs
|
|
||||||
inputs = processor(images=clip, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
|
||||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
|
||||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
self.assertTrue(generated_text_expanded == generated_text)
|
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_qformer_tokenizer(self, **kwargs):
|
def get_qformer_tokenizer(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {"num_query_tokens": 1}
|
||||||
|
|
||||||
def get_video_processor(self, **kwargs):
|
def get_video_processor(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||||
|
|
||||||
@@ -93,9 +96,13 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
video_processor = self.get_video_processor()
|
video_processor = self.get_video_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipVideoProcessor(
|
processor = InstructBlipVideoProcessor(
|
||||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
video_processor=video_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_input = self.prepare_image_inputs()
|
image_input = self.prepare_image_inputs()
|
||||||
@@ -110,15 +117,17 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
video_processor = self.get_video_processor()
|
video_processor = self.get_video_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipVideoProcessor(
|
processor = InstructBlipVideoProcessor(
|
||||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
video_processor=video_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_str = ["lower newer"]
|
input_str = ["lower newer"]
|
||||||
|
|
||||||
encoded_processor = processor(text=input_str)
|
encoded_processor = processor(text=input_str)
|
||||||
|
|
||||||
encoded_tokens = tokenizer(input_str, return_token_type_ids=False)
|
encoded_tokens = tokenizer(input_str, return_token_type_ids=False)
|
||||||
encoded_tokens_qformer = qformer_tokenizer(input_str, return_token_type_ids=False)
|
encoded_tokens_qformer = qformer_tokenizer(input_str, return_token_type_ids=False)
|
||||||
|
|
||||||
@@ -132,9 +141,13 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
video_processor = self.get_video_processor()
|
video_processor = self.get_video_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipVideoProcessor(
|
processor = InstructBlipVideoProcessor(
|
||||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
video_processor=video_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_str = "lower newer"
|
input_str = "lower newer"
|
||||||
@@ -144,7 +157,7 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(inputs.keys()),
|
list(inputs.keys()),
|
||||||
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
|
["qformer_input_ids", "qformer_attention_mask", "input_ids", "attention_mask", "pixel_values"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# test if it raises when no input is passed
|
# test if it raises when no input is passed
|
||||||
@@ -155,9 +168,13 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
video_processor = self.get_video_processor()
|
video_processor = self.get_video_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipVideoProcessor(
|
processor = InstructBlipVideoProcessor(
|
||||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
video_processor=video_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||||
@@ -171,9 +188,13 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
video_processor = self.get_video_processor()
|
video_processor = self.get_video_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = InstructBlipVideoProcessor(
|
processor = InstructBlipVideoProcessor(
|
||||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
tokenizer=tokenizer,
|
||||||
|
video_processor=video_processor,
|
||||||
|
qformer_tokenizer=qformer_tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_str = "lower newer"
|
input_str = "lower newer"
|
||||||
@@ -183,5 +204,5 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(inputs.keys()),
|
list(inputs.keys()),
|
||||||
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
|
["qformer_input_ids", "qformer_attention_mask", "input_ids", "attention_mask", "pixel_values"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -626,7 +626,7 @@
|
|||||||
"model_classes": [
|
"model_classes": [
|
||||||
"Blip2ForConditionalGeneration"
|
"Blip2ForConditionalGeneration"
|
||||||
],
|
],
|
||||||
"sha": "35e1ef43da3554af62eb29a7b3dbbef3f3bef48e"
|
"sha": "d0de11fd1f8ca481231c07ee0934924be96cb281"
|
||||||
},
|
},
|
||||||
"Blip2Model": {
|
"Blip2Model": {
|
||||||
"tokenizer_classes": [
|
"tokenizer_classes": [
|
||||||
|
|||||||
@@ -52,11 +52,9 @@ python utils/tests_fetcher.py --diff_with_last_commit
|
|||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import glob
|
import glob
|
||||||
import importlib.util
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
@@ -323,58 +321,30 @@ def get_impacted_files_from_tiny_model_summary(diff_with_last_commit: bool = Fal
|
|||||||
if key in new_keys:
|
if key in new_keys:
|
||||||
impacted_model_classes.extend(new_content[key]["model_classes"])
|
impacted_model_classes.extend(new_content[key]["model_classes"])
|
||||||
|
|
||||||
# get the module where the model classes are defined. We want to use the main `__init__` file, but it requires
|
# Add imports via `define_import_structure` after the #35167 as we remove explicit import in `__init__.py`
|
||||||
# all the framework being installed, which is not ideal for a simple script like test fetcher.
|
from transformers.utils.import_utils import define_import_structure
|
||||||
# So we create a temporary and modified main `__init__` and access its `_import_structure`.
|
|
||||||
with open(folder / "src/transformers/__init__.py") as fp:
|
|
||||||
lines = fp.readlines()
|
|
||||||
new_lines = []
|
|
||||||
# Get all the code related to `_import_structure`
|
|
||||||
for line in lines:
|
|
||||||
if line == "_import_structure = {\n":
|
|
||||||
new_lines.append(line)
|
|
||||||
elif line == "# Direct imports for type-checking\n":
|
|
||||||
break
|
|
||||||
elif len(new_lines) > 0:
|
|
||||||
# bypass the framework check so we can get all the information even if frameworks are not available
|
|
||||||
line = re.sub(r"is_.+_available\(\)", "True", line)
|
|
||||||
line = line.replace("OptionalDependencyNotAvailable", "Exception")
|
|
||||||
line = line.replace("Exception()", "Exception")
|
|
||||||
new_lines.append(line)
|
|
||||||
|
|
||||||
# create and load the temporary module
|
reversed_structure = {}
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
new_imported_modules_from_import_structure = define_import_structure("src/transformers/__init__.py")
|
||||||
with open(os.path.join(tmpdirname, "temp_init.py"), "w") as fp:
|
for mapping in new_imported_modules_from_import_structure.values():
|
||||||
fp.write("".join(new_lines))
|
for _module, _imports in mapping.items():
|
||||||
|
for _import in _imports:
|
||||||
|
reversed_structure[_import] = _module
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("temp_init", os.path.join(tmpdirname, "temp_init.py"))
|
# Get the corresponding modeling file path
|
||||||
module = importlib.util.module_from_spec(spec)
|
for model_class in impacted_model_classes:
|
||||||
spec.loader.exec_module(module)
|
module = reversed_structure[model_class]
|
||||||
# Finally, get `_import_structure` that we need
|
framework = ""
|
||||||
import_structure = module._import_structure
|
if model_class.startswith("TF"):
|
||||||
|
framework = "tf"
|
||||||
# map model classes to their defined module
|
elif model_class.startswith("Flax"):
|
||||||
reversed_structure = {}
|
framework = "flax"
|
||||||
for key, values in import_structure.items():
|
fn = (
|
||||||
for value in values:
|
f"modeling_{module.split('.')[-1]}.py"
|
||||||
reversed_structure[value] = key
|
if framework == ""
|
||||||
|
else f"modeling_{framework}_{module.split('.')[-1]}.py"
|
||||||
# Get the corresponding modeling file path
|
)
|
||||||
for model_class in impacted_model_classes:
|
files.add(f"src.transformers.{module}.{fn}".replace(".", os.path.sep).replace(f"{os.path.sep}py", ".py"))
|
||||||
module = reversed_structure[model_class]
|
|
||||||
framework = ""
|
|
||||||
if model_class.startswith("TF"):
|
|
||||||
framework = "tf"
|
|
||||||
elif model_class.startswith("Flax"):
|
|
||||||
framework = "flax"
|
|
||||||
fn = (
|
|
||||||
f"modeling_{module.split('.')[-1]}.py"
|
|
||||||
if framework == ""
|
|
||||||
else f"modeling_{framework}_{module.split('.')[-1]}.py"
|
|
||||||
)
|
|
||||||
files.add(
|
|
||||||
f"src.transformers.{module}.{fn}".replace(".", os.path.sep).replace(f"{os.path.sep}py", ".py")
|
|
||||||
)
|
|
||||||
|
|
||||||
return sorted(files)
|
return sorted(files)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user