diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c8b68960..a96cb37f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1269,11 +1269,12 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, # XXX: Extremely important to cap resolution in order to limit # VRAM usage. - processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + processor_kwargs={"size": {"longest_edge": 1456}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index bc1fd073..0548fbc6 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -127,11 +127,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str # TODO: implement this in a more general way n_rows = image_input["rows"][0][image_id] n_cols = image_input["cols"][0][image_id] - - # TODO: avoid using hardcoded values - image_seq_len = 169 # default value - # image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2)) - + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) image_str = get_image_prompt_string( n_rows, n_cols, @@ -301,10 +300,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - replacement_text = image_text_replacement( + full_text += image_text_replacement( processor, image_inputs, config, image_id ) - full_text += replacement_text image_id += 1 full_text = image_text_replacement_fixup(config, full_text) @@ -379,10 +377,9 @@ class VlmCausalLM(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code, - # **processor_kwargs, + **processor_kwargs, ) self.batch_class = batch_class - # import ipdb; ipdb.set_trace() super().__init__( model_id=model_id, revision=revision,