mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixes for VLM.
This commit is contained in:
parent
b2fb845923
commit
43ef5268fd
@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.language_model = load_text_model(
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -880,8 +880,11 @@ class FlashCausalLM(Model):
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
text_config = getattr(config, "text_config", None)
|
||||
if text_config is not None:
|
||||
config = text_config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
# Validation is done in the model itself
|
||||
self.num_kv_heads = config.num_key_value_heads // self.process_group.size()
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
|
@ -261,7 +261,7 @@ class VlmCausalLM(FlashMistral):
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(model_id=model_id, **kwargs)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
|
Loading…
Reference in New Issue
Block a user