Move piece/position embeddings into FlashGPT2Model

This commit is contained in:
Daniël de Kok 2024-05-13 12:26:34 +00:00
parent 1510461d93
commit ccdec05f7e

View File

@ -302,6 +302,16 @@ class FlashGPT2Model(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=("wte" if not prefix else f"{prefix}.wte"),
weights=weights,
)
self.embed_positions = TensorParallelEmbedding(
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
weights=weights,
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashGPT2Layer( FlashGPT2Layer(
@ -328,7 +338,7 @@ class FlashGPT2Model(torch.nn.Module):
def forward( def forward(
self, self,
inputs_embeds: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -339,6 +349,10 @@ class FlashGPT2Model(torch.nn.Module):
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
inputs_embeds = token_embeds + position_embeds
hidden_states = inputs_embeds hidden_states = inputs_embeds
residual = None residual = None
@ -363,15 +377,6 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=("wte" if not prefix else f"{prefix}.wte"),
weights=weights,
)
self.embed_positions = TensorParallelEmbedding(
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
weights=weights,
)
self.model = FlashGPT2Model(prefix, config, weights) self.model = FlashGPT2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -392,11 +397,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
inputs_embeds = token_embeds + position_embeds
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,