mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Operational.
This commit is contained in:
parent
613dc93617
commit
ae2b4e1c23
@ -114,8 +114,8 @@ impl Client {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
inputs.push_str("");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
|
@ -66,7 +66,8 @@ impl LlavaNext {
|
||||
let (num_patch_height, num_patch_width) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
// Ceil
|
||||
let height_of_patch = (height * npatches + width - 1) / width;
|
||||
// TODO Very odd artifact when the rounding is super close
|
||||
let height_of_patch = (height * npatches + width - 10) / width;
|
||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
||||
// They are only added after width
|
||||
let newline_features = height_of_patch * num_patch_width;
|
||||
@ -166,5 +167,7 @@ mod test {
|
||||
assert_eq!(slots, 2732);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
let slots = config.get_number_of_features(1067, 1600);
|
||||
assert_eq!(slots, 2144);
|
||||
}
|
||||
}
|
||||
|
@ -739,6 +739,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
|
@ -170,6 +170,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
|
@ -26,6 +26,8 @@ class Idefics2(VlmCausalLM):
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
size={"longest_edge": 448, "shortest_edge": 378},
|
||||
)
|
||||
super().__init__(
|
||||
|
@ -64,6 +64,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(image_input, config) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
# TODO technically depends on image splitting which is not implemented.
|
||||
num_features = 320
|
||||
return (
|
||||
"<fake_token_around_image>"
|
||||
+ "<image>" * num_features
|
||||
+ "<fake_token_around_image>"
|
||||
)
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
from loguru import logger
|
||||
|
||||
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||
return "<image>" * num_features
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
@ -82,7 +102,7 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
image_size,
|
||||
)
|
||||
|
||||
height_of_patch = math.ceil(height / width * npatches)
|
||||
height_of_patch = (height * npatches + width - 10) // width
|
||||
|
||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||
# They are only added after width
|
||||
@ -99,12 +119,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||||
# assert get_number_of_features(640, 640) == 2928
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashMistralBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
|
||||
@classmethod
|
||||
@ -112,6 +129,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@ -119,6 +137,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@ -147,11 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# height, width = image_input["image_sizes"][0]
|
||||
# num_features = get_number_of_features(height, width, config)
|
||||
num_features = 320
|
||||
full_text += "<image>" * num_features
|
||||
full_text += image_text_replacement(image_input, config)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
@ -163,15 +178,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_inputs = {
|
||||
image_input = image_inputs[0]
|
||||
new_image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
"pixel_attention_mask": torch.cat(
|
||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||
),
|
||||
# "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||
}
|
||||
if "pixel_attention_mask" in image_input:
|
||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||
)
|
||||
if "image_sizes" in image_input:
|
||||
new_image_inputs["image_sizes"] = torch.cat(
|
||||
[img["image_sizes"] for img in image_inputs], dim=0
|
||||
)
|
||||
image_inputs = new_image_inputs
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
@ -192,14 +213,20 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||
device=device
|
||||
)
|
||||
# batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.pixel_attention_mask = None
|
||||
if "image_sizes" in image_inputs:
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.image_sizes = None
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
# batch.image_sizes = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
||||
@ -291,7 +318,7 @@ class VlmCausalLM(BaseFlashMistral):
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
# image_sizes=batch.image_sizes,
|
||||
image_sizes=batch.image_sizes,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
@ -299,8 +326,8 @@ class VlmCausalLM(BaseFlashMistral):
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
# if batch.image_sizes is not None:
|
||||
# batch.image_sizes = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
|
Loading…
Reference in New Issue
Block a user