Operational.

This commit is contained in:
Nicolas Patry 2024-04-19 22:39:30 +00:00
parent 613dc93617
commit ae2b4e1c23
6 changed files with 61 additions and 25 deletions

View File

@ -114,8 +114,8 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
requests.push(Request {
id: 0,

View File

@ -66,7 +66,8 @@ impl LlavaNext {
let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width;
// TODO Very odd artifact when the rounding is super close
let height_of_patch = (height * npatches + width - 10) / width;
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
@ -166,5 +167,7 @@ mod test {
assert_eq!(slots, 2732);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
}
}

View File

@ -739,6 +739,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:

View File

@ -170,6 +170,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)

View File

@ -26,6 +26,8 @@ class Idefics2(VlmCausalLM):
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(

View File

@ -64,6 +64,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(image_input, config) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
@ -82,7 +102,7 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_size,
)
height_of_patch = math.ceil(height / width * npatches)
height_of_patch = (height * npatches + width - 10) // width
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
@ -99,12 +119,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@ -112,6 +129,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@ -119,6 +137,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@ -147,11 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
# import ipdb;ipdb.set_trace()
# height, width = image_input["image_sizes"][0]
# num_features = get_number_of_features(height, width, config)
num_features = 320
full_text += "<image>" * num_features
full_text += image_text_replacement(image_input, config)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
@ -163,15 +178,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_inputs = {
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"pixel_attention_mask": torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
),
# "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@ -192,14 +213,20 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
# batch.image_sizes = image_inputs["image_sizes"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
# batch.image_sizes = None
batch.image_sizes = None
return batch
@ -291,7 +318,7 @@ class VlmCausalLM(BaseFlashMistral):
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
# image_sizes=batch.image_sizes,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
@ -299,8 +326,8 @@ class VlmCausalLM(BaseFlashMistral):
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
# if batch.image_sizes is not None:
# batch.image_sizes = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph