Idefics2 in working state.

This commit is contained in:
Nicolas Patry 2024-04-19 16:30:16 +00:00
parent f68ccfd023
commit 613dc93617
8 changed files with 93 additions and 51 deletions

View File

@ -114,8 +114,8 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
requests.push(Request { requests.push(Request {
id: 0, id: 0,

View File

@ -84,6 +84,17 @@ pub struct ClipVisionModel {
patch_size: usize, patch_size: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -92,7 +103,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics,
Idefics2, Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
Santacoder, Santacoder,

View File

@ -540,7 +540,34 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Idefics | Config::Idefics2) => { Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;

View File

@ -430,7 +430,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
config, config,
# TODO dirty hack for idefics2. # TODO dirty hack for idefics2.
prefix=( prefix=(
"lm_head" if not prefix or name is not "model" else f"{prefix}.lm_head" "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
), ),
weights=weights, weights=weights,
) )

View File

@ -36,6 +36,20 @@ from text_generation_server.utils.layers import (
) )
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics2VisionEmbeddings(nn.Module): class Idefics2VisionEmbeddings(nn.Module):
""" """
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
@ -390,14 +404,15 @@ class Idefics2MLP(nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.intermediate_size = (
config.text_config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states): def forward(self, hidden_states):
start_shape = hidden_states.shape[:-1]
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) intermediate_size = gate_up_states.shape[-1] // 2
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
).view(*start_shape, -1)
class Idefics2RMSNorm(nn.Module): class Idefics2RMSNorm(nn.Module):
@ -432,17 +447,23 @@ class Idefics2PerceiverAttention(nn.Module):
self.attention_dropout = config.perceiver_config.attention_dropout self.attention_dropout = config.perceiver_config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = ( self.num_key_value_heads = (
config.text_config.num_key_value_heads // weights.process_group.size() self.num_key_value_heads // weights.process_group.size()
) )
self.qkv = TensorParallelColumnLinear.load_multi( self.q_proj = TensorParallelColumnLinear.load(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.kv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.out_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
) )
@ -457,19 +478,13 @@ class Idefics2PerceiverAttention(nn.Module):
bsz, q_len, _ = latents.size() bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1] kv_seq_len = q_len + context.size()[1]
try: hidden_states = torch.concat([context, latents], dim=-2)
hidden_states = torch.concat([context, latents], dim=-2) query_states = self.q_proj(latents)
except Exception as e: kv = self.kv(hidden_states)
print(e) key_states, value_states = kv.split(
import ipdb
ipdb.set_trace()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_key_value_heads,
2 * self.head_size * self.num_key_value_heads, self.head_size * self.num_key_value_heads,
], ],
dim=2, dim=2,
) )
@ -704,7 +719,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
image_features: torch.Tensor, image_features: torch.Tensor,
): ):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index # mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots ! # Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds return inputs_embeds

View File

@ -23,7 +23,10 @@ class Idefics2(VlmCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id,
revision=revision,
trust_remote_code=trust_remote_code,
size={"longest_edge": 448, "shortest_edge": 378},
) )
super().__init__( super().__init__(
model_cls=Idefics2ForConditionalGeneration, model_cls=Idefics2ForConditionalGeneration,

View File

@ -150,7 +150,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
# import ipdb;ipdb.set_trace() # import ipdb;ipdb.set_trace()
# height, width = image_input["image_sizes"][0] # height, width = image_input["image_sizes"][0]
# num_features = get_number_of_features(height, width, config) # num_features = get_number_of_features(height, width, config)
num_features = 1 num_features = 320
full_text += "<image>" * num_features full_text += "<image>" * num_features
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
@ -269,17 +269,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,

View File

@ -154,19 +154,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0] batch = batches[0]
concat_ns = None concat_ns = None
torch.profiler._utils._init_for_cuda_graphs() generations, next_batch, timings = self.model.generate_token(batch)
# prof = torch.profiler.profile()
# if self.model.rank != 0:
if True:
import contextlib
prof = contextlib.nullcontext()
else:
prof = torch.profiler.profile()
with prof:
generations, next_batch, timings = self.model.generate_token(batch)
# if self.model.rank == 0:
# prof.export_chrome_trace(f"out_rank_0.json")
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(