mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Working state.
This commit is contained in:
parent
ef4fa3ea7c
commit
2ac607a215
@ -76,6 +76,7 @@ FLASH_ATTENTION = True
|
|||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
FlashDeepseekV2ForCausalLM,
|
FlashDeepseekV2ForCausalLM,
|
||||||
DeepseekV2Config,
|
DeepseekV2Config,
|
||||||
@ -1138,7 +1139,7 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return MllamaCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=MllamaForConditionalGeneration,
|
model_class=MllamaForConditionalGeneration,
|
||||||
batch_class=MllamaCausalLMBatch,
|
batch_class=MllamaCausalLMBatch,
|
||||||
|
@ -451,7 +451,6 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
image_indices,
|
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -576,7 +575,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states=None,
|
cross_attention_states=None,
|
||||||
image_indices=None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -601,7 +599,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
image_indices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -649,7 +646,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states=None,
|
cross_attention_states=None,
|
||||||
image_indices=None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@ -665,7 +661,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
image_indices=image_indices,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -856,20 +856,17 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states, # [ IB, ...]
|
cross_attention_states, # [ IB, ...]
|
||||||
image_indices,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if cross_attention_states is None:
|
if cross_attention_states is None:
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
|
||||||
# indices = cross_attention_states[-1]
|
indices = cross_attention_states[-1]
|
||||||
if cu_seqlen_prefill is not None:
|
out_hidden_states = hidden_states[:]
|
||||||
out_hidden_states = hidden_states[:]
|
if len(indices) > 0:
|
||||||
hidden_states = hidden_states[:]
|
assert max(indices) < hidden_states.shape[0]
|
||||||
else:
|
hidden_states = hidden_states[indices]
|
||||||
out_hidden_states = hidden_states[:]
|
|
||||||
hidden_states = hidden_states[image_indices]
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
@ -885,10 +882,7 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
out_hidden_states[indices] = hidden_states
|
||||||
out_hidden_states[:] = hidden_states
|
|
||||||
else:
|
|
||||||
out_hidden_states[image_indices] = hidden_states
|
|
||||||
hidden_states = out_hidden_states
|
hidden_states = out_hidden_states
|
||||||
|
|
||||||
return hidden_states, None
|
return hidden_states, None
|
||||||
@ -971,40 +965,24 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
adapter_data: Optional[torch.Tensor],
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states: Optional[torch.Tensor],
|
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||||
image_indices,
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
):
|
):
|
||||||
# if cross_att_sention_mask is not None:
|
|
||||||
# cross_attention_mask, full_text_row_masked_out_mask = (
|
|
||||||
# _prepare_cross_attention_mask(
|
|
||||||
# cross_attention_mask,
|
|
||||||
# num_vision_tokens=self.vision_model.num_patches,
|
|
||||||
# dtype=self.dtype,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# full_text_row_masked_out_mask = None
|
|
||||||
|
|
||||||
# if cross_attention_mask is not None and cache_position is not None:
|
|
||||||
# cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
|
||||||
# full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
|
||||||
# :, :, cache_position
|
|
||||||
# ]
|
|
||||||
|
|
||||||
if cross_attention_states is not None:
|
if cross_attention_states is not None:
|
||||||
seqlen_q = len(image_indices)
|
seqlen_q = len(image_indices)
|
||||||
n_images = cross_attention_states.shape[0]
|
n_images = cross_attention_states.shape[0]
|
||||||
seqlen_k = cross_attention_states.shape[1]
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
device = cross_attention_states.device
|
device = cross_attention_states.device
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# raise RuntimeError("TODO")
|
|
||||||
offset = 0
|
offset = 0
|
||||||
cu_q = []
|
cu_q = []
|
||||||
indices = []
|
indices = []
|
||||||
for index in image_indices:
|
for index in image_indices:
|
||||||
cu_q.append(offset)
|
cu_q.append(offset)
|
||||||
length = seqlen.input_lengths[index]
|
length = seqlen.input_lengths[index].item()
|
||||||
|
assert index < seqlen.cu_seqlen_q.shape[0]
|
||||||
input_ids_offset = seqlen.cu_seqlen_q[index]
|
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||||
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||||
offset += length
|
offset += length
|
||||||
@ -1061,8 +1039,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
# cross_attention_mask=cross_attention_mask,
|
|
||||||
image_indices=image_indices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from typing import Iterable, List
|
from typing import Iterable, Optional, Tuple, List, Dict
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -9,10 +9,13 @@ from opentelemetry import trace
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
from typing import Optional
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
block_tables_to_ragged,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -36,11 +39,17 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
image_indices = []
|
image_indices = []
|
||||||
attention_states = []
|
attention_states = []
|
||||||
for b in batches:
|
for b in batches:
|
||||||
attention_states.append(b.cross_attention_states)
|
if b.cross_attention_states is not None:
|
||||||
|
attention_states.append(b.cross_attention_states)
|
||||||
image_indices.extend([i + offset for i in b.image_indices])
|
image_indices.extend([i + offset for i in b.image_indices])
|
||||||
offset += len(b.image_indices)
|
offset += len(b.image_indices)
|
||||||
batch.cross_attention_states = torch.cat(attention_states, dim=0)
|
if len(attention_states) > 0:
|
||||||
batch.image_indices = image_indices
|
assert len(image_indices) > 0
|
||||||
|
batch.cross_attention_states = torch.cat(attention_states, dim=0)
|
||||||
|
batch.image_indices = image_indices
|
||||||
|
else:
|
||||||
|
batch.cross_attention_states = None
|
||||||
|
batch.image_indices = []
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -64,8 +73,14 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
prev_i = i
|
prev_i = i
|
||||||
|
|
||||||
batch.image_indices = new_image_indices
|
batch.image_indices = new_image_indices
|
||||||
batch.cross_attention_states = self.cross_attention_states[indices]
|
if len(new_image_indices) > 0:
|
||||||
assert offset <= batch.cross_attention_states.shape[0]
|
assert max(new_image_indices) < self.cross_attention_states.shape[0]
|
||||||
|
assert offset <= self.cross_attention_states.shape[0]
|
||||||
|
batch.cross_attention_states = self.cross_attention_states[
|
||||||
|
new_image_indices
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch.cross_attention_states = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -79,23 +94,20 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
for i, r in enumerate(requests):
|
for i, r in enumerate(requests):
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
curr_text = ""
|
curr_text = ""
|
||||||
has_image = False
|
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
curr_text += chunk.text
|
curr_text += chunk.text
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
has_image = True
|
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
# TODO unsure about BOS
|
# TODO unsure about BOS
|
||||||
curr_text += "<|image|>"
|
curr_text += "<|image|>"
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
|
image_indices.append(i)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
texts.append(curr_text)
|
texts.append(curr_text)
|
||||||
if has_image:
|
|
||||||
image_indices.append(i)
|
|
||||||
|
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
curr_text,
|
curr_text,
|
||||||
@ -124,6 +136,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
|
if image_inputs is not None:
|
||||||
|
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -162,3 +177,173 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
batch.image_indices = []
|
batch.image_indices = []
|
||||||
assert batch.image_indices is not None
|
assert batch.image_indices is not None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaCausalLM(VlmCausalLM):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: VlmCausalLMBatch,
|
||||||
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
|
if (
|
||||||
|
cu_seqlen_prefill is not None
|
||||||
|
or cuda_graph is None
|
||||||
|
# Only run cuda graphs when there's no images.
|
||||||
|
or batch.cross_attention_states is not None
|
||||||
|
):
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
):
|
||||||
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=max_s,
|
||||||
|
max_k=max_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
cross_attention_states = self.model.vision_forward(
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||||
|
)
|
||||||
|
batch.cross_attention_states = cross_attention_states
|
||||||
|
|
||||||
|
cross_attention_states = batch.cross_attention_states
|
||||||
|
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
image_indices=batch.image_indices[:],
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
Loading…
Reference in New Issue
Block a user