fix: improve get_position_ids, add lift embed_tokens

This commit is contained in:
David Holtz 2024-10-28 02:15:48 +00:00 committed by drbh
parent 09ac4fb6eb
commit 22fdf9344f
3 changed files with 127 additions and 184 deletions

View File

@ -144,7 +144,7 @@ class TensorParallelColumnLinear(SuperLayer):
num_key_value_heads=num_key_value_heads, num_key_value_heads=num_key_value_heads,
) )
if bias: if bias:
bias = weights.get_tensor(f"{prefix}.bias") raise NotImplementedError("packed_qkv only implemented for baichuan")
else: else:
bias = None bias = None
linear = get_linear(weight, bias) linear = get_linear(weight, bias)

View File

@ -130,28 +130,23 @@ class Qwen2Attention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# TODO: correctly handle the multimodal case _query = query.clone()
if False: _cos = cos.clone()
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) _sin = sin.clone()
else:
# multimodal rotary
unsqueeze_dim = 1
mrope_section = self.mrope_section * 2
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))],
dim=-1,
).unsqueeze(unsqueeze_dim)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))],
dim=-1,
).unsqueeze(unsqueeze_dim)
_query = query.transpose(0, 1).unsqueeze(0) self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin)
_key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0)
q_embed = (_query * cos) + (rotate_half(_query) * sin) _cos = torch.cat((_cos, _cos), dim=-1)
k_embed = (_key * cos) + (rotate_half(_key) * sin) _sin = torch.cat((_sin, _sin), dim=-1)
query = q_embed.squeeze(0).transpose(0, 1) q_emb = (_query * _cos).reshape(2, 1, -1) + (
kv[:, 0] = k_embed.squeeze(0).transpose(0, 1) rotate_half(_query) * _sin
).reshape(2, 1, -1)
k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + (
rotate_half(torch.select(kv, dim=1, index=0)) * _sin
).reshape(2, 1, -1)
query = q_emb.reshape(-1, self.num_heads, self.head_size)
kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size)
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices] kv_to_cache = kv[prefill_cache_indices]
@ -299,9 +294,6 @@ class Qwen2Model(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2Layer( Qwen2Layer(
@ -325,7 +317,7 @@ class Qwen2Model(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -335,25 +327,12 @@ class Qwen2Model(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds
# if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids # TODO: ensure we are getting the correct positional embeddings
if inputs_embeds is not None: cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
hidden_states = inputs_embeds.squeeze(0) position_ids[0, 0, :], true_max_s, hidden_states.dtype
else:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
# TODO: fix how N-D position_ids are handled
if position_ids.dim() == 2:
position_ids = position_ids.unsqueeze(0)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack(
position_ids, true_max_s, hidden_states.dtype
) )
residual = None residual = None
@ -393,6 +372,11 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix, prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights, weights=weights,
) )
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.max_past = config.sliding_window self.max_past = config.sliding_window
self.max_past_tensor = ( self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
@ -423,8 +407,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
# kernel requires the true values # kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelEmbedding,
FastLinear, FastLinear,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -78,11 +79,11 @@ class Qwen2VLSdpaAttention(nn.Module):
config, config,
prefix=f"{prefix}.qkv", prefix=f"{prefix}.qkv",
weights=weights, weights=weights,
bias=True, bias=False,
num_heads=self.num_heads, num_heads=self.num_heads,
num_key_value_heads=self.num_heads, num_key_value_heads=self.num_heads,
) )
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelColumnLinear.load( self.proj = TensorParallelColumnLinear.load(
config, config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
@ -285,7 +286,6 @@ class Qwen2VisionModel(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
aspect_ratio_ids: Optional[torch.Tensor] = None, aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None, grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# reshape the input tensor for processing # reshape the input tensor for processing
@ -361,7 +361,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.image_token_id = config.image_token_id self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix=f"model.embed_tokens", weights=weights
)
self.visual = Qwen2VisionModel( self.visual = Qwen2VisionModel(
prefix="visual", config=config.vision_config, weights=weights prefix="visual", config=config.vision_config, weights=weights
) )
@ -371,6 +373,93 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
self.device = weights.device self.device = weights.device
def get_position_ids(
self,
batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor],
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]:
position_ids = torch.ones(
3,
batch_input_ids.shape[0],
batch_input_ids.shape[1],
dtype=batch_input_ids.dtype,
device=batch_input_ids.device,
)
d = batch_input_ids.device
if image_grid_thw is not None:
image_index = 0
llm_pos_ids_list = []
for i, input_ids in enumerate(batch_input_ids):
vision_start_indices = torch.argwhere(
input_ids == self.vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
# only copy the sum of the image tokens GPU<->CPU
image_count = (vision_tokens == self.image_token_id).sum().item()
current_pos = 0
for _ in range(image_count):
# copy the value position of the next image token from GPU<->CPU
next_image_pos = (
(input_ids[current_pos:] == self.image_token_id)
.nonzero()[0]
.item()
)
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
time_steps, height, width = image_grid_thw[image_index]
height //= self.spatial_merge_size
width //= self.spatial_merge_size
# calculate the length of the text and image tokens
text_length = next_image_pos - current_pos
start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
)
# text position ids
text_pos_ids = torch.arange(text_length, device=d)
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids)
# image position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width
)
h_indices = (
torch.arange(height, device=d)
.repeat_interleave(width)
.repeat(time_steps)
)
w_indices = torch.arange(width, device=d).repeat(
height * time_steps
)
image_pos_ids = (
torch.stack([t_indices, h_indices, w_indices])
+ text_length
+ start_idx
)
llm_pos_ids_list.append(image_pos_ids)
current_pos = next_image_pos + time_steps * height * width
image_index += 1
if current_pos < batch_input_ids.size(1):
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = batch_input_ids.size(1) - current_pos
llm_pos_ids_list.append(
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[:, i, :] = llm_positions.to(position_ids.device)
return position_ids
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -392,147 +481,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
): ):
# make an attention_mask that is (batch_size, sequence_length) inputs_embeds = self.embed_tokens(input_ids)
attention_mask = torch.ones_like(
input_ids, dtype=torch.bool, device=input_ids.device
)
inputs_embeds = self.text_model.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided # apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = ( inputs_embeds[input_ids == self.image_token_id] = image_embeds
(input_ids == self.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
# input embeddings are masked with image embeddings
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# handle the position_ids taking the multimodal inputs into account
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
total_input_ids = input_ids
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
if attention_mask is not None:
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == self.vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
# determine the number of images and videos in the input
image_nums = (vision_tokens == self.image_token_id).sum()
video_nums = (vision_tokens == self.video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
# process each input based on it's token type and grid size
for _ in range(image_nums + video_nums):
if self.image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(self.image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if self.video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(self.video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // self.spatial_merge_size,
w.item() // self.spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
position_ids = self.get_position_ids(input_ids, image_grid_thw)
outputs = self.text_model( outputs = self.text_model(
input_ids=input_ids, inputs_embeds=inputs_embeds.squeeze(0),
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
@ -542,8 +501,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
) )
logits = self.lm_head(outputs) logits = self.lm_head(outputs)