mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust sharding and lm head logic
This commit is contained in:
parent
7d97ee82a1
commit
bfa16a5857
@ -35,6 +35,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
@ -69,7 +70,7 @@ def apply_rotary_pos_emb_vision(
|
|||||||
class Qwen2VLSdpaAttention(nn.Module):
|
class Qwen2VLSdpaAttention(nn.Module):
|
||||||
def __init__(self, *, prefix, config, weights):
|
def __init__(self, *, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.embed_dim
|
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||||
self.head_dim = config.hidden_size // config.num_heads
|
self.head_dim = config.hidden_size // config.num_heads
|
||||||
self.num_heads = config.num_heads // weights.process_group.size()
|
self.num_heads = config.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
@ -82,7 +83,7 @@ class Qwen2VLSdpaAttention(nn.Module):
|
|||||||
num_key_value_heads=self.num_heads,
|
num_key_value_heads=self.num_heads,
|
||||||
)
|
)
|
||||||
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||||
self.proj = TensorParallelColumnLinear.load(
|
self.proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -364,8 +365,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
prefix="visual", config=config.vision_config, weights=weights
|
prefix="visual", config=config.vision_config, weights=weights
|
||||||
)
|
)
|
||||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||||
self.lm_head = FastLinear.load(
|
if config.tie_word_embeddings:
|
||||||
prefix="lm_head", weights=weights, config=config, bias=False
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm",
|
prefix="model.norm",
|
||||||
@ -508,5 +516,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
logits = self.lm_head(hidden_states)
|
print("lm_head_indices", lm_head_indices)
|
||||||
return logits, None
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
return logits, speculative_logits
|
||||||
|
Loading…
Reference in New Issue
Block a user