-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding imagebind #30690
base: main
Are you sure you want to change the base?
Adding imagebind #30690
Conversation
…MU) and update config classes for text and image modalities.
…h, thermal, imu).
…ImageBind follows Audio Spectrogram Transformer audio processing).
…uding audio (depth, thermal).
…s/image processors to ImageBind's __init__.py file.
…clipped images) following VideoMAE.
Co-authored-by: Pablo Montalvo <[email protected]>
Hi @molbap , I have addressed all the final pass checks as well and left one two questions. If you can please have a look when you get the time.
I think if you agree with the changes done then cc @ArthurZucker can take on with the review round. Thank you |
All tests are green |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, I left some comments on the feature extractor!
src/transformers/models/imagebind/feature_extraction_imagebind.py
Outdated
Show resolved
Hide resolved
num_mel_bins=self.num_mel_bins, | ||
) | ||
else: | ||
waveform = np.squeeze(waveform) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit)
I think you also want to make sure you do it on the right dimension to avoid edge case (empty audio), WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, edge cases are important to handle just in case if the waveform of raw speech clip is empty.
Done in the recent commits 👍
feature_extractor.max_length, | ||
) | ||
self.assertEqual(input_values.shape, expected_shape) | ||
self.assertTrue(torch.allclose(input_values[:, :, 0, 0, 0], expected_input, atol=1e-4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do another test like this, but focusing on a different segment of the expected output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, made it more robust in the recent commits 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
remove_dc_offset=True, | ||
).T | ||
|
||
fbank = torch.from_numpy(fbank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the model is using torch already, I don't mind keeping this dependent on torch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we could probably do the rest of the operation in numpy right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ylacombe I tried doing the rest of the operation in numpy like this:
but keeping it in numpy during padding or truncation results in inconsistent shapes and fails the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: had to overwrite previous comment as the changes were wrong because numpy operations were performed on torch tensors
(Comment can be ignored)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we could probably do the rest of the operation in numpy right ?
@ylacombe When if is_speech_available():
, ta_kaldi.fbank
is used to perform operation on the numpy array converted into torch tensor like as you mentioned in next review comment that numpy to torch conversion should be done during ta_kaldi.fbank
. So, the output is a torch tensor and that's why at line 267
that you referred, numpy to torch conversion has to be done when is_speech_available
is False because out of the if-else
block, padding operations are done with torch because when is_speech_available
is True, we have output as torch tensor to proceed ahead
return result | ||
|
||
|
||
class ImageBindFeatureExtractor(SequenceFeatureExtractor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. However, there might be too many back and forth from numpy arrays to torch tensors.
IMO yo should do the operation numpy->torch.tensor and torch.tensor->numpy only once
, i.e when torchaudio
is available and you use ta_kaldi.fbank
.
Every other operations should be left in numpy array IMO. That way, you benefit from torchaudio fbank speedups while keeping the dependency to torch minimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ylacombe With reference to the above reply, the numpy->torch.tensor and torch.tensor->numpy is only done to create spectrogram just here in the code file and is necessary because the operations ahead are done on the torch tensor because when is_speech_available
is True
ta_kaldi.fbank()
will output a torch tensor as well. So, to facilitate that the spectogram created with numpy array has to be created into a torch tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huge work @RUFFY-369 and @EduardoPach congrats! 🤗
Some small nits here and there but overall good for me!
def rename_encoder_layers(config, modality): | ||
rename_keys = [] | ||
# fmt: off | ||
for layer_idx in range(config.num_hidden_layers): | ||
rename_keys.extend( | ||
[ | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.in_proj_weight",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.qkv_proj.weight"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.in_proj_bias",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.qkv_proj.bias"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.out_proj.weight",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.out_proj.weight"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.out_proj.bias",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.out_proj.bias"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_1.weight",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_before.weight"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_1.bias",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_before.bias"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc1.weight",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc1.weight"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc1.bias",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc1.bias"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc2.weight",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc2.weight"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc2.bias",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc2.bias"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_2.weight",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_after.weight"), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_2.bias",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_after.bias"), | ||
] | ||
) | ||
if config.add_kv_bias: | ||
rename_keys.extend( | ||
[ | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.bias_k",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.k_bias",), | ||
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.bias_v",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.v_bias",), | ||
] | ||
) | ||
# fmt: on | ||
|
||
return rename_keys | ||
|
||
|
||
# here we list all keys to be renamed (original name on the left, our name on the right) | ||
def create_rename_keys(config): | ||
vision_config = config.vision_config | ||
text_config = config.text_config | ||
audio_config = config.audio_config | ||
|
||
rename_keys = [] | ||
|
||
# fmt: off | ||
|
||
# Convert Vision | ||
rename_keys.extend([ | ||
("modality_preprocessors.vision.cls_token", "vision_model.embeddings.cls_token"), | ||
("modality_preprocessors.vision.rgbt_stem.proj.1.weight", "vision_model.embeddings.patch_embedding.projection.weight"), | ||
("modality_preprocessors.vision.pos_embedding_helper.pos_embed", "vision_model.embeddings.position_embeddings"), | ||
("modality_heads.vision.0.weight", "vision_model.layernorm.weight"), | ||
("modality_heads.vision.0.bias", "vision_model.layernorm.bias"), | ||
("modality_heads.vision.2.weight", "vision_projection.weight"), | ||
("modality_trunks.vision.pre_transformer_layer.0.weight", "vision_model.pre_layernorm.weight"), | ||
("modality_trunks.vision.pre_transformer_layer.0.bias", "vision_model.pre_layernorm.bias"), | ||
]) | ||
|
||
rename_keys.extend( | ||
rename_encoder_layers(vision_config, "vision") | ||
) | ||
|
||
# Convert Text | ||
rename_keys.extend([ | ||
("modality_preprocessors.text.pos_embed", "text_model.embeddings.position_embedding.weight"), | ||
("modality_preprocessors.text.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), | ||
("modality_heads.text.proj.0.weight", "text_model.layernorm.weight"), | ||
("modality_heads.text.proj.0.bias", "text_model.layernorm.bias"), | ||
("modality_heads.text.proj.1.weight", "text_projection.weight"), | ||
("modality_postprocessors.text.1.log_logit_scale", "text_postprocessor.log_logit_scale"), | ||
]) | ||
|
||
rename_keys.extend( | ||
rename_encoder_layers(text_config, "text") | ||
) | ||
|
||
# Convert Audio | ||
rename_keys.extend([ | ||
("modality_preprocessors.audio.cls_token", "audio_model.embeddings.cls_token"), | ||
("modality_preprocessors.audio.rgbt_stem.proj.weight", "audio_model.embeddings.patch_embedding.projection.weight"), | ||
("modality_preprocessors.audio.rgbt_stem.norm_layer.weight", "audio_model.embeddings.patch_embedding.layernorm.weight"), | ||
("modality_preprocessors.audio.rgbt_stem.norm_layer.bias", "audio_model.embeddings.patch_embedding.layernorm.bias"), | ||
("modality_preprocessors.audio.pos_embedding_helper.pos_embed", "audio_model.embeddings.position_embeddings"), | ||
("modality_heads.audio.0.weight", "audio_model.layernorm.weight"), | ||
("modality_heads.audio.0.bias", "audio_model.layernorm.bias"), | ||
("modality_heads.audio.2.weight", "audio_projection.weight"), | ||
]) | ||
|
||
rename_keys.extend( | ||
rename_encoder_layers(audio_config, "audio") | ||
) | ||
# fmt: on | ||
|
||
return rename_keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to be a pain 😅 but a lot of this can be simplified with regexed! Would love to see something simple like we have in mllama
!
ORIGINAL_TO_CONVERTED_KEY_MAPPING = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the changes while using regex as re
in the recent commits 👍 . Please check and mention if any changes are necessary.
def prepare_input(): | ||
ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train") | ||
images = ds["image"] | ||
texts = ds["text"] | ||
audios = [ | ||
torchaudio.functional.resample( | ||
torch.from_numpy(audio["array"]), orig_freq=audio["sampling_rate"], new_freq=16000 | ||
).numpy() | ||
for audio in ds["audio"] | ||
] | ||
|
||
return images, texts, audios |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is unused no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it was used for converted model weights testing but when @molbap reviewed the files he suggested to move the assertion tests. They have been moved and this func can be omitted 👍
Status: Done in the recent commits
src/transformers/models/imagebind/image_processing_imagebind.py
Outdated
Show resolved
Hide resolved
src/transformers/models/imagebind/image_processing_imagebind.py
Outdated
Show resolved
Hide resolved
src/transformers/models/imagebind/image_processing_imagebind.py
Outdated
Show resolved
Hide resolved
if self.scale_logits: | ||
self.logit_scale_init = config.logit_scale_init_value | ||
self.max_logit_scale = max_logit_scale | ||
self.learnable = config.learnable_logit_scale | ||
|
||
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) | ||
if self.learnable: | ||
self.log_logit_scale = nn.Parameter(log_logit_scale) | ||
else: | ||
self.register_buffer("log_logit_scale", log_logit_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about code paths, that is something we try to avoid.
Is this used in released checkpoints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used here and then to the original modeling file.
It's also used in the released checkpoints for eg. this param: 'modality_postprocessors.audio.1.log_logit_scale'
def _build_attention_mask(self, attention_mask, batch_size, seq_len, dtype, device=None): | ||
# Build causal mask | ||
mask = torch.empty(batch_size, seq_len, seq_len, dtype=dtype, device=device) | ||
mask.fill_(torch.finfo(dtype).min) | ||
mask.triu_(1) | ||
mask = mask.unsqueeze(1) # expand mask | ||
|
||
# If attention_mask update causal mask | ||
if attention_mask is not None: | ||
attention_mask = AttentionMaskConverter._expand_mask(attention_mask, dtype) | ||
return mask + attention_mask | ||
return mask | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AttentionMaskConverter
is here to hide sdpa masking logic but we kinda deprecated it otherwise. Specifically it cannot expand into static cache.
using _update_causal_mask
/ a simplified version would be better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about this _update_causal_mask?Because AttentionMaskConverter
is also used here?!
Although,I have pushed a simplified version in the recent commits
class ImageBindProcessorAudioKwargs(AudioKwargs, total=False): | ||
do_normalize: Optional[bool] | ||
mean: Optional[float] | ||
std: Optional[float] | ||
do_chunk: Optional[bool] | ||
chunk_duration: Optional[float] | ||
num_chunks: Optional[int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kinda wondering why we don't use them in the FeatureExtractor as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could - the main appeal for standardizing the processor was for pipeline + api use, but there's no harm in doing it at all call levels!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@molbap @ArthurZucker If this change is to be done then should we accompany it in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep we can!
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Super sorry @RUFFY-369 we went on a company wide offsite for a week, getting back to it now 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done! It's a super big work, lots of parts are involved but you pushed through!
Just added some nits / updates that we made on main, but approving as it should be fairly easy to fix
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": | ||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | ||
|
||
# get the text config dict if we are loading from ImageBindConfig | ||
if config_dict.get("model_type") == "imagebind": | ||
config_dict = config_dict["text_config"] | ||
|
||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | ||
logger.warning( | ||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | ||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||
) | ||
|
||
return cls.from_dict(config_dict, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is no longer required! 🔥
>>> # Accessing the model configuration | ||
>>> configuration = model.config | ||
```""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_config_key = "text_config" |
model_type = "imagebind_vision_model" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_type = "imagebind_vision_model" | |
model_type = "imagebind_vision_model" | |
base_config_key = "vision_config" |
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": | ||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | ||
|
||
# get the vision config dict if we are loading from ImageBindConfig | ||
if config_dict.get("model_type") == "imagebind": | ||
config_dict = config_dict["vision_config"] | ||
|
||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | ||
logger.warning( | ||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | ||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||
) | ||
|
||
return cls.from_dict(config_dict, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@classmethod | |
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": | |
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | |
# get the vision config dict if we are loading from ImageBindConfig | |
if config_dict.get("model_type") == "imagebind": | |
config_dict = config_dict["vision_config"] | |
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | |
logger.warning( | |
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | |
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | |
) | |
return cls.from_dict(config_dict, **kwargs) |
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": | ||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | ||
|
||
# get the audio config dict if we are loading from ImageBindConfig | ||
if config_dict.get("model_type") == "imagebind": | ||
config_dict = config_dict["audio_config"] | ||
|
||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | ||
logger.warning( | ||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | ||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||
) | ||
|
||
return cls.from_dict(config_dict, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@classmethod | |
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": | |
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | |
# get the audio config dict if we are loading from ImageBindConfig | |
if config_dict.get("model_type") == "imagebind": | |
config_dict = config_dict["audio_config"] | |
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | |
logger.warning( | |
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | |
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | |
) | |
return cls.from_dict(config_dict, **kwargs) |
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
return module(*inputs, output_attentions) | ||
|
||
return custom_forward | ||
|
||
layer_outputs = torch.utils.checkpoint.checkpoint( | ||
create_custom_forward(encoder_layer), | ||
hidden_states, | ||
attention_mask, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be a lot simpler!
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs, output_attentions) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(encoder_layer), | |
hidden_states, | |
attention_mask, | |
) | |
layer_outputs = self._gradient_checkpointing_func(( | |
encoder_layer.__call__, | |
hidden_states, | |
attention_mask, | |
) |
_gradient_checkpointing_func
is defined as a super
# Build causal mask | ||
mask = torch.empty(batch_size, seq_len, seq_len, dtype=dtype, device=device) | ||
mask.fill_(torch.finfo(dtype).min) | ||
mask.triu_(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inplace operations are usually not so great for accelerators!
"""The text model from ImageBind without any head or projection on top.""", | ||
IMAGEBIND_START_DOCSTRING, | ||
) | ||
class ImageBindTextModel(ImageBindPreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember what we said about this class but it's a pity to have it as it wraps around ImageBindTextTransformer
not super useful!
) | ||
|
||
|
||
@add_start_docstrings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here!
) | ||
|
||
|
||
@add_start_docstrings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here!
What does this PR do?
This PR fixes #23240 by adding
ImageBind
model.This is based on #26310 which is currently stale and the author said it would not have time to work on it (though welcome to help @dg845 ).
Taking into consideration the points raised by @dg845 here #26310 (comment) I'll focus on adding the text/image/audio portion and try to contact the authors.
Who can Review
@amyeroberts (?)