mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing features for llava_next. Still issues with warmup and truncation
atm.
This commit is contained in:
parent
ae2b4e1c23
commit
f2d8c2e76f
@ -293,6 +293,7 @@ def launcher(event_loop):
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
@ -334,6 +335,9 @@ def launcher(event_loop):
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_batch_prefill_tokens:
|
||||
args.append("--max-batch-prefill-tokens")
|
||||
args.append(str(max_batch_prefill_tokens))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
@ -371,6 +375,7 @@ def launcher(event_loop):
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
@ -395,6 +400,9 @@ def launcher(event_loop):
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_batch_prefill_tokens:
|
||||
args.append("--max-batch-prefill-tokens")
|
||||
args.append(str(max_batch_prefill_tokens))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
@ -15,6 +15,7 @@ def flash_llava_next_handle(launcher):
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
num_shard=4,
|
||||
max_input_length=4000,
|
||||
max_batch_prefill_tokens=8000,
|
||||
max_total_tokens=4096,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
@ -114,8 +114,8 @@ impl Client {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str("");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
inputs.push_str("");
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
|
@ -57,20 +57,46 @@ fn select_best_resolution(
|
||||
best_fit.unwrap_or((original_height, original_width))
|
||||
}
|
||||
|
||||
fn get_unpadded_features(
|
||||
height: usize,
|
||||
width: usize,
|
||||
npatches: usize,
|
||||
num_patch_height: usize,
|
||||
num_patch_width: usize,
|
||||
) -> (usize, usize) {
|
||||
let current_height = npatches * num_patch_height;
|
||||
let current_width = npatches * num_patch_width;
|
||||
|
||||
let aspect_ratio: f64 = width as f64 / height as f64;
|
||||
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||
let new_height = (height * current_width) / width;
|
||||
// let padding = (current_height - new_height) / 2;
|
||||
// let new_height = current_height - 2 * padding;
|
||||
(new_height, current_width)
|
||||
} else {
|
||||
let new_width = (width * current_height) / height;
|
||||
(current_height, new_width)
|
||||
};
|
||||
println!("{current_height} {current_width}");
|
||||
|
||||
let unpadded_features = current_height * current_width;
|
||||
let newline_features = current_height;
|
||||
(unpadded_features, newline_features)
|
||||
}
|
||||
|
||||
impl LlavaNext {
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let image_size = self.vision_config.image_size;
|
||||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
println!("{npatches} {image_size} {patch_size}");
|
||||
let (num_patch_height, num_patch_width) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
// Ceil
|
||||
// TODO Very odd artifact when the rounding is super close
|
||||
let height_of_patch = (height * npatches + width - 10) / width;
|
||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
||||
// They are only added after width
|
||||
let newline_features = height_of_patch * num_patch_width;
|
||||
|
||||
let (unpadded_features, newline_features) =
|
||||
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
|
||||
// The base patch covers the entire image
|
||||
let base_features = npatches.pow(2);
|
||||
unpadded_features + newline_features + base_features
|
||||
@ -159,14 +185,16 @@ mod test {
|
||||
],
|
||||
};
|
||||
|
||||
let slots = config.get_number_of_features(20, 20);
|
||||
assert_eq!(slots, 1176);
|
||||
let slots = config.get_number_of_features(640, 640);
|
||||
assert_eq!(slots, 2928);
|
||||
let slots = config.get_number_of_features(480, 640);
|
||||
assert_eq!(slots, 2340);
|
||||
let slots = config.get_number_of_features(899, 1024);
|
||||
assert_eq!(slots, 2732);
|
||||
assert_eq!(slots, 2634);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
assert_eq!(slots, 2640);
|
||||
let slots = config.get_number_of_features(1067, 1600);
|
||||
assert_eq!(slots, 2144);
|
||||
}
|
||||
|
@ -229,6 +229,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
if image_sizes[image_idx][0].item() == 22:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
|
@ -64,7 +64,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(image_input, config) -> str:
|
||||
def image_text_replacement(image_input, config, image_id) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
# TODO technically depends on image splitting which is not implemented.
|
||||
num_features = 320
|
||||
@ -74,7 +74,7 @@ def image_text_replacement(image_input, config) -> str:
|
||||
+ "<fake_token_around_image>"
|
||||
)
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][0]
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
from loguru import logger
|
||||
|
||||
@ -84,6 +84,26 @@ def image_text_replacement(image_input, config) -> str:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio: float = width / height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
@ -101,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
)
|
||||
|
||||
height_of_patch = (height * npatches + width - 10) // width
|
||||
|
||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||
# They are only added after width
|
||||
newline_features = height_of_patch * num_patch_width
|
||||
unpadded_features, newline_features = get_unpadded_features(
|
||||
height, width, npatches, num_patch_height, num_patch_width
|
||||
)
|
||||
# The base patch covers the entire image
|
||||
base_features = npatches**2
|
||||
return unpadded_features + newline_features + base_features
|
||||
@ -149,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
image_id = 0
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
@ -166,7 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config)
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
Loading…
Reference in New Issue
Block a user