From 613dc9361764510f0b86490f04adcea008669c2b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 19 Apr 2024 16:30:16 +0000 Subject: [PATCH] Idefics2 in working state. --- router/client/src/client.rs | 2 +- router/src/config.rs | 13 +++- router/src/validation.rs | 29 ++++++++- .../custom_modeling/flash_mistral_modeling.py | 2 +- .../models/custom_modeling/idefics2.py | 60 ++++++++++++------- .../text_generation_server/models/idefics2.py | 5 +- .../models/vlm_causal_lm.py | 19 +++--- server/text_generation_server/server.py | 14 +---- 8 files changed, 93 insertions(+), 51 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 545cddd0..24ecd2ad 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -114,8 +114,8 @@ impl Client { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut inputs = String::new(); - inputs.push_str("![]()"); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + inputs.push_str("![]()"); requests.push(Request { id: 0, diff --git a/router/src/config.rs b/router/src/config.rs index 0de0a56c..4ee4704f 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -84,6 +84,17 @@ pub struct ClipVisionModel { patch_size: usize, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "model_type")] +#[serde(rename_all = "snake_case")] +pub struct Idefics2 {} + +impl Idefics2 { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + 320 + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] @@ -92,7 +103,7 @@ pub enum Config { ClipVisionModel(ClipVisionModel), Mistral, Idefics, - Idefics2, + Idefics2(Idefics2), Ssm, GptBigcode, Santacoder, diff --git a/router/src/validation.rs b/router/src/validation.rs index 94e59b2d..be4bef00 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -540,7 +540,34 @@ fn prepare_input( inputs = modified_inputs; tokenizer_query } - Some(Config::Idefics | Config::Idefics2) => { + Some(Config::Idefics2(config)) => { + let mut modified_inputs = String::with_capacity(inputs.len()); + let mut tokenizer_query = String::with_capacity(inputs.len()); + let mut start = 0; + for chunk in RE.find_iter(&inputs) { + let chunk_start = chunk.start(); + let chunk_end = chunk.end(); + if chunk_start != start { + modified_inputs.push_str(&inputs[start..chunk_start]); + tokenizer_query.push_str(&inputs[start..chunk_start]); + } + let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let slots = config.get_number_of_features(height, width); + tokenizer_query.push_str(""); + tokenizer_query.push_str(&"".repeat(slots)); + tokenizer_query.push_str(""); + + modified_inputs.push_str(&image_uri); + start = chunk_end; + } + if start != inputs.len() - 1 { + modified_inputs.push_str(&inputs[start..]); + tokenizer_query.push_str(&inputs[start..]); + } + inputs = modified_inputs; + tokenizer_query + } + Some(Config::Idefics) => { let mut modified_inputs = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index e78260fc..c2445cda 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -430,7 +430,7 @@ class FlashMistralForCausalLM(torch.nn.Module): config, # TODO dirty hack for idefics2. prefix=( - "lm_head" if not prefix or name is not "model" else f"{prefix}.lm_head" + "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head" ), weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index dc51fbcd..95a4d476 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -36,6 +36,20 @@ from text_generation_server.utils.layers import ( ) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class Idefics2VisionEmbeddings(nn.Module): """ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable @@ -390,14 +404,15 @@ class Idefics2MLP(nn.Module): weights=weights, bias=False, ) - self.intermediate_size = ( - config.text_config.intermediate_size // weights.process_group.size() - ) def forward(self, hidden_states): + start_shape = hidden_states.shape[:-1] gate_up_states = self.gate_up_proj(hidden_states) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + intermediate_size = gate_up_states.shape[-1] // 2 + gate_up_states = gate_up_states.view(-1, 2, intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] + ).view(*start_shape, -1) class Idefics2RMSNorm(nn.Module): @@ -432,17 +447,23 @@ class Idefics2PerceiverAttention(nn.Module): self.attention_dropout = config.perceiver_config.attention_dropout self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( - config.text_config.num_key_value_heads // weights.process_group.size() + self.num_key_value_heads // weights.process_group.size() ) - self.qkv = TensorParallelColumnLinear.load_multi( + self.q_proj = TensorParallelColumnLinear.load( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.kv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, ) - self.out_proj = TensorParallelRowLinear.load( + self.o_proj = TensorParallelRowLinear.load( config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False ) @@ -457,19 +478,13 @@ class Idefics2PerceiverAttention(nn.Module): bsz, q_len, _ = latents.size() kv_seq_len = q_len + context.size()[1] - try: - hidden_states = torch.concat([context, latents], dim=-2) - except Exception as e: - print(e) - import ipdb - - ipdb.set_trace() - - qkv = self.qkv(hidden_states) - query_states, key_states, value_states = qkv.split( + hidden_states = torch.concat([context, latents], dim=-2) + query_states = self.q_proj(latents) + kv = self.kv(hidden_states) + key_states, value_states = kv.split( [ - self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, ], dim=2, ) @@ -704,7 +719,8 @@ class Idefics2ForConditionalGeneration(nn.Module): image_features: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" - mask = input_ids == self.config.image_token_index + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py index d88ff574..f759300d 100644 --- a/server/text_generation_server/models/idefics2.py +++ b/server/text_generation_server/models/idefics2.py @@ -23,7 +23,10 @@ class Idefics2(VlmCausalLM): trust_remote_code: bool = False, ): self.processor = AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + size={"longest_edge": 448, "shortest_edge": 378}, ) super().__init__( model_cls=Idefics2ForConditionalGeneration, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5161a970..ab87c5c7 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -150,7 +150,7 @@ class VlmCausalLMBatch(FlashMistralBatch): # import ipdb;ipdb.set_trace() # height, width = image_input["image_sizes"][0] # num_features = get_number_of_features(height, width, config) - num_features = 1 + num_features = 320 full_text += "" * num_features image_inputs.append(image_input) else: @@ -269,17 +269,14 @@ class VlmCausalLM(BaseFlashMistral): max_s = min(self.max_past(), max_s) bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: logits, speculative_logits = self.model.forward( input_ids=input_ids, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 2c2a6566..495c2c0c 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -154,19 +154,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = batches[0] concat_ns = None - torch.profiler._utils._init_for_cuda_graphs() - # prof = torch.profiler.profile() - # if self.model.rank != 0: - if True: - import contextlib - - prof = contextlib.nullcontext() - else: - prof = torch.profiler.profile() - with prof: - generations, next_batch, timings = self.model.generate_token(batch) - # if self.model.rank == 0: - # prof.export_chrome_trace(f"out_rank_0.json") + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.DecodeResponse(