fix: adjust sharding and lm head logic

This commit is contained in:
David Holtz 2024-10-31 15:33:36 +00:00
parent 7d97ee82a1
commit bfa16a5857

View File

@ -35,6 +35,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
FastLinear, FastLinear,
SpeculativeHead,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen, Seqlen,
@ -69,7 +70,7 @@ def apply_rotary_pos_emb_vision(
class Qwen2VLSdpaAttention(nn.Module): class Qwen2VLSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights): def __init__(self, *, prefix, config, weights):
super().__init__() super().__init__()
self.embed_dim = config.embed_dim self.embed_dim = config.embed_dim // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size() self.num_heads = config.num_heads // weights.process_group.size()
@ -82,7 +83,7 @@ class Qwen2VLSdpaAttention(nn.Module):
num_key_value_heads=self.num_heads, num_key_value_heads=self.num_heads,
) )
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelColumnLinear.load( self.proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
weights=weights, weights=weights,
@ -364,8 +365,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix="visual", config=config.vision_config, weights=weights prefix="visual", config=config.vision_config, weights=weights
) )
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
self.lm_head = FastLinear.load( if config.tie_word_embeddings:
prefix="lm_head", weights=weights, config=config, bias=False suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", prefix="model.norm",
@ -508,5 +516,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
) )
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)
logits = self.lm_head(hidden_states) print("lm_head_indices", lm_head_indices)
return logits, None logits, speculative_logits = self.lm_head(hidden_states)
# import ipdb; ipdb.set_trace()
return logits, speculative_logits