diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 63dcff72..3553b3f9 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -35,6 +35,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelEmbedding, FastLinear, + SpeculativeHead, ) from text_generation_server.layers.attention import ( Seqlen, @@ -69,7 +70,7 @@ def apply_rotary_pos_emb_vision( class Qwen2VLSdpaAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() - self.embed_dim = config.embed_dim + self.embed_dim = config.embed_dim // weights.process_group.size() self.head_dim = config.hidden_size // config.num_heads self.num_heads = config.num_heads // weights.process_group.size() @@ -82,7 +83,7 @@ class Qwen2VLSdpaAttention(nn.Module): num_key_value_heads=self.num_heads, ) self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) - self.proj = TensorParallelColumnLinear.load( + self.proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.proj", weights=weights, @@ -364,8 +365,15 @@ class Qwen2VLForConditionalGeneration(nn.Module): prefix="visual", config=config.vision_config, weights=weights ) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) - self.lm_head = FastLinear.load( - prefix="lm_head", weights=weights, config=config, bias=False + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, ) self.norm = FastRMSNorm.load( prefix="model.norm", @@ -508,5 +516,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): prefill_cache_indices=prefill_cache_indices, ) hidden_states, _ = self.norm(hidden_states) - logits = self.lm_head(hidden_states) - return logits, None + print("lm_head_indices", lm_head_indices) + logits, speculative_logits = self.lm_head(hidden_states) + # import ipdb; ipdb.set_trace() + return logits, speculative_logits