Enable llava-next

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-07-28 09:05:49 +00:00
parent d3155d6f41
commit 588a014551
9 changed files with 1278 additions and 480 deletions

View File

@ -110,6 +110,7 @@ impl Client {
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
model_id: &str
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
if !warmup_enabled { if !warmup_enabled {
@ -152,25 +153,76 @@ impl Client {
let mut batch_counter: u64 = 0; let mut batch_counter: u64 = 0;
let mut request_counter: u64 = 0; let mut request_counter: u64 = 0;
for shape in shapes.iter() { if model_id.contains("llava") {
let (batch_size, seq_length) = shape; let mut n_tokens = 0;
let mut batches: Vec<Batch> = vec![ let mut requests = Vec::new();
self.create_warmup_batch( // Create requests
*shape, while n_tokens < max_prefill_tokens {
&mut batch_counter, let truncate = cmp::min(max_input_length, max_prefill_tokens - n_tokens);
&mut request_counter,
max_input_length, let mut inputs = String::new();
max_total_tokens, inputs.push_str("![]()");
seq_bucket_size, inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
false,
None, requests.push(Request {
) id: 0,
]; // We truncate the input on the server side to be sure that it has the correct size
// if possible, create second batch in order to trigger concatenate operation inputs,
if *batch_size < max_decode_batch_size { truncate,
batches.push( // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let mut batches = Vec::new();
batches.push(Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
});
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
else {
for shape in shapes.iter() {
let (batch_size, seq_length) = shape;
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch( self.create_warmup_batch(
(1, *seq_length), *shape,
&mut batch_counter, &mut batch_counter,
&mut request_counter, &mut request_counter,
max_input_length, max_input_length,
@ -179,56 +231,45 @@ impl Client {
false, false,
None, None,
) )
); ];
// if possible, create second batch in order to trigger concatenate operation
if *batch_size < max_decode_batch_size {
batches.push(
self.create_warmup_batch(
(1, *seq_length),
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
None,
)
);
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
} }
let request = tonic::Request::new(WarmupRequest { // send batches to warmup all possible decode shapes
batches, if decode_batch_sizes.len() > 1 {
max_input_length, let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
max_prefill_tokens, decode_bucket_size
max_total_tokens, } else {
}).inject_context(); decode_bucket_size.div_ceil(max_prefill_batch_size)
let _response = self.stub.warmup(request).await?.into_inner(); };
} let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
// send batches to warmup all possible decode shapes let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
if decode_batch_sizes.len() > 1 { let mut batches: Vec<Batch> = vec![
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
decode_bucket_size
} else {
decode_bucket_size.div_ceil(max_prefill_batch_size)
};
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
(requests_send, seq_bucket_size),
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
Some(max_new_tokens),
)
];
let get_current_decode_batch_size = |num: u32| -> u32 {
decode_batch_sizes.iter()
.filter(|&&x| x >= num)
.min()
.copied()
.unwrap()
};
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
while current_decode_batch_size < max_decode_batch_size {
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
batches.push(
self.create_warmup_batch( self.create_warmup_batch(
(num_requests, seq_bucket_size), (requests_send, seq_bucket_size),
&mut batch_counter, &mut batch_counter,
&mut request_counter, &mut request_counter,
max_input_length, max_input_length,
@ -237,48 +278,74 @@ impl Client {
false, false,
Some(max_new_tokens), Some(max_new_tokens),
) )
); ];
requests_send += num_requests; let get_current_decode_batch_size = |num: u32| -> u32 {
current_decode_batch_size = get_current_decode_batch_size(requests_send); decode_batch_sizes.iter()
.filter(|&&x| x >= num)
.min()
.copied()
.unwrap()
};
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
while current_decode_batch_size < max_decode_batch_size {
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
batches.push(
self.create_warmup_batch(
(num_requests, seq_bucket_size),
&mut batch_counter,
&mut request_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
Some(max_new_tokens),
)
);
requests_send += num_requests;
current_decode_batch_size = get_current_decode_batch_size(requests_send);
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
} }
let request = tonic::Request::new(WarmupRequest { // send batches with default params to warm up Greedy search
batches, let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
max_input_length, for batch_size in &prefill_batch_sizes {
max_prefill_tokens, greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
max_total_tokens, }
}).inject_context(); for greedy_shape in greedy_shapes.iter() {
let _response = self.stub.warmup(request).await?.into_inner(); let batches: Vec<Batch> = vec![
} self.create_warmup_batch(
*greedy_shape,
// send batches with default params to warm up Greedy search &mut batch_counter,
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len()); &mut request_counter,
for batch_size in &prefill_batch_sizes { max_input_length,
greedy_shapes.push((*batch_size, seq_bucket_size.clone())); max_total_tokens,
} seq_bucket_size,
for greedy_shape in greedy_shapes.iter() { true,
let batches: Vec<Batch> = vec![ None,
self.create_warmup_batch( )
*greedy_shape, ];
&mut batch_counter, let request = tonic::Request::new(WarmupRequest {
&mut request_counter, batches,
max_input_length, max_input_length,
max_prefill_tokens,
max_total_tokens, max_total_tokens,
seq_bucket_size, }).inject_context();
true, let _response = self.stub.warmup(request).await?.into_inner();
None, }
) Ok(None) // No support for maximum total tokens
];
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
} }
Ok(None) // No support for maximum total tokens
} }
#[instrument(skip_all)] #[instrument(skip_all)]

View File

@ -100,6 +100,7 @@ impl ShardedClient {
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
model_id: &str,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
@ -110,6 +111,7 @@ impl ShardedClient {
max_prefill_tokens, max_prefill_tokens,
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
model_id
)) ))
}) })
.collect(); .collect();

