mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: vendor processor and config from transformers
This commit is contained in:
parent
07c0080970
commit
e4e6ea2598
@ -166,6 +166,8 @@ try:
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
|
||||
Qwen2_5VLForConditionalGeneration,
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLProcessor,
|
||||
)
|
||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||
except ImportError as e:
|
||||
@ -1388,6 +1390,8 @@ def get_model(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=Qwen2_5_VLConfig,
|
||||
processor_class=Qwen2_5_VLProcessor,
|
||||
)
|
||||
if model_type == MLLAMA:
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -29,6 +29,8 @@ else:
|
||||
import numpy as np
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
@ -45,6 +47,334 @@ from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
|
||||
from typing import Union
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput, VideoInput
|
||||
from transformers.processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
VideosKwargs,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
|
||||
fps: Union[List[float], float]
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"videos_kwargs": {"fps": 2.0},
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
|
||||
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(
|
||||
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
|
||||
):
|
||||
self.image_token = (
|
||||
"<|image_pad|>"
|
||||
if not hasattr(tokenizer, "image_token")
|
||||
else tokenizer.image_token
|
||||
)
|
||||
self.video_token = (
|
||||
"<|video_pad|>"
|
||||
if not hasattr(tokenizer, "video_token")
|
||||
else tokenizer.video_token
|
||||
)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[
|
||||
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
||||
] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
||||
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen2_5_VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(
|
||||
images=images, videos=None, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(
|
||||
images=None, videos=videos, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||
if isinstance(fps, (int, float)):
|
||||
second_per_grid_ts = [
|
||||
self.image_processor.temporal_patch_size / fps
|
||||
] * len(video_grid_thw)
|
||||
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||
second_per_grid_ts = [
|
||||
self.image_processor.temporal_patch_size / tmp for tmp in fps
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||
)
|
||||
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.image_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.image_token,
|
||||
"<|placeholder|>"
|
||||
* (image_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.video_token,
|
||||
"<|placeholder|>"
|
||||
* (video_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
names_from_processor = list(
|
||||
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
|
||||
)
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
hidden_size=3584,
|
||||
hidden_act="silu",
|
||||
intermediate_size=3420,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
spatial_patch_size=14,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=4,
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if vision_config is not None:
|
||||
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
# TODO: @raushan update config in the hub
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
@ -273,7 +603,7 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_channels=config.in_chans,
|
||||
in_channels=config.in_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
@ -304,7 +634,7 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
@ -506,52 +836,52 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
self.device = weights.device
|
||||
|
||||
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
||||
# modified to first find segments then initialize position ids for each segment
|
||||
# Steps:
|
||||
# locate all vision and text segments
|
||||
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
||||
# calculate `text_segment_lengths` for each text segment to be used as offset
|
||||
# create position ids for each vision segment based on the image grid
|
||||
# create position ids for each text segment
|
||||
# combine all the position ids
|
||||
# the final segment is the difference between the last vision segment and the end of the input
|
||||
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||
def get_position_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if image_grid_thw is None:
|
||||
# (batch_size, 3)
|
||||
return (
|
||||
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, 3)
|
||||
)
|
||||
|
||||
# if image grid provided than we need to calculate the position ids
|
||||
spatial_merge_size = self.spatial_merge_size
|
||||
vision_start_token_id = self.vision_start_token_id
|
||||
vision_end_token_id = self.vision_end_token_id
|
||||
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
input_ids_len = input_ids.shape[0]
|
||||
|
||||
# capture vision segments
|
||||
starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||
ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||
# ie. [[ 14, 2181], [2212, 4379]]
|
||||
vision_segments = torch.stack((starts, ends), dim=1)
|
||||
# capture text lengths as the space between vision segments
|
||||
|
||||
prev_end = torch.cat( # shift to the left to get the previous end
|
||||
[torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]]
|
||||
) # ie. [0, 2181]
|
||||
|
||||
# text is the space between the end of one vision segment and the start of the next
|
||||
text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32]
|
||||
|
||||
# calculate the max id from the image width for each segment
|
||||
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||
prev_vision_end = torch.cat(
|
||||
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||
)
|
||||
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||
vision_widths_max = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||
]
|
||||
)
|
||||
total_segment_lengths = vision_widths_max + text_lengths
|
||||
total_segment_lengths = total_segment_lengths.cumsum(dim=0)
|
||||
text_diff = total_segment_lengths - text_lengths
|
||||
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||
|
||||
# create position ids for each vision segment based on the image grid
|
||||
llm_pos_ids_list = []
|
||||
@ -567,29 +897,28 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||
|
||||
# offset by the position of the last vision segment
|
||||
im = image_position_ids + total_segment_lengths[i]
|
||||
im = image_position_ids + vision_segment_lengths[i]
|
||||
llm_pos_ids_list.append(im)
|
||||
|
||||
# create position ids for each text segment
|
||||
text_ranges = [
|
||||
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||
+ text_diff[i]
|
||||
for i, seq_len in enumerate(text_lengths)
|
||||
] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]]
|
||||
+ text_segment_lengths[i]
|
||||
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||
]
|
||||
|
||||
# combine by alternating text and vision segments (text, vision, text, vision, ...)
|
||||
full_llm_pos_ids_list = [
|
||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||
]
|
||||
# import ipdb
|
||||
|
||||
# the final segment is the difference between the last vision segment and the end of the input
|
||||
# ipdb.set_trace()
|
||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||
final_text_len = input_ids_len - ends[-1]
|
||||
final_text_len = input_ids_len - vision_ends[-1]
|
||||
if final_text_len > 0:
|
||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||
full_llm_pos_ids_list.append(m + max_s)
|
||||
|
||||
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||
position_ids = (
|
||||
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user