mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
Enable llava-next
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
d3155d6f41
commit
588a014551
@ -110,6 +110,7 @@ impl Client {
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
model_id: &str
|
||||
) -> Result<Option<u32>> {
|
||||
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
||||
if !warmup_enabled {
|
||||
@ -152,25 +153,76 @@ impl Client {
|
||||
|
||||
let mut batch_counter: u64 = 0;
|
||||
let mut request_counter: u64 = 0;
|
||||
for shape in shapes.iter() {
|
||||
let (batch_size, seq_length) = shape;
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
*shape,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
];
|
||||
// if possible, create second batch in order to trigger concatenate operation
|
||||
if *batch_size < max_decode_batch_size {
|
||||
batches.push(
|
||||
if model_id.contains("llava") {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = cmp::min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str("");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs,
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut batches = Vec::new();
|
||||
batches.push(Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
});
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
else {
|
||||
for shape in shapes.iter() {
|
||||
let (batch_size, seq_length) = shape;
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
(1, *seq_length),
|
||||
*shape,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
@ -179,56 +231,45 @@ impl Client {
|
||||
false,
|
||||
None,
|
||||
)
|
||||
);
|
||||
];
|
||||
// if possible, create second batch in order to trigger concatenate operation
|
||||
if *batch_size < max_decode_batch_size {
|
||||
batches.push(
|
||||
self.create_warmup_batch(
|
||||
(1, *seq_length),
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
// send batches to warmup all possible decode shapes
|
||||
if decode_batch_sizes.len() > 1 {
|
||||
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
|
||||
decode_bucket_size
|
||||
} else {
|
||||
decode_bucket_size.div_ceil(max_prefill_batch_size)
|
||||
};
|
||||
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
|
||||
|
||||
// send batches to warmup all possible decode shapes
|
||||
if decode_batch_sizes.len() > 1 {
|
||||
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
|
||||
decode_bucket_size
|
||||
} else {
|
||||
decode_bucket_size.div_ceil(max_prefill_batch_size)
|
||||
};
|
||||
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
|
||||
|
||||
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
(requests_send, seq_bucket_size),
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
Some(max_new_tokens),
|
||||
)
|
||||
];
|
||||
|
||||
let get_current_decode_batch_size = |num: u32| -> u32 {
|
||||
decode_batch_sizes.iter()
|
||||
.filter(|&&x| x >= num)
|
||||
.min()
|
||||
.copied()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
|
||||
while current_decode_batch_size < max_decode_batch_size {
|
||||
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
|
||||
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
|
||||
batches.push(
|
||||
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
(num_requests, seq_bucket_size),
|
||||
(requests_send, seq_bucket_size),
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
@ -237,48 +278,74 @@ impl Client {
|
||||
false,
|
||||
Some(max_new_tokens),
|
||||
)
|
||||
);
|
||||
];
|
||||
|
||||
requests_send += num_requests;
|
||||
current_decode_batch_size = get_current_decode_batch_size(requests_send);
|
||||
let get_current_decode_batch_size = |num: u32| -> u32 {
|
||||
decode_batch_sizes.iter()
|
||||
.filter(|&&x| x >= num)
|
||||
.min()
|
||||
.copied()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
|
||||
while current_decode_batch_size < max_decode_batch_size {
|
||||
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
|
||||
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
|
||||
batches.push(
|
||||
self.create_warmup_batch(
|
||||
(num_requests, seq_bucket_size),
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
Some(max_new_tokens),
|
||||
)
|
||||
);
|
||||
|
||||
requests_send += num_requests;
|
||||
current_decode_batch_size = get_current_decode_batch_size(requests_send);
|
||||
}
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
// send batches with default params to warm up Greedy search
|
||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
|
||||
for batch_size in &prefill_batch_sizes {
|
||||
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||
}
|
||||
for greedy_shape in greedy_shapes.iter() {
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
*greedy_shape,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
// send batches with default params to warm up Greedy search
|
||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
|
||||
for batch_size in &prefill_batch_sizes {
|
||||
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||
}
|
||||
for greedy_shape in greedy_shapes.iter() {
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
*greedy_shape,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
true,
|
||||
None,
|
||||
)
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
true,
|
||||
None,
|
||||
)
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
Ok(None) // No support for maximum total tokens
|
||||
}
|
||||
Ok(None) // No support for maximum total tokens
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
|
@ -100,6 +100,7 @@ impl ShardedClient {
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
model_id: &str,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
@ -110,6 +111,7 @@ impl ShardedClient {
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
model_id
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
@ -349,6 +349,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
&model_info.model_id
|
||||
)
|
||||
.await
|
||||
.map_err(RouterError::Warmup)?
|
||||
|
@ -16,6 +16,12 @@ from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
@ -159,6 +165,18 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
logger.info(f"model_type = {model_type}")
|
||||
if model_type == "llava_next":
|
||||
logger.info(f"################model_type = {model_type}")
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=None,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
|
@ -369,6 +369,7 @@ class CausalLMBatch(Batch):
|
||||
input_lengths = [b.input_length for b in batches]
|
||||
max_input_length = max(input_lengths)
|
||||
offsets = [max_input_length - b.input_length for b in batches]
|
||||
|
||||
cur_padding = [b.right_padding for b in batches]
|
||||
# For prefill there is a space allocated only for first token
|
||||
# Need to add padding to the max total tokens before first decode
|
||||
|
@ -21,17 +21,12 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
@ -56,100 +51,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.language_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
@ -164,120 +72,215 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
if token_idx is not None:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
token_idx=token_idx,
|
||||
)
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
logits = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
|
||||
The only differences are:
|
||||
- add new args token_idx
|
||||
- add the process of merging images into inputs_embeds
|
||||
"""
|
||||
token_idx = kwargs.get("token_idx", None)
|
||||
if token_idx is None:
|
||||
return super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
||||
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
|
||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
||||
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
||||
image_features = self.vision_tower(
|
||||
reshaped_pixel_values, output_hidden_states=True
|
||||
)
|
||||
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx].tolist(),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
|
||||
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
seq_len = input_ids.shape[1]
|
||||
pad_len = seq_len - token_idx.item()
|
||||
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = extended_attention_mask
|
||||
attention_mask[:, -pad_len:] = 0
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
if token_idx is not None:
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"token_idx": token_idx,
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
return model_inputs
|
@ -10,7 +10,12 @@ import numpy as np
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
@ -19,6 +24,11 @@ from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.utils.dist import RANK
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
@ -686,20 +696,97 @@ class FlashCausalLMBatch(Batch):
|
||||
class FlashCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
lora_adapter_ids: Optional[list] = [],
|
||||
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||
default_dtype=torch.bfloat16,
|
||||
aliases=None,
|
||||
# Used for Santacoder override of config
|
||||
num_kv_heads: Optional[int] = None,
|
||||
# Deepseek V2 uses different QK and V dims.
|
||||
head_size: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
|
||||
# Create model
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
device = torch.device("hpu")
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||
# TODO Huge hack
|
||||
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype)
|
||||
|
||||
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
|
||||
# VLM models define the config we care about in their text_config
|
||||
text_config = getattr(config, "text_config", None)
|
||||
if text_config is not None:
|
||||
config = text_config
|
||||
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
# Validation is done in the model itself
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||
# GPT-2 workaround
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = getattr(config, "n_head", None)
|
||||
if num_kv_heads is None:
|
||||
raise ValueError("Cannot get the number of key/value heads")
|
||||
self.num_kv_heads = num_kv_heads (
|
||||
num_kv_heads // self.process_group.size()
|
||||
if num_kv_heads > 1
|
||||
else num_kv_heads
|
||||
)
|
||||
assert self.num_kv_heads > 0
|
||||
|
||||
if head_size is None:
|
||||
# Some models use GQA and different sizes for o_proj
|
||||
# and q_proj, that allows for that.
|
||||
if hasattr(config, "head_dim"):
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
else:
|
||||
self.head_size = head_size
|
||||
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
@ -711,7 +798,7 @@ class FlashCausalLM(Model):
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=sliding_window,
|
||||
sliding_window=None,
|
||||
)
|
||||
|
||||
@property
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -96,8 +96,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||
self.model.warmup(batches)
|
||||
if self.model.batch_type in VLM_BATCH_TYPES :
|
||||
self.model.warmup(request)
|
||||
else:
|
||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||
self.model.warmup(batches)
|
||||
|
||||
return generate_pb2.WarmupResponse()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user