text-generation-inference/backends/gaudi/server/text_generation_server/models/model.py
Wang, Yi d62c941c56
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)
* clean cuda/rocm code in hpu backend, enable flat_hpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix TP in pageattn

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* adjust block table in hpu to improve performance

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable all the model. not testet yet

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* use tensor cache in hpu graph to avoid replay issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add moe support, fix qwen/mistral/mixtral crash

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix phimoe issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* gpt_bigcode could also go pageattn

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable dbrx remove some unused code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* multi-modality initial PR

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* adjust warmup and enable vlm

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix incorrect output in qwen2 idefics if hpu graph is used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove unused quantization code and enable awq/gptq int4

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix gptq issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable fp8

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* warmup prefill

remove model where pageattn is not used, set block table to None since it's not used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add warmup_decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* warmup decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove block_tables and prefill_cache_indices which will lead to dynamic shape

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* missing gptj change...

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix some issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove torch.where to fix incorrect output in hpu graph model

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* match the latest vllm_extension ops

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-14 15:58:13 +02:00

143 lines
4.9 KiB
Python
Raw Blame History

import inspect
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, Generation
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.pb import generate_pb2
BASE_MODEL_ADAPTER_ID = "__base_model__"
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
):
self.model_id = model_id
self.model = model.eval()
self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device
self.rank = rank
self.world_size = world_size
self.sliding_window = sliding_window if sliding_window != -1 else None
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.check_initialized()
@property
def info(self) -> InfoResponse:
if self.requires_padding and self.sliding_window is not None:
raise NotImplementedError("sliding_window is not implemented with padding")
return InfoResponse(
requires_padding=self.requires_padding,
dtype=str(self.dtype),
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate,
block_size=BLOCK_SIZE,
)
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(
self, batch: B
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError
def warmup(
self, batch: generate_pb2.WarmupRequest
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
self.generate_token(batch)
return None, None, None
def decode_token(
self,
all_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
skip_special_tokens: bool = False,
) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
all_input_ids[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
)
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text) :]
return new_text, read_offset, len(all_input_ids)
else:
return "", prefix_offset, read_offset
def check_initialized(self):
uninitialized_parameters = []
for n, p in self.model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
)