View File

@ -349,6 +349,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_total_tokens as u32, max_total_tokens as u32,
max_batch_size, max_batch_size,
&model_info.model_id
) )
.await .await
.map_err(RouterError::Warmup)? .map_err(RouterError::Warmup)?

View File

@ -16,6 +16,12 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@ -159,6 +165,18 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
logger.info(f"model_type = {model_type}")
if model_type == "llava_next":
logger.info(f"################model_type = {model_type}")
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(

View File

@ -369,6 +369,7 @@ class CausalLMBatch(Batch):
input_lengths = [b.input_length for b in batches] input_lengths = [b.input_length for b in batches]
max_input_length = max(input_lengths) max_input_length = max(input_lengths)
offsets = [max_input_length - b.input_length for b in batches] offsets = [max_input_length - b.input_length for b in batches]
cur_padding = [b.right_padding for b in batches] cur_padding = [b.right_padding for b in batches]
# For prefill there is a space allocated only for first token # For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode # Need to add padding to the max total tokens before first decode

View File

@ -21,17 +21,12 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
)
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from loguru import logger
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
@ -56,100 +51,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def unpad_image(tensor, original_size): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features( def _merge_input_ids_with_image_features(
self, self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
image_features: torch.Tensor, image_features: torch.Tensor,
input_ids: torch.Tensor,
): ):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
@ -164,120 +72,215 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.LongTensor = None,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images if token_idx is not None:
num_images, num_patches, channels, height, width = pixel_values.shape output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pixel_values = pixel_values.view( output_hidden_states = (
num_images * num_patches, channels, height, width output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
image_features = self.vision_tower(pixel_values) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] outputs = self.language_model(
# Already done within the clip model attention_mask=attention_mask,
selected_image_feature = image_features.last_hidden_state position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
)
if self.config.vision_feature_select_strategy == "default": logits = outputs[0]
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full": if not return_dict:
selected_image_feature = selected_image_feature output = (logits,) + outputs[1:]
return output
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are:
- add new args token_idx
- add the process of merging images into inputs_embeds
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
**kwargs,
)
else: else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True
)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
}
) )
image_features = self.multi_modal_projector(selected_image_feature) return model_inputs
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -19,6 +24,11 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -686,20 +696,97 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model: torch.nn.Module, model_id: str,
tokenizer: PreTrainedTokenizerBase, model_class,
num_layers: int, revision: Optional[str] = None,
num_kv_heads: int, quantize: Optional[str] = None,
head_size: int, speculator: Optional[str] = None,
dtype: torch.dtype, dtype: Optional[torch.dtype] = None,
device: torch.device, trust_remote_code: bool = False,
rank: int = 0, lora_adapter_ids: Optional[list] = [],
world_size: int = 1, tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
sliding_window: Optional[int] = None, config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.bfloat16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
): ):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads # Create model
self.head_size = head_size world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("hpu")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype)
prefix = ""
model = model_class(prefix, config, weights)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = num_kv_heads (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {}
self.kv_cache = []
self.cuda_graphs = {} self.cuda_graphs = {}
@ -711,7 +798,7 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window, sliding_window=None,
) )
@property @property

File diff suppressed because it is too large Load Diff

View File

@ -96,8 +96,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch, self.model.tokenizer, self.model.dtype, self.model.device batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
batches = [batch_from_pb(batch) for batch in request.batches] if self.model.batch_type in VLM_BATCH_TYPES :
self.model.warmup(batches) self.model.warmup(request)
else:
batches = [batch_from_pb(batch) for batch in request.batches]
self.model.warmup(batches)
return generate_pb2.WarmupResponse() return generate_pb2.WarmupResponse()