Enable llava-next

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-07-28 09:05:49 +00:00
parent d3155d6f41
commit 588a014551
9 changed files with 1278 additions and 480 deletions

View File

@ -110,6 +110,7 @@ impl Client {
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
model_id: &str
) -> Result<Option<u32>> {
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
if !warmup_enabled {
@ -152,25 +153,76 @@ impl Client {
let mut batch_counter: u64 = 0;
let mut request_counter: u64 = 0;
for shape in shapes.iter() {
let (batch_size, seq_length) = shape;
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
*shape,
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
None,
)
];
// if possible, create second batch in order to trigger concatenate operation
if *batch_size < max_decode_batch_size {
batches.push(
if model_id.contains("llava") {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = cmp::min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str("![]()");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs,
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let mut batches = Vec::new();
batches.push(Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
});
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
else {
for shape in shapes.iter() {
let (batch_size, seq_length) = shape;
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
(1, *seq_length),
*shape,
&mut batch_counter,
&mut request_counter,
max_input_length,
@ -179,56 +231,45 @@ impl Client {
false,
None,
)
);
];
// if possible, create second batch in order to trigger concatenate operation
if *batch_size < max_decode_batch_size {
batches.push(
self.create_warmup_batch(
(1, *seq_length),
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
None,
)
);
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
// send batches to warmup all possible decode shapes
if decode_batch_sizes.len() > 1 {
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
decode_bucket_size
} else {
decode_bucket_size.div_ceil(max_prefill_batch_size)
};
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
// send batches to warmup all possible decode shapes
if decode_batch_sizes.len() > 1 {
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
decode_bucket_size
} else {
decode_bucket_size.div_ceil(max_prefill_batch_size)
};
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
(requests_send, seq_bucket_size),
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
Some(max_new_tokens),
)
];
let get_current_decode_batch_size = |num: u32| -> u32 {
decode_batch_sizes.iter()
.filter(|&&x| x >= num)
.min()
.copied()
.unwrap()
};
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
while current_decode_batch_size < max_decode_batch_size {
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
batches.push(
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
(num_requests, seq_bucket_size),
(requests_send, seq_bucket_size),
&mut batch_counter,
&mut request_counter,
max_input_length,
@ -237,48 +278,74 @@ impl Client {
false,
Some(max_new_tokens),
)
);
];
requests_send += num_requests;
current_decode_batch_size = get_current_decode_batch_size(requests_send);
let get_current_decode_batch_size = |num: u32| -> u32 {
decode_batch_sizes.iter()
.filter(|&&x| x >= num)
.min()
.copied()
.unwrap()
};
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
while current_decode_batch_size < max_decode_batch_size {
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
batches.push(
self.create_warmup_batch(
(num_requests, seq_bucket_size),
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
Some(max_new_tokens),
)
);
requests_send += num_requests;
current_decode_batch_size = get_current_decode_batch_size(requests_send);
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
// send batches with default params to warm up Greedy search
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
for batch_size in &prefill_batch_sizes {
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
}
for greedy_shape in greedy_shapes.iter() {
let batches: Vec<Batch> = vec![
self.create_warmup_batch(
*greedy_shape,
&mut batch_counter,
&mut request_counter,
// send batches with default params to warm up Greedy search
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
for batch_size in &prefill_batch_sizes {
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
}
for greedy_shape in greedy_shapes.iter() {
let batches: Vec<Batch> = vec![
self.create_warmup_batch(
*greedy_shape,
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
true,
None,
)
];
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
seq_bucket_size,
true,
None,
)
];
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
Ok(None) // No support for maximum total tokens
}
Ok(None) // No support for maximum total tokens
}
#[instrument(skip_all)]

View File

@ -100,6 +100,7 @@ impl ShardedClient {
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
model_id: &str,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
@ -110,6 +111,7 @@ impl ShardedClient {
max_prefill_tokens,
max_total_tokens,
max_batch_size,
model_id
))
})
.collect();

View File

@ -349,6 +349,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
&model_info.model_id
)
.await
.map_err(RouterError::Warmup)?

View File

@ -16,6 +16,12 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@ -159,6 +165,18 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
logger.info(f"model_type = {model_type}")
if model_type == "llava_next":
logger.info(f"################model_type = {model_type}")
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(

View File

@ -369,6 +369,7 @@ class CausalLMBatch(Batch):
input_lengths = [b.input_length for b in batches]
max_input_length = max(input_lengths)
offsets = [max_input_length - b.input_length for b in batches]
cur_padding = [b.right_padding for b in batches]
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode

View File

@ -21,17 +21,12 @@ import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
)
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from loguru import logger
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
@ -56,100 +51,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
input_ids: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
@ -164,120 +72,215 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
if token_idx is not None:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
image_features = self.vision_tower(pixel_values)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
)
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are:
- add new args token_idx
- add the process of merging images into inputs_embeds
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
**kwargs,
)
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True
)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
}
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
return logits, speculative_logits
return model_inputs

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -19,6 +24,11 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models.types import (
Batch,
Tokens,
@ -686,20 +696,97 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model):
def __init__(
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.bfloat16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_size = head_size
# Create model
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("hpu")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype)
prefix = ""
model = model_class(prefix, config, weights)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = num_kv_heads (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {}
self.kv_cache = []
self.cuda_graphs = {}
@ -711,7 +798,7 @@ class FlashCausalLM(Model):
device=device,
rank=rank,
world_size=world_size,
sliding_window=sliding_window,
sliding_window=None,
)
@property

File diff suppressed because it is too large Load Diff

View File

@ -96,8 +96,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch, self.model.tokenizer, self.model.dtype, self.model.device
)
batches = [batch_from_pb(batch) for batch in request.batches]
self.model.warmup(batches)
if self.model.batch_type in VLM_BATCH_TYPES :
self.model.warmup(request)
else:
batches = [batch_from_pb(batch) for batch in request.batches]
self.model.warmup(batches)
return generate_pb2.WarmupResponse()