mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Fix llava-next and mllama crash issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
e497bc09f6
commit
bb55318f81
@ -20,13 +20,13 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
unpad_image,
|
unpad_image,
|
||||||
)
|
)
|
||||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
@ -49,7 +49,6 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||||
return height // patch_size, width // patch_size
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
|
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
|
||||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||||
"""
|
"""
|
||||||
@ -73,9 +72,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
|||||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||||
if not isinstance(image_size, (list, tuple)):
|
if not isinstance(image_size, (list, tuple)):
|
||||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||||
raise TypeError(
|
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
|
||||||
f"image_size invalid type {type(image_size)} with value {image_size}"
|
|
||||||
)
|
|
||||||
image_size = image_size.tolist()
|
image_size = image_size.tolist()
|
||||||
|
|
||||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||||
@ -89,26 +86,8 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
|||||||
num_patches += 1
|
num_patches += 1
|
||||||
return num_patches
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
|
||||||
self,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
image_features: torch.Tensor,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
||||||
mask = input_ids == self.config.image_token_index
|
|
||||||
# Let's pray we have enabled enough slots !
|
|
||||||
try:
|
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -126,24 +105,16 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
token_idx: Optional[torch.Tensor] = None,
|
token_idx: Optional[torch.Tensor] = None,
|
||||||
use_flash_attention: Optional[bool] = True,
|
use_flash_attention: Optional[bool] = False,
|
||||||
flash_attention_recompute: Optional[bool] = True,
|
flash_attention_recompute: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
output_attentions = (
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@ -169,6 +140,75 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
#Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
|
||||||
|
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||||
|
"""
|
||||||
|
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||||
|
List of image feature tensor, each contains all the visual feature of all patches.
|
||||||
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
|
Actual image size of each images (H, W).
|
||||||
|
vision_feature_select_strategy (`str`)
|
||||||
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
|
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||||
|
New line embedding vector.
|
||||||
|
Returns:
|
||||||
|
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||||
|
feature_lens (`List[int]`)
|
||||||
|
token length of each image in image_features
|
||||||
|
"""
|
||||||
|
new_image_features = []
|
||||||
|
feature_lens = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||||
|
|
||||||
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.config.image_grid_pinpoints,
|
||||||
|
self.config.vision_config.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
|
||||||
|
and vision_feature_select_strategy == "default"
|
||||||
|
):
|
||||||
|
logger.warning_once(
|
||||||
|
"Image feature shape does not line up with the provided patch size. "
|
||||||
|
"You may be using the `default` vision_feature_select_strategy with a"
|
||||||
|
" visual encoder that does not have CLS."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||||
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||||
|
if image_newline is not None:
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
image_newline[:, None, None]
|
||||||
|
.expand(*image_feature.shape[:-1], 1)
|
||||||
|
.to(image_feature.device, image_feature.dtype),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
if image_newline is not None:
|
||||||
|
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
feature_lens.append(image_feature.size(0))
|
||||||
|
image_features = torch.cat(new_image_features, dim=0)
|
||||||
|
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||||
|
return image_features, feature_lens
|
||||||
|
|
||||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
@ -207,16 +247,11 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
]
|
]
|
||||||
if pixel_values.dim() == 5:
|
if pixel_values.dim() == 5:
|
||||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||||
_pixel_values_list = [
|
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||||
pix_val[:num_patch]
|
|
||||||
for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
|
||||||
]
|
|
||||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||||
elif pixel_values.dim() != 4:
|
elif pixel_values.dim() != 4:
|
||||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||||
raise ValueError(
|
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||||
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
|
|
||||||
)
|
|
||||||
|
|
||||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
# If we have one vision feature layer, return the corresponding hidden states,
|
# If we have one vision feature layer, return the corresponding hidden states,
|
||||||
@ -224,10 +259,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
if isinstance(vision_feature_layer, int):
|
if isinstance(vision_feature_layer, int):
|
||||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||||
else:
|
else:
|
||||||
hs_pool = [
|
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||||
image_features.hidden_states[layer_idx]
|
|
||||||
for layer_idx in vision_feature_layer
|
|
||||||
]
|
|
||||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
if vision_feature_select_strategy == "default":
|
if vision_feature_select_strategy == "default":
|
||||||
@ -267,19 +299,13 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
use_flash_attention = kwargs.get("use_flash_attention", True)
|
use_flash_attention = kwargs.get("use_flash_attention", False)
|
||||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
labels = kwargs.get("labels", None)
|
labels = kwargs.get("labels", None)
|
||||||
if (
|
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
||||||
past_key_values is None
|
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
|
||||||
and pixel_values is not None
|
|
||||||
and input_ids.shape[1] != 1
|
|
||||||
):
|
|
||||||
vision_feature_select_strategy = kwargs.get(
|
|
||||||
"vision_feature_select_strategy", None
|
|
||||||
)
|
|
||||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||||
vision_feature_select_strategy = (
|
vision_feature_select_strategy = (
|
||||||
vision_feature_select_strategy
|
vision_feature_select_strategy
|
||||||
@ -287,9 +313,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
else self.config.vision_feature_select_strategy
|
else self.config.vision_feature_select_strategy
|
||||||
)
|
)
|
||||||
vision_feature_layer = (
|
vision_feature_layer = (
|
||||||
vision_feature_layer
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_feature_layer
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Extract the input embeddings
|
# 1. Extract the input embeddings
|
||||||
@ -303,61 +327,25 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
height = width = (
|
image_features, feature_lens = self.pack_image_features(
|
||||||
self.config.vision_config.image_size
|
image_features,
|
||||||
// self.config.vision_config.patch_size
|
image_sizes,
|
||||||
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
|
image_newline=self.image_newline,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_image_features = []
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
for image_idx, image_feature in enumerate(image_features):
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
if image_feature.shape[0] > 1:
|
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||||
base_image_feature = image_feature[0]
|
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
image_feature = image_feature[1:]
|
n_image_features = image_features.shape[0]
|
||||||
|
|
||||||
if height * width != base_image_feature.shape[0]:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of patches is not consistent with the image size."
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
)
|
)
|
||||||
|
|
||||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
image_sizes[image_idx].tolist(),
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
self.config.image_grid_pinpoints,
|
|
||||||
self.config.vision_config.image_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_feature = image_feature.view(
|
|
||||||
num_patch_height, num_patch_width, height, width, -1
|
|
||||||
)
|
|
||||||
image_feature = image_feature.permute(
|
|
||||||
4, 0, 2, 1, 3
|
|
||||||
).contiguous()
|
|
||||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
||||||
image_feature = unpad_image(
|
|
||||||
image_feature, image_sizes[image_idx]
|
|
||||||
)
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(
|
|
||||||
image_feature,
|
|
||||||
self.image_newline[:, None, None].expand(
|
|
||||||
*image_feature.shape[:-1], 1
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(base_image_feature, image_feature), dim=0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_feature = image_feature[0]
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(image_feature, self.image_newline[None]), dim=0
|
|
||||||
)
|
|
||||||
new_image_features.append(image_feature)
|
|
||||||
image_features = torch.cat(new_image_features, dim=0)
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
inputs_embeds, image_features, input_ids
|
|
||||||
)
|
|
||||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||||
# generation with cache
|
# generation with cache
|
||||||
elif past_key_values is not None:
|
elif past_key_values is not None:
|
||||||
@ -368,9 +356,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
# that are set to 0
|
# that are set to 0
|
||||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
batch_index, non_attended_tokens = torch.where(
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
first_layer_past_key_value.float().sum(-2) == 0
|
|
||||||
)
|
|
||||||
# Get the target length
|
# Get the target length
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
extended_attention_mask = torch.ones(
|
extended_attention_mask = torch.ones(
|
||||||
@ -397,9 +383,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
position_ids = (
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
@ -428,6 +428,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
images.append(curr_image)
|
images.append(curr_image)
|
||||||
|
|
||||||
|
if is_warmup is True:
|
||||||
|
images += [images[0]] * (len(texts) - len(images))
|
||||||
|
|
||||||
missing_inputs = 0
|
missing_inputs = 0
|
||||||
dummy_images = None
|
dummy_images = None
|
||||||
if is_warmup is False:
|
if is_warmup is False:
|
||||||
@ -1464,7 +1467,6 @@ class VlmCausalLM(Model):
|
|||||||
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
||||||
max_input_tokens = request.max_input_tokens
|
max_input_tokens = request.max_input_tokens
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||||
@ -1548,7 +1550,7 @@ class VlmCausalLM(Model):
|
|||||||
request,
|
request,
|
||||||
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
||||||
max_prefill_batch_size,
|
max_prefill_batch_size,
|
||||||
is_warmup=False,
|
is_warmup=True,
|
||||||
)
|
)
|
||||||
_, prefill_batch, _ = self.generate_token(
|
_, prefill_batch, _ = self.generate_token(
|
||||||
[batch], is_warmup=True
|
[batch], is_warmup=True
|
||||||
@ -1568,7 +1570,7 @@ class VlmCausalLM(Model):
|
|||||||
request,
|
request,
|
||||||
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
||||||
2,
|
2,
|
||||||
is_warmup=False,
|
is_warmup=True,
|
||||||
)
|
)
|
||||||
_, prefill_batch, _ = self.generate_token(
|
_, prefill_batch, _ = self.generate_token(
|
||||||
[batch], is_warmup=True
|
[batch], is_warmup=True
|
||||||
|
Loading…
Reference in New Issue
Block a user