Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding imagebind #30690

Open
wants to merge 176 commits into
base: main
Choose a base branch
from
Open

Conversation

EduardoPach
Copy link
Contributor

@EduardoPach EduardoPach commented May 7, 2024

What does this PR do?

This PR fixes #23240 by adding ImageBind model.

This is based on #26310 which is currently stale and the author said it would not have time to work on it (though welcome to help @dg845 ).

Taking into consideration the points raised by @dg845 here #26310 (comment) I'll focus on adding the text/image/audio portion and try to contact the authors.

Who can Review

@amyeroberts (?)

…MU) and update config classes for text and image modalities.
…ImageBind follows Audio Spectrogram Transformer audio processing).
…s/image processors to ImageBind's __init__.py file.
@RUFFY-369
Copy link
Contributor

RUFFY-369 commented Oct 9, 2024

Hi @molbap , I have addressed all the final pass checks as well and left one two questions. If you can please have a look when you get the time.

Final pass I think before passing it to @ArthurZucker (pinging so it's on his radar). Added some comments on qkv biases that are nonstandard a a few other things, overall looks really good! I did run slow tests locally, left a comment on one but all seems inline.

I think if you agree with the changes done then cc @ArthurZucker can take on with the review round.

Thank you

@RUFFY-369
Copy link
Contributor

All tests are green

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I left some comments on the feature extractor!

num_mel_bins=self.num_mel_bins,
)
else:
waveform = np.squeeze(waveform)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)
I think you also want to make sure you do it on the right dimension to avoid edge case (empty audio), WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, edge cases are important to handle just in case if the waveform of raw speech clip is empty.
Done in the recent commits 👍

