From ccdec05f7e91f85db7625127ad7adc8bad949338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 13 May 2024 12:26:34 +0000 Subject: [PATCH] Move piece/position embeddings into `FlashGPT2Model` --- .../custom_modeling/flash_gpt2_modeling.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 828d7d14..49534a3a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -302,6 +302,16 @@ class FlashGPT2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + + self.embed_tokens = TensorParallelEmbedding( + prefix=("wte" if not prefix else f"{prefix}.wte"), + weights=weights, + ) + self.embed_positions = TensorParallelEmbedding( + prefix=("wpe" if not prefix else f"{prefix}.wpe"), + weights=weights, + ) + self.layers = nn.ModuleList( [ FlashGPT2Layer( @@ -328,7 +338,7 @@ class FlashGPT2Model(torch.nn.Module): def forward( self, - inputs_embeds: torch.Tensor, + input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -339,6 +349,10 @@ class FlashGPT2Model(torch.nn.Module): true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: + token_embeds = self.embed_tokens(input_ids) + position_embeds = self.embed_positions(position_ids) + inputs_embeds = token_embeds + position_embeds + hidden_states = inputs_embeds residual = None @@ -363,15 +377,6 @@ class FlashGPT2ForCausalLM(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix=("wte" if not prefix else f"{prefix}.wte"), - weights=weights, - ) - self.embed_positions = TensorParallelEmbedding( - prefix=("wpe" if not prefix else f"{prefix}.wpe"), - weights=weights, - ) - self.model = FlashGPT2Model(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, @@ -392,11 +397,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module): prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - token_embeds = self.embed_tokens(input_ids) - position_embeds = self.embed_positions(position_ids) - inputs_embeds = token_embeds + position_embeds hidden_states = self.model( - inputs_embeds, + input_ids, position_ids, cu_seqlen_prefill, kv_cache,