fix: adjust sharding and lm head logic

This commit is contained in:
David Holtz 2024-10-31 15:33:36 +00:00
parent 7d97ee82a1
commit bfa16a5857

View File

@ -35,6 +35,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelEmbedding,
FastLinear,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
@ -69,7 +70,7 @@ def apply_rotary_pos_emb_vision(
class Qwen2VLSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.embed_dim
self.embed_dim = config.embed_dim // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
@ -82,7 +83,7 @@ class Qwen2VLSdpaAttention(nn.Module):
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelColumnLinear.load(
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
@ -364,8 +365,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
self.lm_head = FastLinear.load(
prefix="lm_head", weights=weights, config=config, bias=False
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.norm = FastRMSNorm.load(
prefix="model.norm",
@ -508,5 +516,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices=prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits, None
print("lm_head_indices", lm_head_indices)
logits, speculative_logits = self.lm_head(hidden_states)
# import ipdb; ipdb.set_trace()
return logits, speculative_logits