2024-10-30 16:40:51 +00:00
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Qwen2 VL model. """
from typing import Optional , Tuple , List
import torch
import torch . utils . checkpoint
from torch import nn
from text_generation_server . utils . import_utils import SYSTEM
if SYSTEM == " ipex " :
2024-11-18 17:46:40 +00:00
import intel_extension_for_pytorch as ipex
2024-10-30 16:40:51 +00:00
else :
2024-11-18 17:46:40 +00:00
import flash_attn_2_cuda
import numpy as np
2024-10-30 16:40:51 +00:00
from transformers . activations import ACT2FN
import torch . nn . functional as F
from text_generation_server . layers . layernorm import FastLayerNorm , FastRMSNorm
from text_generation_server . layers import (
TensorParallelColumnLinear ,
TensorParallelRowLinear ,
TensorParallelEmbedding ,
2024-11-01 02:05:34 +00:00
SpeculativeHead ,
2024-10-30 16:40:51 +00:00
)
from text_generation_server . layers . attention import (
Seqlen ,
)
from text_generation_server . models . custom_modeling . flash_qwen2_modeling import (
Qwen2Model ,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half ( x ) :
""" Rotates half the hidden dims of the input. """
x1 = x [ . . . , : x . shape [ - 1 ] / / 2 ]
x2 = x [ . . . , x . shape [ - 1 ] / / 2 : ]
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
def apply_rotary_pos_emb_vision (
tensor : torch . Tensor , freqs : torch . Tensor
) - > torch . Tensor :
orig_dtype = tensor . dtype
tensor = tensor . float ( )
cos = freqs . cos ( )
sin = freqs . sin ( )
cos = cos . unsqueeze ( 1 ) . repeat ( 1 , 1 , 2 ) . unsqueeze ( 0 ) . float ( )
sin = sin . unsqueeze ( 1 ) . repeat ( 1 , 1 , 2 ) . unsqueeze ( 0 ) . float ( )
output = ( tensor * cos ) + ( rotate_half ( tensor ) * sin )
output = output . to ( orig_dtype )
return output
2024-11-18 17:46:40 +00:00
class Qwen2VLAttention ( nn . Module ) :
2024-10-30 16:40:51 +00:00
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
2024-11-01 02:05:34 +00:00
self . embed_dim = config . embed_dim / / weights . process_group . size ( )
2024-10-30 16:40:51 +00:00
self . head_dim = config . hidden_size / / config . num_heads
self . num_heads = config . num_heads / / weights . process_group . size ( )
self . qkv = TensorParallelColumnLinear . load_qkv (
config ,
prefix = f " { prefix } .qkv " ,
weights = weights ,
bias = False ,
num_heads = self . num_heads ,
num_key_value_heads = self . num_heads ,
)
self . qkv . linear . bias = weights . get_sharded ( f " { prefix } .qkv.bias " , dim = 0 )
2024-11-01 02:05:34 +00:00
self . proj = TensorParallelRowLinear . load (
2024-10-30 16:40:51 +00:00
config ,
prefix = f " { prefix } .proj " ,
weights = weights ,
bias = True ,
)
2024-11-18 17:46:40 +00:00
self . softmax_scale = 1.0 / np . sqrt ( self . embed_dim / / self . num_heads )
2024-10-30 16:40:51 +00:00
def forward (
self ,
hidden_state : torch . Tensor ,
cu_seqlens : torch . Tensor ,
rotary_pos_emb : torch . Tensor ,
2024-11-18 17:46:40 +00:00
max_seqlen : int ,
2024-10-30 16:40:51 +00:00
) - > torch . Tensor :
# apply the qkv linear layer to the hidden state
qkv = self . qkv ( hidden_state )
query , key , value = qkv . split (
[ self . embed_dim , self . embed_dim , self . embed_dim ] , dim = 1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state . shape [ 0 ] ,
self . num_heads ,
self . embed_dim / / self . num_heads ,
)
query = query . view ( * _shape )
key = key . view ( * _shape )
value = value . view ( * _shape )
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision ( query . unsqueeze ( 0 ) , rotary_pos_emb ) . squeeze (
0
)
key = apply_rotary_pos_emb_vision ( key . unsqueeze ( 0 ) , rotary_pos_emb ) . squeeze ( 0 )
2024-11-18 17:46:40 +00:00
# calc maximum sequence length for any batch
query = query . contiguous ( )
key = key . contiguous ( )
value = value . contiguous ( )
causal = False
# execute flash attention
if SYSTEM == " ipex " :
attn_output = torch . empty_like ( query )
ipex . llm . functional . varlen_attention (
( query . contiguous ( ) if query . device . type == " xpu " else query ) ,
( key . contiguous ( ) if key . device . type == " xpu " else key ) ,
( value . contiguous ( ) if value . device . type == " xpu " else value ) ,
attn_output ,
cu_seqlens ,
cu_seqlens ,
max_seqlen ,
max_seqlen ,
0.0 ,
self . softmax_scale ,
False ,
causal ,
False ,
None ,
)
else :
attn_output = flash_attn_2_cuda . varlen_fwd (
query ,
key ,
value ,
None , # tmp buffer (auto-allocated)
cu_seqlens , # cu_seqlens_q
cu_seqlens , # cu_seqlens_k
None , # max_seqlen_q (auto-computed)
None , # max_seqlen_k (auto-computed)
None , # block_tables
None , # broadcast_mask
max_seqlen , # max_seqlen
max_seqlen , # max_seqlen
0.0 , # dropout_p
self . softmax_scale ,
False , # zero_tensors
causal , # causal attention within each sequence
- 1 , # window_size_left
- 1 , # window_size_right
0.0 , # softmax_cap
False , # deterministic
None , # rng_state
) [ 0 ]
# reshape output to original dimensions
2024-10-30 16:40:51 +00:00
attn_output = attn_output . reshape ( hidden_state . shape [ 0 ] , - 1 )
attn_output = self . proj ( attn_output )
return attn_output
class Qwen2VLVisionMLP ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . activation_fn = ACT2FN [ config . hidden_act ]
self . fc1 = TensorParallelColumnLinear . load (
prefix = f " { prefix } .fc1 " , weights = weights , config = config , bias = True
)
self . fc2 = TensorParallelRowLinear . load (
prefix = f " { prefix } .fc2 " , weights = weights , config = config , bias = True
)
def forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
hidden_states = self . fc1 ( hidden_states )
hidden_states = self . activation_fn ( hidden_states )
hidden_states = self . fc2 ( hidden_states )
return hidden_states
class Qwen2VLVisionBlock ( nn . Module ) :
def __init__ ( self , prefix , config , weights ) :
super ( ) . __init__ ( )
2024-11-18 17:46:40 +00:00
self . attn = Qwen2VLAttention (
2024-10-30 16:40:51 +00:00
prefix = f " { prefix } .attn " ,
config = config ,
weights = weights ,
)
self . norm1 = FastLayerNorm . load (
prefix = f " { prefix } .norm1 " ,
weights = weights ,
eps = 1e-6 ,
)
self . norm2 = FastLayerNorm . load (
prefix = f " { prefix } .norm2 " ,
weights = weights ,
eps = 1e-6 ,
)
self . mlp = Qwen2VLVisionMLP (
prefix = f " { prefix } .mlp " ,
config = config ,
weights = weights ,
)
2024-11-18 17:46:40 +00:00
def forward (
self , hidden_states , cu_seqlens , rotary_pos_emb , max_seqlen
) - > torch . Tensor :
2025-02-04 17:44:18 +00:00
norm1_out , residual = self . norm1 ( hidden_states )
attn_out = self . attn ( norm1_out , cu_seqlens , rotary_pos_emb , max_seqlen )
hidden_states = attn_out + residual
norm2_out , residual = self . norm2 ( hidden_states )
hidden_states = hidden_states + self . mlp ( norm2_out )
2024-10-30 16:40:51 +00:00
return hidden_states
class Qwen2VLPatchMerger ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . hidden_size = config . embed_dim * ( config . spatial_merge_size * * 2 )
self . patch_merger_ln_q = FastLayerNorm . load (
prefix = f " { prefix } .ln_q " ,
weights = weights ,
eps = 1e-6 ,
)
self . fc1 = TensorParallelColumnLinear . load (
prefix = f " { prefix } .mlp.0 " , weights = weights , config = config , bias = True
)
self . fc2 = TensorParallelRowLinear . load (
prefix = f " { prefix } .mlp.2 " , weights = weights , config = config , bias = True
)
2024-11-18 17:46:40 +00:00
def forward ( self , hidden_states ) - > torch . Tensor :
2024-10-30 16:40:51 +00:00
hidden_states , _ = self . patch_merger_ln_q ( hidden_states )
hidden_states = hidden_states . view ( - 1 , self . hidden_size )
hidden_states = self . fc1 ( hidden_states )
hidden_states = F . gelu ( hidden_states )
hidden_states = self . fc2 ( hidden_states )
return hidden_states
class Qwen2VisionModel ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . spatial_merge_size = config . spatial_merge_size
kernel_size = [ config . temporal_patch_size , config . patch_size , config . patch_size ]
self . patch_embedding = nn . Conv3d (
in_channels = config . in_chans ,
out_channels = config . embed_dim ,
kernel_size = kernel_size ,
stride = kernel_size ,
bias = False ,
)
self . patch_embedding . weight = nn . Parameter (
weights . get_tensor ( f " { prefix } .patch_embed.proj.weight " ) , requires_grad = False
)
head_dim = config . embed_dim / / config . num_heads
# TODO: replace with static positional embeddings once implemented
theta = 10000.0
dim = head_dim / / 2
inv_freq = 1.0 / ( theta * * ( torch . arange ( 0 , dim , 2 , dtype = torch . float ) / dim ) )
self . register_buffer ( " inv_freq " , inv_freq , persistent = False )
self . blocks = nn . ModuleList (
[
Qwen2VLVisionBlock (
prefix = f " { prefix } .blocks. { i } " ,
config = config ,
weights = weights ,
)
for i in range ( config . depth )
]
)
self . merger = Qwen2VLPatchMerger (
prefix = f " { prefix } .merger " ,
config = config ,
weights = weights ,
)
self . temporal_patch_size = config . temporal_patch_size
self . spatial_patch_size = config . spatial_patch_size
self . in_channels = config . in_channels
self . embed_dim = config . embed_dim
def apply_class_embedding ( self , hidden_state : torch . Tensor ) - > torch . Tensor :
batch_size , _ , hidden_size = hidden_state . shape
class_embedding = self . class_embedding . expand ( batch_size , 1 , hidden_size )
hidden_state = torch . cat ( [ class_embedding , hidden_state ] , dim = 1 )
return hidden_state
def forward (
self ,
pixel_values : torch . Tensor ,
grid_thw : Optional [ torch . LongTensor ] = None ,
) - > torch . Tensor :
# reshape the input tensor for processing
shape = (
- 1 ,
self . in_channels ,
self . temporal_patch_size ,
self . spatial_patch_size ,
self . spatial_patch_size ,
)
pixel_values = pixel_values . view ( shape ) . to ( self . patch_embedding . weight . dtype )
hidden_states = self . patch_embedding ( pixel_values ) . view ( - 1 , self . embed_dim )
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = [ ]
for t , h , w in grid_thw :
hpos_ids = torch . arange ( h ) . unsqueeze ( 1 ) . expand ( - 1 , w )
hpos_ids = hpos_ids . reshape (
h / / self . spatial_merge_size ,
self . spatial_merge_size ,
w / / self . spatial_merge_size ,
self . spatial_merge_size ,
)
hpos_ids = hpos_ids . permute ( 0 , 2 , 1 , 3 )
hpos_ids = hpos_ids . flatten ( )
wpos_ids = torch . arange ( w ) . unsqueeze ( 0 ) . expand ( h , - 1 )
wpos_ids = wpos_ids . reshape (
h / / self . spatial_merge_size ,
self . spatial_merge_size ,
w / / self . spatial_merge_size ,
self . spatial_merge_size ,
)
wpos_ids = wpos_ids . permute ( 0 , 2 , 1 , 3 )
wpos_ids = wpos_ids . flatten ( )
pos_ids . append ( torch . stack ( [ hpos_ids , wpos_ids ] , dim = - 1 ) . repeat ( t , 1 ) )
pos_ids = torch . cat ( pos_ids , dim = 0 )
max_grid_size = grid_thw [ : , 1 : ] . max ( )
# apply the positional embeddings to the position ids
seq = torch . arange (
max_grid_size , device = self . inv_freq . device , dtype = self . inv_freq . dtype
)
rotary_pos_emb_full = torch . outer ( seq , self . inv_freq )
rotary_pos_emb = rotary_pos_emb_full [ pos_ids ] . flatten ( 1 )
rotary_pos_emb = rotary_pos_emb . to ( hidden_states . device , hidden_states . dtype )
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch . repeat_interleave (
grid_thw [ : , 1 ] * grid_thw [ : , 2 ] , grid_thw [ : , 0 ]
) . cumsum ( dim = 0 , dtype = torch . int32 )
cu_seqlens = F . pad ( cu_seqlens , ( 1 , 0 ) , value = 0 )
2024-11-18 17:46:40 +00:00
max_seqlen = torch . max ( cu_seqlens [ 1 : ] - cu_seqlens [ : - 1 ] )
2024-10-30 16:40:51 +00:00
# iterately apply the blocks to the hidden states
for block in self . blocks :
2024-11-18 17:46:40 +00:00
hidden_states = block ( hidden_states , cu_seqlens , rotary_pos_emb , max_seqlen )
2024-10-30 16:40:51 +00:00
# apply the final patch merger to the hidden states
2024-11-18 17:46:40 +00:00
hidden_states = self . merger ( hidden_states )
2024-10-30 16:40:51 +00:00
return hidden_states
class Qwen2VLForConditionalGeneration ( nn . Module ) :
def __init__ ( self , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
config . vision_config . quantize = None
config . vision_config . speculator = config . speculator
2025-02-04 17:44:18 +00:00
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment
2025-02-19 11:38:20 +00:00
if (
hasattr ( config , " rope_scaling " )
and config . rope_scaling is not None
and config . rope_scaling . get ( " type " , None ) == " default "
) :
config . rope_scaling . update ( { " rope_type " : " mrope " } )
2024-10-30 16:40:51 +00:00
self . hidden_size = config . hidden_size
self . vision_start_token_id = config . vision_start_token_id
2025-02-04 17:44:18 +00:00
self . vision_end_token_id = config . vision_end_token_id
2024-10-30 16:40:51 +00:00
self . image_token_id = config . image_token_id
self . video_token_id = config . video_token_id
self . spatial_merge_size = config . vision_config . spatial_merge_size
self . embed_tokens = TensorParallelEmbedding (
prefix = " model.embed_tokens " , weights = weights
)
self . visual = Qwen2VisionModel (
prefix = " visual " , config = config . vision_config , weights = weights
)
self . text_model = Qwen2Model ( prefix = None , config = config , weights = weights )
2024-11-01 02:05:34 +00:00
if config . tie_word_embeddings :
suffix = " model.embed_tokens "
else :
suffix = " lm_head "
self . lm_head = SpeculativeHead . load (
config ,
prefix = suffix if not prefix else f " { prefix } . { suffix } " ,
weights = weights ,
2024-10-30 16:40:51 +00:00
)
self . norm = FastRMSNorm . load (
prefix = " model.norm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
self . device = weights . device
2025-02-04 17:44:18 +00:00
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
2024-10-30 16:40:51 +00:00
def get_position_ids (
self ,
2025-02-04 17:44:18 +00:00
input_ids : torch . Tensor ,
image_grid_thw : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
if image_grid_thw is None :
return (
torch . arange ( input_ids . shape [ 0 ] , device = input_ids . device )
. unsqueeze ( 1 )
. repeat ( 1 , 3 )
)
2024-10-30 16:40:51 +00:00
2025-02-04 17:44:18 +00:00
spatial_merge_size = self . spatial_merge_size
vision_start_token_id = self . vision_start_token_id
vision_end_token_id = self . vision_end_token_id
device = input_ids . device
dtype = input_ids . dtype
input_ids_len = input_ids . shape [ 0 ]
vision_starts = torch . where ( input_ids == vision_start_token_id ) [ 0 ]
vision_ends = torch . where ( input_ids == vision_end_token_id ) [ 0 ]
vision_segments = torch . stack ( ( vision_starts , vision_ends ) , dim = 1 )
prev_vision_end = torch . cat (
[ torch . zeros ( 1 , device = vision_ends . device , dtype = dtype ) , vision_ends [ : - 1 ] ]
)
text_lengths_between_vision = vision_segments [ : , 0 ] - prev_vision_end + 1
vision_widths_max = torch . cat (
[
torch . zeros ( 1 , device = image_grid_thw . device , dtype = dtype ) ,
image_grid_thw [ : - 1 , 2 ] / / spatial_merge_size ,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths . cumsum ( dim = 0 )
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = [ ]
for i , _ in enumerate ( vision_segments ) :
t , h , w = (
image_grid_thw [ i ] [ 0 ] ,
image_grid_thw [ i ] [ 1 ] / / spatial_merge_size ,
image_grid_thw [ i ] [ 2 ] / / spatial_merge_size ,
2024-11-02 00:40:05 +00:00
)
2025-02-04 17:44:18 +00:00
t_indices = torch . arange ( t , device = device ) . repeat_interleave ( h * w )
h_indices = torch . arange ( h , device = device ) . repeat_interleave ( w ) . repeat ( t )
w_indices = torch . arange ( w , device = device ) . repeat ( t * h )
image_position_ids = torch . stack ( [ t_indices , h_indices , w_indices ] , dim = 0 )
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths [ i ]
llm_pos_ids_list . append ( im )
# create position ids for each text segment
text_ranges = [
torch . arange ( seq_len , device = device ) . view ( 1 , - 1 ) . expand ( 3 , - 1 )
+ text_segment_lengths [ i ]
for i , seq_len in enumerate ( text_lengths_between_vision )
]
full_llm_pos_ids_list = [
item for sublist in zip ( text_ranges , llm_pos_ids_list ) for item in sublist
]
max_s = full_llm_pos_ids_list [ - 1 ] . max ( ) + 1
final_text_len = input_ids_len - vision_ends [ - 1 ]
if final_text_len > 0 :
m = torch . arange ( final_text_len , device = device ) . view ( 1 , - 1 ) . expand ( 3 , - 1 )
full_llm_pos_ids_list . append ( m + max_s )
position_ids = (
torch . cat ( full_llm_pos_ids_list , dim = 1 ) . reshape ( 3 , - 1 ) . transpose ( 0 , 1 )
)
2024-10-30 16:40:51 +00:00
return position_ids
def forward (
self ,
input_ids : torch . Tensor ,
position_ids : torch . Tensor ,
cu_seqlen_prefill : Optional [ torch . Tensor ] ,
kv_cache : List [ Tuple [ torch . Tensor , torch . Tensor ] ] ,
block_tables : torch . Tensor ,
slots : torch . Tensor ,
seqlen : Seqlen ,
max_s : int ,
prefill_cache_indices : Optional [ torch . Tensor ] ,
lm_head_indices : Optional [ torch . Tensor ] ,
pixel_values : torch . FloatTensor = None ,
image_grid_thw : Optional [ torch . LongTensor ] = None ,
video_grid_thw : Optional [ torch . LongTensor ] = None ,
2025-01-17 17:09:05 +00:00
pixel_attention_mask = None ,
2024-10-30 16:40:51 +00:00
image_sizes : Optional [ torch . LongTensor ] = None ,
adapter_data : Optional [ torch . Tensor ] = None ,
cross_attention_states : Optional [ torch . Tensor ] = None ,
2025-01-17 17:09:05 +00:00
image_indices = None ,
2024-10-30 16:40:51 +00:00
) :
inputs_embeds = self . embed_tokens ( input_ids )
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len ( pixel_values ) > 0 :
if pixel_values is not None :
image_embeds = self . visual (
pixel_values , grid_thw = image_grid_thw
) . squeeze ( 0 )
inputs_embeds [ input_ids == self . image_token_id ] = image_embeds
hidden_states = self . text_model (
inputs_embeds = inputs_embeds ,
position_ids = position_ids ,
cu_seqlen_prefill = cu_seqlen_prefill ,
kv_cache = kv_cache ,
block_tables = block_tables ,
slots = slots ,
seqlen = seqlen ,
max_s = max_s ,
true_max_s = max_s ,
prefill_cache_indices = prefill_cache_indices ,
)
2024-11-01 02:05:34 +00:00
if lm_head_indices is not None :
hidden_states = hidden_states [ lm_head_indices ]
logits , speculative_logits = self . lm_head ( hidden_states )
return logits , speculative_logits