mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix crash in multi-modal
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
da82c63a4f
commit
d8e3a27648
@ -24,6 +24,7 @@ from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
|
||||
class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
@ -92,7 +93,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
|
||||
# insert image features into input embeddings
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -824,7 +825,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
@ -836,6 +837,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
@ -268,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
@ -280,6 +281,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
Loading…
Reference in New Issue
Block a user