mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Move piece/position embeddings into FlashGPT2Model
This commit is contained in:
parent
1510461d93
commit
ccdec05f7e
@ -302,6 +302,16 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
||||
weights=weights,
|
||||
)
|
||||
self.embed_positions = TensorParallelEmbedding(
|
||||
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGPT2Layer(
|
||||
@ -328,7 +338,7 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -339,6 +349,10 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
inputs_embeds = token_embeds + position_embeds
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
@ -363,15 +377,6 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
||||
weights=weights,
|
||||
)
|
||||
self.embed_positions = TensorParallelEmbedding(
|
||||
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.model = FlashGPT2Model(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
@ -392,11 +397,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
inputs_embeds = token_embeds + position_embeds
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
Loading…
Reference in New Issue
Block a user