mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: update docs, image token logic and weight names
This commit is contained in:
parent
ebef284b3d
commit
dbe1666bc7
@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||||||
|
|
||||||
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||||
|
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
|
@ -614,6 +614,73 @@ fn image_tokens(
|
|||||||
|
|
||||||
image_string
|
image_string
|
||||||
}
|
}
|
||||||
|
Idefics3(config) => {
|
||||||
|
const FAKE: &str = "<fake_token_around_image>";
|
||||||
|
const IMAGE: &str = "<image>";
|
||||||
|
const GLOBAL_IMG: &str = "<global-img>";
|
||||||
|
|
||||||
|
let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();
|
||||||
|
|
||||||
|
// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
|
||||||
|
let (height, width) = if height > max_longest_edge_for_image_resize
|
||||||
|
|| width > max_longest_edge_for_image_resize
|
||||||
|
{
|
||||||
|
let aspect_ratio = height as f32 / width as f32;
|
||||||
|
if height > width {
|
||||||
|
(
|
||||||
|
max_longest_edge_for_image_resize,
|
||||||
|
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
|
||||||
|
max_longest_edge_for_image_resize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(height, width)
|
||||||
|
};
|
||||||
|
|
||||||
|
let image_seq_len = config.get_number_of_features();
|
||||||
|
let max_edge = config.get_max_longest_edge();
|
||||||
|
|
||||||
|
let (image_rows, image_cols) = if height > max_edge || width > max_edge {
|
||||||
|
(
|
||||||
|
(height as f32 / max_edge as f32).ceil() as usize,
|
||||||
|
(width as f32 / max_edge as f32).ceil() as usize,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(0, 0)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut image_string = String::new();
|
||||||
|
|
||||||
|
if image_rows == 0 && image_cols == 0 {
|
||||||
|
// Single image case
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
image_string.push_str(GLOBAL_IMG);
|
||||||
|
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
} else {
|
||||||
|
// Split image case
|
||||||
|
for n_h in 0..image_rows {
|
||||||
|
for n_w in 0..image_cols {
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
|
||||||
|
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||||
|
}
|
||||||
|
image_string.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
image_string.push('\n');
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
image_string.push_str(GLOBAL_IMG);
|
||||||
|
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
}
|
||||||
|
|
||||||
|
image_string
|
||||||
|
}
|
||||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||||
Qwen2Vl(config) => format!(
|
Qwen2Vl(config) => format!(
|
||||||
@ -647,7 +714,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
|
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
|
||||||
|
| Qwen2Vl(_)),
|
||||||
) => {
|
) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
@ -534,7 +534,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{layer_id}"
|
f"model.layers.{layer_id}"
|
||||||
if not prefix
|
if not prefix
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
else f"{prefix}.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -547,7 +547,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{layer_id}"
|
f"model.layers.{layer_id}"
|
||||||
if not prefix
|
if not prefix
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
else f"{prefix}.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -774,7 +774,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -783,6 +783,10 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
# Unused here
|
# Unused here
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
@ -872,7 +876,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user