mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Move piece/position embeddings into FlashGPT2Model
This commit is contained in:
parent
1510461d93
commit
ccdec05f7e
@ -302,6 +302,16 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.embed_positions = TensorParallelEmbedding(
|
||||||
|
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashGPT2Layer(
|
FlashGPT2Layer(
|
||||||
@ -328,7 +338,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs_embeds: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -339,6 +349,10 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
token_embeds = self.embed_tokens(input_ids)
|
||||||
|
position_embeds = self.embed_positions(position_ids)
|
||||||
|
inputs_embeds = token_embeds + position_embeds
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
@ -363,15 +377,6 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
self.embed_positions = TensorParallelEmbedding(
|
|
||||||
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = FlashGPT2Model(prefix, config, weights)
|
self.model = FlashGPT2Model(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
@ -392,11 +397,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
token_embeds = self.embed_tokens(input_ids)
|
|
||||||
position_embeds = self.embed_positions(position_ids)
|
|
||||||
inputs_embeds = token_embeds + position_embeds
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
Loading…
Reference in New Issue
Block a user