mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 21:42:06 +00:00
Fixing seqlen with the new vlms.
This commit is contained in:
parent
9dacac3b15
commit
e0069a3a26
@ -19,6 +19,7 @@ from torch import nn
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from transformers.activations import ACT2FN
|
|||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
@ -23,6 +23,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=max_s,
|
||||||
|
max_k=max_k,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
|
Loading…
Reference in New Issue
Block a user