mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Fix crash issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
8c182415c2
commit
b32b78e74e
@ -46,9 +46,12 @@ class Qwen3Attention(nn.Module):
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.num_key_value_groups = (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
config.num_key_value_heads = getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
)
|
||||
# self.num_key_value_groups = (
|
||||
# config.num_attention_heads // config.num_key_value_heads
|
||||
# )
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
@ -66,9 +69,13 @@ class Qwen3Attention(nn.Module):
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
# self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
if config.num_key_value_heads > weights.process_group.size():
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
else:
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
@ -86,10 +93,10 @@ class Qwen3Attention(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
).repeat_interleave(self.num_key_value_groups)
|
||||
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
@ -127,6 +134,10 @@ class Qwen3Attention(nn.Module):
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
print(f"qkv shape: {qkv.shape}")
|
||||
print(f"self.head_dim: {self.head_dim}")
|
||||
print(f"self.num_heads: {self.num_heads}")
|
||||
print(f"self.num_key_value_heads: {self.num_key_value_heads}")
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_dim * self.num_heads,
|
||||
|
@ -153,7 +153,7 @@ def prepare_for_decode(
|
||||
block_groups_device, num_classes=batch_size
|
||||
)
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= block_usage.unsqueeze(-1)
|
||||
mask = mask >= block_usage_device.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
|
Loading…
Reference in New Issue
Block a user