mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix crash in multi-modal
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
da82c63a4f
commit
d8e3a27648
@ -24,6 +24,7 @@ from text_generation_server.models.custom_modeling.vlm import (
|
|||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaForConditionalGeneration(nn.Module):
|
class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
@ -92,7 +93,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# insert image features into input embeddings
|
# insert image features into input embeddings
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
@ -824,7 +825,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
input_ids, inputs_embeds, image_hidden_states
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
)
|
)
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -836,6 +837,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
@ -268,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
input_ids, inputs_embeds, image_features
|
input_ids, inputs_embeds, image_features
|
||||||
)
|
)
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -280,6 +281,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
Loading…
Reference in New Issue
Block a user