2025-02-28 11:14:58 +00:00
|
|
|
from io import BytesIO
|
|
|
|
from PIL import Image
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
from opentelemetry import trace
|
|
|
|
from typing import Iterable
|
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)
* clean cuda/rocm code in hpu backend, enable flat_hpu
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix TP in pageattn
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* adjust block table in hpu to improve performance
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable all the model. not testet yet
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* use tensor cache in hpu graph to avoid replay issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* add moe support, fix qwen/mistral/mixtral crash
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix phimoe issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* gpt_bigcode could also go pageattn
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable dbrx remove some unused code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* multi-modality initial PR
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* adjust warmup and enable vlm
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix incorrect output in qwen2 idefics if hpu graph is used
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* remove unused quantization code and enable awq/gptq int4
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptq issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable fp8
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* warmup prefill
remove model where pageattn is not used, set block table to None since it's not used
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* add warmup_decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* warmup decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* remove block_tables and prefill_cache_indices which will lead to dynamic shape
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix comment
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* missing gptj change...
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix some issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* remove torch.where to fix incorrect output in hpu graph model
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* match the latest vllm_extension ops
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
---------
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-14 13:58:13 +00:00
|
|
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
|
|
|
FlashVlmCausalLMBatch,
|
2025-02-28 11:14:58 +00:00
|
|
|
image_text_replacement,
|
|
|
|
)
|
|
|
|
|
|
|
|
from text_generation_server.pb.generate_pb2 import Request
|
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
|
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)
* clean cuda/rocm code in hpu backend, enable flat_hpu
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix TP in pageattn
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* adjust block table in hpu to improve performance
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable all the model. not testet yet
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* use tensor cache in hpu graph to avoid replay issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* add moe support, fix qwen/mistral/mixtral crash
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix phimoe issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* gpt_bigcode could also go pageattn
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable dbrx remove some unused code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* multi-modality initial PR
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* adjust warmup and enable vlm
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix incorrect output in qwen2 idefics if hpu graph is used
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* remove unused quantization code and enable awq/gptq int4
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptq issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable fp8
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* warmup prefill
remove model where pageattn is not used, set block table to None since it's not used
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* add warmup_decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* warmup decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* remove block_tables and prefill_cache_indices which will lead to dynamic shape
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix comment
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* missing gptj change...
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix some issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* remove torch.where to fix incorrect output in hpu graph model
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* match the latest vllm_extension ops
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
---------
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-14 13:58:13 +00:00
|
|
|
class PaliGemmaBatch(FlashVlmCausalLMBatch):
|
2025-02-28 11:14:58 +00:00
|
|
|
@classmethod
|
|
|
|
def batch_tokenized_inputs(
|
|
|
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
|
|
|
):
|
|
|
|
batch_inputs = []
|
|
|
|
image_inputs = []
|
|
|
|
max_truncation = 0
|
|
|
|
for r in requests:
|
|
|
|
full_text = ""
|
|
|
|
image_id = 0
|
|
|
|
for chunk in r.input_chunks.chunks:
|
|
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
|
|
if chunk_type == "text":
|
|
|
|
full_text += "<bos>" + chunk.text + "\n"
|
|
|
|
elif chunk_type == "image":
|
|
|
|
image = Image.open(BytesIO(chunk.image.data))
|
|
|
|
# TODO do_convert_RGB should be on by default ?
|
|
|
|
image = image.convert("RGB")
|
|
|
|
image_input = processor.image_processor(image, return_tensors="pt")
|
|
|
|
full_text += image_text_replacement(
|
|
|
|
processor, image_input, config, image_id
|
|
|
|
)
|
|
|
|
image_inputs.append(image_input)
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
|
|
|
|
|
|
batch_inputs.append(full_text)
|
|
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
|
|
batch_inputs,
|
|
|
|
truncation=True,
|
|
|
|
max_length=max_truncation,
|
|
|
|
add_special_tokens=False,
|
|
|
|
)["input_ids"]
|
|
|
|
if image_inputs:
|
|
|
|
image_input = image_inputs[0]
|
|
|
|
new_image_inputs = {
|
|
|
|
"pixel_values": torch.cat(
|
|
|
|
[img["pixel_values"] for img in image_inputs], dim=0
|
|
|
|
),
|
|
|
|
}
|
|
|
|
if "pixel_attention_mask" in image_input:
|
|
|
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
|
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
|
|
)
|
|
|
|
if "image_sizes" in image_input:
|
|
|
|
new_image_inputs["image_sizes"] = torch.cat(
|
|
|
|
[img["image_sizes"] for img in image_inputs], dim=0
|
|
|
|
)
|
|
|
|
image_inputs = new_image_inputs
|
|
|
|
else:
|
|
|
|
image_inputs = None
|
|
|
|
return batch_tokenized_inputs, image_inputs
|