feat: refactor model, improve startup and re enable tests

This commit is contained in:
drbh 2025-01-21 22:31:22 +00:00
parent bb69c5b199
commit 77ef543061
9 changed files with 244 additions and 139 deletions

View File

@ -5,7 +5,7 @@
"index": 0,
"logprobs": null,
"message": {
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance.",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1730164250,
"created": 1737498164,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"system_fingerprint": "3.0.2-dev0-native",
"usage": {
"completion_tokens": 58,
"prompt_tokens": 349,
"total_tokens": 407
"completion_tokens": 68,
"prompt_tokens": 1364,
"total_tokens": 1432
}
}

View File

@ -11,10 +11,10 @@
"logprobs": null
}
],
"created": 1730416361,
"created": 1737498227,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.2-dev0-native",
"system_fingerprint": "3.0.2-dev0-native",
"usage": null
}

View File

@ -1,81 +1,78 @@
# Disabled because it's broken.
# import pytest
#
#
# @pytest.fixture(scope="module")
# def flash_qwen2_vl_handle(launcher):
# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
# yield handle
#
#
# @pytest.fixture(scope="module")
# async def flash_qwen2(flash_qwen2_vl_handle):
# await flash_qwen2_vl_handle.health(300)
# return flash_qwen2_vl_handle.client
#
#
# @pytest.mark.private
# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
# response = await flash_qwen2.chat(
# max_tokens=100,
# seed=42,
# messages=[
# {
# "role": "user",
# "content": [
# {
# "type": "image_url",
# "image_url": {
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
# },
# },
# {"type": "text", "text": "Describe this image."},
# ],
# },
# ],
# )
#
# assert (
# response.choices[0].message.content
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
# )
#
# assert response == response_snapshot
#
#
# @pytest.mark.private
# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
# responses = await flash_qwen2.chat(
# max_tokens=100,
# seed=42,
# messages=[
# {
# "role": "user",
# "content": [
# {
# "type": "image_url",
# "image_url": {
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
# },
# },
# {"type": "text", "text": "Describe this image."},
# ],
# },
# ],
# stream=True,
# )
#
# count = 0
# generated = ""
# last_response = None
# async for response in responses:
# count += 1
# generated += response.choices[0].delta.content
# last_response = response
#
# assert (
# generated
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
# )
# assert count == 58
# assert last_response == response_snapshot
import pytest
@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_qwen2(flash_qwen2_vl_handle):
await flash_qwen2_vl_handle.health(300)
return flash_qwen2_vl_handle.client
@pytest.mark.private
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
)
assert (
response.choices[0].message.content
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance."
)
assert response == response_snapshot
@pytest.mark.private
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
responses = await flash_qwen2.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
stream=True,
)
count = 0
generated = ""
last_response = None
async for response in responses:
count += 1
generated += response.choices[0].delta.content
last_response = response
assert (
generated
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance."
)
assert count == 68
assert last_response == response_snapshot

View File