feature_extractor.max_length,
)
self.assertEqual(input_values.shape, expected_shape)
self.assertTrue(torch.allclose(input_values[:, :, 0, 0, 0], expected_input, atol=1e-4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do another test like this, but focusing on a different segment of the expected output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, made it more robust in the recent commits 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

remove_dc_offset=True,
).T

fbank = torch.from_numpy(fbank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the model is using torch already, I don't mind keeping this dependent on torch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we could probably do the rest of the operation in numpy right ?

Copy link
Contributor

@RUFFY-369 RUFFY-369 Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylacombe I tried doing the rest of the operation in numpy like this:

but keeping it in numpy during padding or truncation results in inconsistent shapes and fails the test

Copy link
Contributor

@RUFFY-369 RUFFY-369 Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: had to overwrite previous comment as the changes were wrong because numpy operations were performed on torch tensors
(Comment can be ignored)

Copy link
Contributor

@RUFFY-369 RUFFY-369 Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we could probably do the rest of the operation in numpy right ?

@ylacombe When if is_speech_available(): , ta_kaldi.fbank is used to perform operation on the numpy array converted into torch tensor like as you mentioned in next review comment that numpy to torch conversion should be done during ta_kaldi.fbank. So, the output is a torch tensor and that's why at line 267 that you referred, numpy to torch conversion has to be done when is_speech_available is False because out of the if-else block, padding operations are done with torch because when is_speech_available is True, we have output as torch tensor to proceed ahead

return result


class ImageBindFeatureExtractor(SequenceFeatureExtractor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. However, there might be too many back and forth from numpy arrays to torch tensors.

IMO yo should do the operation numpy->torch.tensor and torch.tensor->numpy only once, i.e when torchaudio is available and you use ta_kaldi.fbank.
Every other operations should be left in numpy array IMO. That way, you benefit from torchaudio fbank speedups while keeping the dependency to torch minimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylacombe With reference to the above reply, the numpy->torch.tensor and torch.tensor->numpy is only done to create spectrogram just here in the code file and is necessary because the operations ahead are done on the torch tensor because when is_speech_available is True ta_kaldi.fbank() will output a torch tensor as well. So, to facilitate that the spectogram created with numpy array has to be created into a torch tensor

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work @RUFFY-369 and @EduardoPach congrats! 🤗
Some small nits here and there but overall good for me!

Comment on lines 36 to 125
def rename_encoder_layers(config, modality):
rename_keys = []
# fmt: off
for layer_idx in range(config.num_hidden_layers):
rename_keys.extend(
[
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.in_proj_weight",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.qkv_proj.weight"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.in_proj_bias",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.qkv_proj.bias"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.out_proj.weight",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.out_proj.weight"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.out_proj.bias",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.out_proj.bias"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_1.weight",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_before.weight"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_1.bias",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_before.bias"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc1.weight",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc1.weight"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc1.bias",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc1.bias"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc2.weight",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc2.weight"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.mlp.fc2.bias",f"{modality}_model.encoder.layers.{layer_idx}.mlp.fc2.bias"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_2.weight",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_after.weight"),
(f"modality_trunks.{modality}.blocks.{layer_idx}.norm_2.bias",f"{modality}_model.encoder.layers.{layer_idx}.layernorm_after.bias"),
]
)
if config.add_kv_bias:
rename_keys.extend(
[
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.bias_k",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.k_bias",),
(f"modality_trunks.{modality}.blocks.{layer_idx}.attn.bias_v",f"{modality}_model.encoder.layers.{layer_idx}.self_attn.v_bias",),
]
)
# fmt: on

return rename_keys


# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config):
vision_config = config.vision_config
text_config = config.text_config
audio_config = config.audio_config

rename_keys = []

# fmt: off

# Convert Vision
rename_keys.extend([
("modality_preprocessors.vision.cls_token", "vision_model.embeddings.cls_token"),
("modality_preprocessors.vision.rgbt_stem.proj.1.weight", "vision_model.embeddings.patch_embedding.projection.weight"),
("modality_preprocessors.vision.pos_embedding_helper.pos_embed", "vision_model.embeddings.position_embeddings"),
("modality_heads.vision.0.weight", "vision_model.layernorm.weight"),
("modality_heads.vision.0.bias", "vision_model.layernorm.bias"),
("modality_heads.vision.2.weight", "vision_projection.weight"),
("modality_trunks.vision.pre_transformer_layer.0.weight", "vision_model.pre_layernorm.weight"),
("modality_trunks.vision.pre_transformer_layer.0.bias", "vision_model.pre_layernorm.bias"),
])

rename_keys.extend(
rename_encoder_layers(vision_config, "vision")
)

# Convert Text
rename_keys.extend([
("modality_preprocessors.text.pos_embed", "text_model.embeddings.position_embedding.weight"),
("modality_preprocessors.text.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
("modality_heads.text.proj.0.weight", "text_model.layernorm.weight"),
("modality_heads.text.proj.0.bias", "text_model.layernorm.bias"),
("modality_heads.text.proj.1.weight", "text_projection.weight"),
("modality_postprocessors.text.1.log_logit_scale", "text_postprocessor.log_logit_scale"),
])

rename_keys.extend(
rename_encoder_layers(text_config, "text")
)

# Convert Audio
rename_keys.extend([
("modality_preprocessors.audio.cls_token", "audio_model.embeddings.cls_token"),
("modality_preprocessors.audio.rgbt_stem.proj.weight", "audio_model.embeddings.patch_embedding.projection.weight"),
("modality_preprocessors.audio.rgbt_stem.norm_layer.weight", "audio_model.embeddings.patch_embedding.layernorm.weight"),
("modality_preprocessors.audio.rgbt_stem.norm_layer.bias", "audio_model.embeddings.patch_embedding.layernorm.bias"),
("modality_preprocessors.audio.pos_embedding_helper.pos_embed", "audio_model.embeddings.position_embeddings"),
("modality_heads.audio.0.weight", "audio_model.layernorm.weight"),
("modality_heads.audio.0.bias", "audio_model.layernorm.bias"),
("modality_heads.audio.2.weight", "audio_projection.weight"),
])

rename_keys.extend(
rename_encoder_layers(audio_config, "audio")
)
# fmt: on

return rename_keys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to be a pain 😅 but a lot of this can be simplified with regexed! Would love to see something simple like we have in mllama!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the changes while using regex as re in the recent commits 👍 . Please check and mention if any changes are necessary.

Comment on lines 141 to 152
def prepare_input():
ds = load_dataset("EduardoPacheco/imagebind-example-data", split="train")
images = ds["image"]
texts = ds["text"]
audios = [
torchaudio.functional.resample(
torch.from_numpy(audio["array"]), orig_freq=audio["sampling_rate"], new_freq=16000
).numpy()
for audio in ds["audio"]
]

return images, texts, audios
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is unused no ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was used for converted model weights testing but when @molbap reviewed the files he suggested to move the assertion tests. They have been moved and this func can be omitted 👍

Status: Done in the recent commits

Comment on lines +713 to +722
if self.scale_logits:
self.logit_scale_init = config.logit_scale_init_value
self.max_logit_scale = max_logit_scale
self.learnable = config.learnable_logit_scale

log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
if self.learnable:
self.log_logit_scale = nn.Parameter(log_logit_scale)
else:
self.register_buffer("log_logit_scale", log_logit_scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about code paths, that is something we try to avoid.
Is this used in released checkpoints?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used here and then to the original modeling file.
It's also used in the released checkpoints for eg. this param: 'modality_postprocessors.audio.1.log_logit_scale'

Comment on lines 1092 to 1104
def _build_attention_mask(self, attention_mask, batch_size, seq_len, dtype, device=None):
# Build causal mask
mask = torch.empty(batch_size, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1)
mask = mask.unsqueeze(1) # expand mask

# If attention_mask update causal mask
if attention_mask is not None:
attention_mask = AttentionMaskConverter._expand_mask(attention_mask, dtype)
return mask + attention_mask
return mask

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttentionMaskConverter is here to hide sdpa masking logic but we kinda deprecated it otherwise. Specifically it cannot expand into static cache.
using _update_causal_mask / a simplified version would be better!

Copy link
Contributor

@RUFFY-369 RUFFY-369 Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about this _update_causal_mask?Because AttentionMaskConverter is also used here?!
Although,I have pushed a simplified version in the recent commits

Comment on lines +43 to +49
class ImageBindProcessorAudioKwargs(AudioKwargs, total=False):
do_normalize: Optional[bool]
mean: Optional[float]
std: Optional[float]
do_chunk: Optional[bool]
chunk_duration: Optional[float]
num_chunks: Optional[int]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda wondering why we don't use them in the FeatureExtractor as well!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could - the main appeal for standardizing the processor was for pipeline + api use, but there's no harm in doing it at all call levels!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@molbap @ArthurZucker If this change is to be done then should we accompany it in this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we can!

tests/models/imagebind/test_processor_imagebind.py Outdated Show resolved Hide resolved
@ArthurZucker ArthurZucker self-requested a review October 22, 2024 15:08
@ArthurZucker
Copy link
Collaborator

Super sorry @RUFFY-369 we went on a company wide offsite for a week, getting back to it now 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done! It's a super big work, lots of parts are involved but you pushed through!

Just added some nits / updates that we made on main, but approving as it should be fairly easy to fix

Comment on lines +146 to +160
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the text config dict if we are loading from ImageBindConfig
if config_dict.get("model_type") == "imagebind":
config_dict = config_dict["text_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer required! 🔥

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_config_key = "text_config"

Comment on lines +235 to +236
model_type = "imagebind_vision_model"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_type = "imagebind_vision_model"
model_type = "imagebind_vision_model"
base_config_key = "vision_config"

Comment on lines +282 to +296
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the vision config dict if we are loading from ImageBindConfig
if config_dict.get("model_type") == "imagebind":
config_dict = config_dict["vision_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from ImageBindConfig
if config_dict.get("model_type") == "imagebind":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)

Comment on lines +419 to +434
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# get the audio config dict if we are loading from ImageBindConfig
if config_dict.get("model_type") == "imagebind":
config_dict = config_dict["audio_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the audio config dict if we are loading from ImageBindConfig
if config_dict.get("model_type") == "imagebind":
config_dict = config_dict["audio_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)

Comment on lines +979 to +989
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be a lot simpler!

Suggested change
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
)
layer_outputs = self._gradient_checkpointing_func((
encoder_layer.__call__,
hidden_states,
attention_mask,
)

_gradient_checkpointing_func is defined as a super

# Build causal mask
mask = torch.empty(batch_size, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inplace operations are usually not so great for accelerators!

"""The text model from ImageBind without any head or projection on top.""",
IMAGEBIND_START_DOCSTRING,
)
class ImageBindTextModel(ImageBindPreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember what we said about this class but it's a pity to have it as it wraps around ImageBindTextTransformer not super useful!

)


@add_start_docstrings(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here!

)


@add_start_docstrings(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New model] ImageBind: One Embedding Space To Bind Them All
10 participants