Fixed PaliGemma.

This commit is contained in:
Nicolas Patry 2024-05-14 15:58:19 +00:00
parent 67e833cedb
commit c119ac4d1d
12 changed files with 195 additions and 371 deletions

View File

@ -100,7 +100,6 @@ impl LlavaNext {
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
image_size: usize,
@ -108,7 +107,6 @@ pub struct ClipVisionModel {
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
@ -119,18 +117,20 @@ impl Idefics2 {
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {}
pub struct PaliTextConfig {
num_image_tokens: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
}
impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
// TODO: improve to calculate based on height and width
// 224 = 256 image tokens
// 448 = 1024 image tokens
// 896 = 4096 image tokens
256
self.text_config.num_image_tokens
}
}

View File

@ -64,6 +64,9 @@ try:
from text_generation_server.models.flash_gemma import (
FlashGemma,
)
from text_generation_server.models.pali_gemma import (
PaliGemma,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
@ -654,6 +657,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma":
if FLASH_ATTENTION:
return PaliGemma(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if FLASH_ATTENTION:

View File

@ -153,15 +153,11 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
@ -238,6 +234,7 @@ class FlashGemmaAttention(torch.nn.Module):
cu_seqlen_prefill,
max_s,
self.softmax_scale,
causal=self.causal,
)
# Decode
else:
@ -295,10 +292,10 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
)
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -350,7 +347,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
@ -362,6 +359,7 @@ class FlashGemmaModel(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
causal=causal,
)
for layer_id in range(config.num_hidden_layers)
]
@ -378,7 +376,7 @@ class FlashGemmaModel(torch.nn.Module):
def forward(
self,
input_embeds: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -387,7 +385,7 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = input_embeds
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -416,7 +414,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
embed_norm = config.hidden_size**0.5
@ -430,7 +428,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights)
self.model = FlashGemmaModel(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"

View File

@ -19,134 +19,14 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.layers import TensorParallelColumnLinear
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
GemmaConfig,
)
class VisionConfig(PretrainedConfig):
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
model_type: str = "siglip_vision_model",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
num_image_tokens: int = 256,
patch_size: int = 14,
projection_dim: int = 2048,
projector_hidden_act: str = "gelu_fast",
vision_use_head: bool = False,
vocab_size: int = 257152,
quantize: Optional[str] = None,
image_size: int = 224,
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
hidden_act: str = "gelu_pytorch_tanh",
num_channels: int = 3,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.num_channels = num_channels
super().__init__(**kwargs)
class PaliGemmaConfig(PretrainedConfig):
model_type = "paligemma"
def __init__(
self,
text_config: GemmaConfig,
vision_config: VisionConfig,
vocab_size: int = 257152,
image_token_index: int = 256000,
**kwargs,
):
self.text_config = text_config
self.vision_config = vision_config
self.vocab_size = vocab_size
self.image_token_index = image_token_index
self.intermediate_size = text_config.intermediate_size
self.num_hidden_layers = text_config.num_hidden_layers
self.num_key_value_heads = text_config.num_key_value_heads
self.num_attention_heads = text_config.num_attention_heads
super().__init__(**kwargs)
def from_pretrained(pretrained_model_name_or_path, **kwargs):
vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
)
text_config = GemmaConfig.from_pretrained(
pretrained_model_name_or_path,
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
hidden_size=2048,
initializer_range=0.02,
intermediate_size=16384,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_image_tokens=256,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.40.0.dev0",
use_cache=True,
vocab_size=257216,
**kwargs,
)
return PaliGemmaConfig(
text_config=text_config,
vision_config=vision_config,
**kwargs,
)
class FlashPaliGemmaForConditionalGeneration(nn.Module):
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
@ -166,6 +46,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
self.vocab_size = config.vocab_size
self.config = config
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
@ -188,36 +71,28 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pixel_attention_mask=None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None and len(pixel_values) > 0:
# TODO: avoid these casts upstream
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# TODO: now we scale them? maybe we can do this up or downstream
scaled_image_features = image_features / (
self.config.text_config.hidden_size**0.5
)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index | (input_ids == 2)
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
# normalizer = torch.tensor(
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
# )
# inputs_embeds = inputs_embeds * normalizer
inputs_embeds[mask] = scaled_image_features.view(
-1, scaled_image_features.shape[-1]
)
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.language_model.model(
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
@ -230,6 +105,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -16,7 +16,7 @@ from transformers.modeling_outputs import (
)
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from text_generation_server.utils.layers import (
from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,

View File

@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None):
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,

View File

@ -3,8 +3,7 @@ import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast
from transformers import AutoConfig
from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
@ -36,14 +35,12 @@ class FlashGemma(FlashCausalLM):
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = GemmaTokenizerFast.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = AutoConfig.from_pretrained(
@ -61,7 +58,7 @@ class FlashGemma(FlashCausalLM):
# TODO hardcoded
prefix = "language_model"
model = FlashGemmaForCausalLM(prefix, config, weights)
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(

View File

@ -1,50 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashPaliGemmaForConditionalGeneration,
PaliGemmaConfig,
)
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__)
class FlashPaliGemma(PaliVlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
"google/siglip-base-patch16-224",
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=PaliGemmaConfig,
model_cls=FlashPaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -0,0 +1,123 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
load_data_uri,
split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str:
return "<image>" * num_features
elif config.model_type == "paligemma":
# TODO: use correct number of features
return "<image>" * 256
return "<image>" * config.text_config.num_image_tokens
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -174,7 +173,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
@ -198,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
@ -376,142 +378,3 @@ class VlmCausalLM(BaseFlashMistral):
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class PaliVlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
text_inputs = []
image_text_replacements = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
text_inputs.append(chunk["content"])
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
text_replacement = image_text_replacement(
image_input, config, image_id
)
full_text += text_replacement
image_text_replacements.append(text_replacement)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
image_token = tokenizer.get_added_vocab()["<image>"]
# find the index of the first non-image token
for batch in batch_tokenized_inputs:
first_non_image = 0
for i, token in enumerate(batch):
if token != image_token:
first_non_image = i
break
# manually add the bos to the left of the text
batch_tokenized_inputs = [
batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:]
for batch in batch_tokenized_inputs
]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "PaliVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch

View File

@ -14,9 +14,9 @@ from typing import List, Optional
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
PaliVlmCausalLMBatch,
)
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in {
IdeficsCausalLMBatch,
VlmCausalLMBatch,
PaliVlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
@ -126,7 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in {
IdeficsCausalLMBatch,
VlmCausalLMBatch,
PaliVlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,

View File

@ -116,6 +116,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
@ -134,7 +135,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
0.0,
softmax_scale,
False,
True,
causal,
window_size_left,
0,
False,