@ -230,7 +230,14 @@ struct QuantizationConfig {
}
#[derive(Debug, Deserialize)]
struct VisionConfig {}
struct VisionConfig {
depth: usize,
embed_dim: usize,
mlp_ratio: usize,
in_chans: usize,
patch_size: usize,
temporal_patch_size: usize,
}
#[derive(Debug, Deserialize)]
struct Config {
@ -253,11 +260,6 @@ struct Config {
impl Config {
fn flop(&self) -> Option<u64> {
if self.vision_config.is_some() {
// VLM are much harder to predict and VRAM requirements
// Are more complex.
return None;
}
let num_heads = self.num_heads? as u64;
let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64;
@ -277,8 +279,38 @@ impl Config {
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
let layer_flops = attn_layer_flops + gate_up_down_flops;
let total = layer_flops * num_layers;
Some(total)
let text_flops = layer_flops * num_layers;
tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop"));
if let Some(vision_config) = self.vision_config.as_ref() {
let in_chans = vision_config.in_chans as u64;
let patch_size = vision_config.patch_size as u64;
let embed_dim = vision_config.embed_dim as u64;
let vision_depth = vision_config.depth as u64;
let mlp_ratio = vision_config.mlp_ratio as u64;
let temporal_patch_size = vision_config.temporal_patch_size as u64;
// 1. patch embedding:
// - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2
// where the 2 accounts for multiply-add
let patch_flops = 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans;
// 2. self-attention + mlp:
// - qkv projections: 3 * d_model * d_model * 2
// - attention: d_model * d_model * 2
// - mlp: 2 * d_model * (mlp_ratio * d_model) * 2
// simplified to: 2 * d_model * (4 + mlp_ratio * d_model)
let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim);
// 3. add with layer norm flops for total vision layer flops
let layer_flops = patch_flops + attn_flops + 2 * embed_dim;
let vision_flops = layer_flops * vision_depth;
tracing::debug!(
"Vision flops: {}",
human_size(vision_flops as usize, "flop")
);
Some(text_flops + vision_flops)
} else {
Some(text_flops)
}
}
fn kv_vram_per_tok(&self) -> Option<usize> {
@ -2012,6 +2044,10 @@ fn main() -> Result<(), LauncherError> {
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
let model_type = config
.as_ref()
.and_then(|c| c.model_type.as_deref())
.map(|s| s.to_owned());
// Quantization usually means you're even more RAM constrained.
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
@ -2100,8 +2136,20 @@ fn main() -> Result<(), LauncherError> {
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
let default_cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default CUDA graphs: {:?}", default_cuda_graphs);
let cuda_graphs = match model_type.as_deref() {
Some("qwen2_vl") => {
tracing::warn!(
"Qwen VL model detected - restricting CUDA graphs to values >= 3"
);
default_cuda_graphs
.into_iter()
.filter(|&c| c >= 3)
.collect()
}
_ => default_cuda_graphs,
};
cuda_graphs
}
};

View File

@ -90,7 +90,11 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear":
pass
elif rope_type == "default":
pass
if rope_scaling.get("mrope_section", False):
mrope_section = rope_scaling.get("mrope_section")
return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
@ -548,3 +552,74 @@ def apply_llama3_scaling(
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(self, inv_freq, scaling_factor, sections):
super().__init__(inv_freq, scaling_factor)
self.sections = sections * 2
self._cos_cached = None
self._sin_cached = None
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
mrope_section = self.sections
unsqueeze_dim = 1
split_cos = cos.split(mrope_section, dim=-1)
split_sin = sin.split(mrope_section, dim=-1)
cos = []
for i, m in enumerate(split_cos):
cos.append(m[i % 3])
cos = torch.cat(cos, dim=-1).unsqueeze(unsqueeze_dim)
sin = []
for i, m in enumerate(split_sin):
sin.append(m[i % 3])
sin = torch.cat(sin, dim=-1).unsqueeze(unsqueeze_dim)
q = query.transpose(0, 1).unsqueeze(0)
k = key.transpose(0, 1).unsqueeze(0)
rotary_dim = cos.shape[-1]
q1 = q[..., :rotary_dim]
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
k1 = k[..., :rotary_dim]
k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True)
def get_cos_sin(
self,
position_ids: torch.Tensor,
max_s: int,
dtype: torch.dtype,
):
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
inv_freq_expanded = (
self.inv_freq[None, None, :, None]
.float()
.expand(3, position_ids.shape[1], -1, 1)
)
position_ids_expanded = position_ids[
:, :, None, :
].float() # shape (3, bs, 1, positions)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
2, 3
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype), sin.to(dtype)

View File

@ -1362,7 +1362,7 @@ def get_model(
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,

View File

@ -61,11 +61,6 @@ class Qwen2Attention(torch.nn.Module):
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.mrope_section = (
config.rope_scaling.get("mrope_section", None)
if config.rope_scaling is not None
else None
)
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -127,17 +122,6 @@ class Qwen2Attention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
if self.mrope_section is not None:
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
@ -251,7 +235,8 @@ class Qwen2Layer(nn.Module):
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
residual = hidden_states
normed_hidden_states, _ = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
@ -266,15 +251,14 @@ class Qwen2Layer(nn.Module):
max_s,
prefill_cache_indices,
)
hidden_states = attn_output + residual
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = mlp_output + residual
return hidden_states
class Qwen2Model(torch.nn.Module):
@ -322,18 +306,15 @@ class Qwen2Model(torch.nn.Module):
) -> torch.Tensor:
hidden_states = inputs_embeds
# flatten position ids from 2D to 1D
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids.flatten(), true_max_s, hidden_states.dtype
position_ids,
true_max_s,
hidden_states.dtype,
)
# reshape back to 2D if the position_ids were 2D
if position_ids.size(0) != cos.size(0):
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states = layer(
hidden_states,
residual,
cos,
@ -347,7 +328,7 @@ class Qwen2Model(torch.nn.Module):
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
hidden_states, _ = self.norm(hidden_states)
return hidden_states

View File

@ -222,12 +222,11 @@ class Qwen2VLVisionBlock(nn.Module):
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
hidden_states_post_norm1, res = self.norm1(hidden_states)
hidden_states = hidden_states + self.attn(
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen
)
hidden_states_post_norm2, res = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
return hidden_states
@ -527,6 +526,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
pixel_values = pixel_values.to(inputs_embeds.dtype)
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw

View File

@ -1486,6 +1486,10 @@ class FlashCausalLM(Model):
state=state,
cache_lengths_tensor=cache_lengths_tensor,
):
# in the case of N dimensional position ids we need to slice the
# position ids to match the input_ids size for cuda graphs warmup
position_ids = position_ids[..., : input_ids.shape[0]]
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,