mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing features for llava_next. Still issues with warmup and truncation
atm.
This commit is contained in:
parent
ae2b4e1c23
commit
f2d8c2e76f
@ -293,6 +293,7 @@ def launcher(event_loop):
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
@ -334,6 +335,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_batch_prefill_tokens:
|
||||||
|
args.append("--max-batch-prefill-tokens")
|
||||||
|
args.append(str(max_batch_prefill_tokens))
|
||||||
if max_total_tokens:
|
if max_total_tokens:
|
||||||
args.append("--max-total-tokens")
|
args.append("--max-total-tokens")
|
||||||
args.append(str(max_total_tokens))
|
args.append(str(max_total_tokens))
|
||||||
@ -371,6 +375,7 @@ def launcher(event_loop):
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
@ -395,6 +400,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_batch_prefill_tokens:
|
||||||
|
args.append("--max-batch-prefill-tokens")
|
||||||
|
args.append(str(max_batch_prefill_tokens))
|
||||||
if max_total_tokens:
|
if max_total_tokens:
|
||||||
args.append("--max-total-tokens")
|
args.append("--max-total-tokens")
|
||||||
args.append(str(max_total_tokens))
|
args.append(str(max_total_tokens))
|
||||||
|
@ -15,6 +15,7 @@ def flash_llava_next_handle(launcher):
|
|||||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
num_shard=4,
|
num_shard=4,
|
||||||
max_input_length=4000,
|
max_input_length=4000,
|
||||||
|
max_batch_prefill_tokens=8000,
|
||||||
max_total_tokens=4096,
|
max_total_tokens=4096,
|
||||||
) as handle:
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -114,8 +114,8 @@ impl Client {
|
|||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str("");
|
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
inputs.push_str("");
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
|
@ -57,20 +57,46 @@ fn select_best_resolution(
|
|||||||
best_fit.unwrap_or((original_height, original_width))
|
best_fit.unwrap_or((original_height, original_width))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unpadded_features(
|
||||||
|
height: usize,
|
||||||
|
width: usize,
|
||||||
|
npatches: usize,
|
||||||
|
num_patch_height: usize,
|
||||||
|
num_patch_width: usize,
|
||||||
|
) -> (usize, usize) {
|
||||||
|
let current_height = npatches * num_patch_height;
|
||||||
|
let current_width = npatches * num_patch_width;
|
||||||
|
|
||||||
|
let aspect_ratio: f64 = width as f64 / height as f64;
|
||||||
|
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||||
|
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||||
|
let new_height = (height * current_width) / width;
|
||||||
|
// let padding = (current_height - new_height) / 2;
|
||||||
|
// let new_height = current_height - 2 * padding;
|
||||||
|
(new_height, current_width)
|
||||||
|
} else {
|
||||||
|
let new_width = (width * current_height) / height;
|
||||||
|
(current_height, new_width)
|
||||||
|
};
|
||||||
|
println!("{current_height} {current_width}");
|
||||||
|
|
||||||
|
let unpadded_features = current_height * current_width;
|
||||||
|
let newline_features = current_height;
|
||||||
|
(unpadded_features, newline_features)
|
||||||
|
}
|
||||||
|
|
||||||
impl LlavaNext {
|
impl LlavaNext {
|
||||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||||
let image_size = self.vision_config.image_size;
|
let image_size = self.vision_config.image_size;
|
||||||
let patch_size = self.vision_config.patch_size;
|
let patch_size = self.vision_config.patch_size;
|
||||||
assert!(image_size % patch_size == 0);
|
assert!(image_size % patch_size == 0);
|
||||||
let npatches = image_size / patch_size;
|
let npatches = image_size / patch_size;
|
||||||
|
println!("{npatches} {image_size} {patch_size}");
|
||||||
let (num_patch_height, num_patch_width) =
|
let (num_patch_height, num_patch_width) =
|
||||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||||
// Ceil
|
|
||||||
// TODO Very odd artifact when the rounding is super close
|
let (unpadded_features, newline_features) =
|
||||||
let height_of_patch = (height * npatches + width - 10) / width;
|
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
|
||||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
|
||||||
// They are only added after width
|
|
||||||
let newline_features = height_of_patch * num_patch_width;
|
|
||||||
// The base patch covers the entire image
|
// The base patch covers the entire image
|
||||||
let base_features = npatches.pow(2);
|
let base_features = npatches.pow(2);
|
||||||
unpadded_features + newline_features + base_features
|
unpadded_features + newline_features + base_features
|
||||||
@ -159,14 +185,16 @@ mod test {
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let slots = config.get_number_of_features(20, 20);
|
||||||
|
assert_eq!(slots, 1176);
|
||||||
let slots = config.get_number_of_features(640, 640);
|
let slots = config.get_number_of_features(640, 640);
|
||||||
assert_eq!(slots, 2928);
|
assert_eq!(slots, 2928);
|
||||||
let slots = config.get_number_of_features(480, 640);
|
let slots = config.get_number_of_features(480, 640);
|
||||||
assert_eq!(slots, 2340);
|
assert_eq!(slots, 2340);
|
||||||
let slots = config.get_number_of_features(899, 1024);
|
let slots = config.get_number_of_features(899, 1024);
|
||||||
assert_eq!(slots, 2732);
|
assert_eq!(slots, 2634);
|
||||||
let slots = config.get_number_of_features(1024, 899);
|
let slots = config.get_number_of_features(1024, 899);
|
||||||
assert_eq!(slots, 3320);
|
assert_eq!(slots, 2640);
|
||||||
let slots = config.get_number_of_features(1067, 1600);
|
let slots = config.get_number_of_features(1067, 1600);
|
||||||
assert_eq!(slots, 2144);
|
assert_eq!(slots, 2144);
|
||||||
}
|
}
|
||||||
|
@ -229,6 +229,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
self.config.image_grid_pinpoints,
|
self.config.image_grid_pinpoints,
|
||||||
self.config.vision_config.image_size,
|
self.config.vision_config.image_size,
|
||||||
)
|
)
|
||||||
|
if image_sizes[image_idx][0].item() == 22:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
image_feature = image_feature.view(
|
image_feature = image_feature.view(
|
||||||
num_patch_height, num_patch_width, height, width, -1
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
)
|
)
|
||||||
|
@ -64,7 +64,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||||||
return height // patch_size, width // patch_size
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement(image_input, config) -> str:
|
def image_text_replacement(image_input, config, image_id) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
# TODO technically depends on image splitting which is not implemented.
|
# TODO technically depends on image splitting which is not implemented.
|
||||||
num_features = 320
|
num_features = 320
|
||||||
@ -74,7 +74,7 @@ def image_text_replacement(image_input, config) -> str:
|
|||||||
+ "<fake_token_around_image>"
|
+ "<fake_token_around_image>"
|
||||||
)
|
)
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][0]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -84,6 +84,26 @@ def image_text_replacement(image_input, config) -> str:
|
|||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unpadded_features(
|
||||||
|
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
current_height = npatches * num_patch_height
|
||||||
|
current_width = npatches * num_patch_width
|
||||||
|
|
||||||
|
aspect_ratio: float = width / height
|
||||||
|
current_aspect_ratio: float = current_width / current_height
|
||||||
|
if aspect_ratio > current_aspect_ratio:
|
||||||
|
new_height = (height * current_width) // width
|
||||||
|
current_height = new_height
|
||||||
|
else:
|
||||||
|
new_width = (width * current_height) // height
|
||||||
|
current_width = new_width
|
||||||
|
|
||||||
|
unpadded_features = current_height * current_width
|
||||||
|
newline_features = current_height
|
||||||
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_features(height: int, width: int, config) -> int:
|
def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
# From config
|
# From config
|
||||||
# Hardcoded for CLIP for now
|
# Hardcoded for CLIP for now
|
||||||
@ -101,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
image_grid_pinpoints,
|
image_grid_pinpoints,
|
||||||
image_size,
|
image_size,
|
||||||
)
|
)
|
||||||
|
unpadded_features, newline_features = get_unpadded_features(
|
||||||
height_of_patch = (height * npatches + width - 10) // width
|
height, width, npatches, num_patch_height, num_patch_width
|
||||||
|
)
|
||||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
|
||||||
# They are only added after width
|
|
||||||
newline_features = height_of_patch * num_patch_width
|
|
||||||
# The base patch covers the entire image
|
# The base patch covers the entire image
|
||||||
base_features = npatches**2
|
base_features = npatches**2
|
||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
@ -149,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
chunks = split(r.inputs)
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
full_text += chunk["content"]
|
full_text += chunk["content"]
|
||||||
@ -166,7 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
"Cannot process input image not starting with data:"
|
"Cannot process input image not starting with data:"
|
||||||
)
|
)
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
full_text += image_text_replacement(image_input, config)
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
|
Loading…
Reference in New Issue
Block a user