mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Idefics2 in working state.
This commit is contained in:
parent
f68ccfd023
commit
613dc93617
@ -114,8 +114,8 @@ impl Client {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str("");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
inputs.push_str("");
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
|
@ -84,6 +84,17 @@ pub struct ClipVisionModel {
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Idefics2 {}
|
||||
|
||||
impl Idefics2 {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
320
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@ -92,7 +103,7 @@ pub enum Config {
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Idefics2,
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
|
@ -540,7 +540,34 @@ fn prepare_input(
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
Some(Config::Idefics | Config::Idefics2) => {
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
Some(Config::Idefics) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
@ -430,7 +430,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
config,
|
||||
# TODO dirty hack for idefics2.
|
||||
prefix=(
|
||||
"lm_head" if not prefix or name is not "model" else f"{prefix}.lm_head"
|
||||
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
|
@ -36,6 +36,20 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
"""
|
||||
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||
@ -390,14 +404,15 @@ class Idefics2MLP(nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.text_config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
start_shape = hidden_states.shape[:-1]
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
intermediate_size = gate_up_states.shape[-1] // 2
|
||||
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||
).view(*start_shape, -1)
|
||||
|
||||
|
||||
class Idefics2RMSNorm(nn.Module):
|
||||
@ -432,17 +447,23 @@ class Idefics2PerceiverAttention(nn.Module):
|
||||
self.attention_dropout = config.perceiver_config.attention_dropout
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.text_config.num_key_value_heads // weights.process_group.size()
|
||||
self.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.kv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||
)
|
||||
|
||||
@ -457,19 +478,13 @@ class Idefics2PerceiverAttention(nn.Module):
|
||||
bsz, q_len, _ = latents.size()
|
||||
kv_seq_len = q_len + context.size()[1]
|
||||
|
||||
try:
|
||||
hidden_states = torch.concat([context, latents], dim=-2)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
hidden_states = torch.concat([context, latents], dim=-2)
|
||||
query_states = self.q_proj(latents)
|
||||
kv = self.kv(hidden_states)
|
||||
key_states, value_states = kv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
@ -704,7 +719,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# mask = input_ids == self.config.image_token_index
|
||||
mask = input_ids == self.config.image_token_id
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
@ -23,7 +23,10 @@ class Idefics2(VlmCausalLM):
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
size={"longest_edge": 448, "shortest_edge": 378},
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=Idefics2ForConditionalGeneration,
|
||||
|
@ -150,7 +150,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# height, width = image_input["image_sizes"][0]
|
||||
# num_features = get_number_of_features(height, width, config)
|
||||
num_features = 1
|
||||
num_features = 320
|
||||
full_text += "<image>" * num_features
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
@ -269,17 +269,14 @@ class VlmCausalLM(BaseFlashMistral):
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -154,19 +154,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = batches[0]
|
||||
concat_ns = None
|
||||
|
||||
torch.profiler._utils._init_for_cuda_graphs()
|
||||
# prof = torch.profiler.profile()
|
||||
# if self.model.rank != 0:
|
||||
if True:
|
||||
import contextlib
|
||||
|
||||
prof = contextlib.nullcontext()
|
||||
else:
|
||||
prof = torch.profiler.profile()
|
||||
with prof:
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
# if self.model.rank == 0:
|
||||
# prof.export_chrome_trace(f"out_rank_0.json")
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user