diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index f7c272e0..0c3af1ed 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -46,9 +46,12 @@ class Qwen3Attention(nn.Module): self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) - self.num_key_value_groups = ( - config.num_attention_heads // config.num_key_value_heads + config.num_key_value_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads ) + # self.num_key_value_groups = ( + # config.num_attention_heads // config.num_key_value_heads + # ) self.num_heads = config.num_attention_heads self.attention_dropout = config.attention_dropout self.softmax_scale = self.head_dim**-0.5 @@ -66,9 +69,13 @@ class Qwen3Attention(nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() # self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) + if config.num_key_value_heads > weights.process_group.size(): + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + else: + self.num_key_value_heads = config.num_key_value_heads + self.query_key_value = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], @@ -86,10 +93,10 @@ class Qwen3Attention(nn.Module): bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) + ).repeat_interleave(self.num_key_value_groups) self.max_past = ( config.sliding_window if config.sliding_window is not None else -1 @@ -127,6 +134,10 @@ class Qwen3Attention(nn.Module): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.query_key_value(hidden_states) + print(f"qkv shape: {qkv.shape}") + print(f"self.head_dim: {self.head_dim}") + print(f"self.num_heads: {self.num_heads}") + print(f"self.num_key_value_heads: {self.num_key_value_heads}") query_states, key_states, value_states = qkv.split( [ self.head_dim * self.num_heads, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 976e1a65..b3a843dc 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -153,7 +153,7 @@ def prepare_for_decode( block_groups_device, num_classes=batch_size ) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) + mask = mask >= block_usage_device.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) return trim_attn_metadata( HPUPagedAttentionMetadata(