mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Remove a lot of duplicated code.
This commit is contained in:
parent
69cb084b5f
commit
ed34cf0222
@ -53,47 +53,62 @@ FLASH_ATTENTION = True
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
|
||||||
|
|
||||||
# from text_generation_server.models.flash_llama import (
|
|
||||||
# FlashLlama,
|
|
||||||
# )
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_qwen2 import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
FlashQwen2,
|
FlashCohereForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_cohere import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashCohere,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma import (
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
FlashGemma,
|
FlashGemma2ForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma2 import (
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
FlashGemma2,
|
FlashDbrxForCausalLM,
|
||||||
|
DbrxConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
FlashGPTNeoXForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.pali_gemma import (
|
from text_generation_server.models.pali_gemma import (
|
||||||
PaliGemma,
|
PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_santacoder import (
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
FlashSantacoderSharded,
|
FlashPhiForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.llava_next import LlavaNext
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
from text_generation_server.models.idefics2 import Idefics2
|
LlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
|
FlashSantacoderForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
FlashMixtralForCausalLM,
|
FlashMixtralForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
FlashGPT2ForCausalLM,
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
@ -102,20 +117,8 @@ except ImportError as e:
|
|||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashCausalLM)
|
__all__.append(FlashCausalLM)
|
||||||
__all__.append(FlashGPT2)
|
|
||||||
__all__.append(FlashNeoXSharded)
|
|
||||||
__all__.append(FlashRWSharded)
|
|
||||||
__all__.append(FlashSantacoderSharded)
|
|
||||||
# __all__.append(FlashLlama)
|
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
__all__.append(FlashDbrx)
|
|
||||||
__all__.append(FlashPhi)
|
|
||||||
__all__.append(FlashQwen2)
|
|
||||||
__all__.append(FlashStarcoder2)
|
|
||||||
__all__.append(FlashGemma)
|
|
||||||
__all__.append(FlashGemma2)
|
|
||||||
__all__.append(FlashCohere)
|
|
||||||
|
|
||||||
MAMBA_AVAILABLE = True
|
MAMBA_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
@ -468,13 +471,16 @@ def get_model(
|
|||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashSantacoderSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashSantacoderForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -511,13 +517,15 @@ def get_model(
|
|||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
try:
|
try:
|
||||||
return FlashGPT2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGPT2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
@ -543,13 +551,15 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashNeoXSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
return GPTNeoxSharded(
|
return GPTNeoxSharded(
|
||||||
@ -572,13 +582,15 @@ def get_model(
|
|||||||
|
|
||||||
elif model_type == PHI:
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashPhi(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
@ -630,13 +642,15 @@ def get_model(
|
|||||||
)
|
)
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGemmaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||||
@ -651,13 +665,15 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GEMMA2:
|
elif model_type == GEMMA2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGemma2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
@ -673,13 +689,15 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCohere(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashCohereForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||||
@ -695,13 +713,16 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == DBRX:
|
if model_type == DBRX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashDbrx(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashDbrxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DbrxConfig,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||||
@ -720,24 +741,30 @@ def get_model(
|
|||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashRWSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||||
return FlashRWSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return RW(
|
return RW(
|
||||||
@ -799,12 +826,15 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashStarcoder2(
|
return FlashMistral(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -822,12 +852,15 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashQwen2(
|
return FlashMistral(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=Qwen2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||||
@ -874,34 +907,43 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return Idefics2(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == "paligemma":
|
if model_type == "paligemma":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return PaliGemma(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return LlavaNext(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
revision,
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -466,6 +466,7 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
self.transformer = FlashSantacoderModel(config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
|
@ -822,19 +822,9 @@ class FlashCausalLM(Model):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
lora_adapter_ids: Optional[list] = [],
|
lora_adapter_ids: Optional[list] = [],
|
||||||
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||||
|
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||||
default_dtype=torch.float16,
|
default_dtype=torch.float16,
|
||||||
# self,
|
aliases=None,
|
||||||
# model_id: str,
|
|
||||||
# model_class,
|
|
||||||
# tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
|
||||||
# num_layers: int,
|
|
||||||
# num_kv_heads: int,
|
|
||||||
# head_size: int,
|
|
||||||
# dtype: torch.dtype,
|
|
||||||
# device: torch.device,
|
|
||||||
# rank: int = 0,
|
|
||||||
# world_size: int = 1,
|
|
||||||
# sliding_window: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -868,7 +858,7 @@ class FlashCausalLM(Model):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = config_class.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
@ -881,7 +871,9 @@ class FlashCausalLM(Model):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(
|
||||||
|
filenames, device, dtype, process_group=self.process_group, aliases=aliases
|
||||||
|
)
|
||||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
|
||||||
FlashDbrxForCausalLM,
|
|
||||||
DbrxConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashDbrx(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
|
||||||
|
|
||||||
try:
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
# FIXME: change back to model id once the tokenizer.json is merged
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
"Xenova/dbrx-instruct-tokenizer",
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = DbrxConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashDbrxForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashDbrx, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
||||||
FlashGemmaForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
# TODO hardcoded
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGemma, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import PretrainedConfig, AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
|
||||||
FlashGemma2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PretrainedConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
# TODO hardcoded
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGemma2, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,82 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
|
||||||
from transformers.models.gpt2 import GPT2Tokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
|
||||||
FlashGPT2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGPT2, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,82 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
|
||||||
FlashGPTNeoXForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashNeoXSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashNeoXSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.gpt_neox.layers),
|
|
||||||
num_kv_heads=model.gpt_neox.num_heads,
|
|
||||||
head_size=model.gpt_neox.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,111 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
|
||||||
FlashPhiForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPhi(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashPhiForCausalLM(config, weights)
|
|
||||||
if speculator:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(speculator).exists() and Path(speculator).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
speculator, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
speculator, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(speculator) / "config.json")
|
|
||||||
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashPhi, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,88 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
|
||||||
FlashMistral,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
|
||||||
Qwen2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashQwen2(FlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = Qwen2ForCausalLM(config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashMistral, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
@ -1,91 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
|
||||||
RWConfig,
|
|
||||||
FlashRWForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = RWConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device,
|
|
||||||
dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={
|
|
||||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
|
||||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashRWForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashRWSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.transformer.h),
|
|
||||||
num_kv_heads=model.transformer.cache_size,
|
|
||||||
head_size=model.transformer.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,99 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional, List
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
|
||||||
FlashSantacoderForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashSantacoderSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.transformer.h),
|
|
||||||
num_kv_heads=1,
|
|
||||||
head_size=model.transformer.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
@ -1,51 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
|
||||||
Idefics2ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
# XXX: Extremely important to cap resolution in order to limit
|
|
||||||
# VRAM usage.
|
|
||||||
size={"longest_edge": 448, "shortest_edge": 378},
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_cls=Idefics2ForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.text_model.model.layers),
|
|
||||||
model.text_model.model.num_key_value_heads,
|
|
||||||
model.text_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
@ -1,46 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
|
||||||
LlavaNextForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNext(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_cls=LlavaNextForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
@ -77,32 +77,6 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class PaliGemma(VlmCausalLM):
|
class PaliGemma(VlmCausalLM):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
config_cls=AutoConfig,
|
|
||||||
model_cls=PaliGemmaForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self):
|
def batch_type(self):
|
||||||
return PaliGemmaBatch
|
return PaliGemmaBatch
|
||||||
|
@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
|||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
FlashMistral,
|
FlashMistral,
|
||||||
)
|
)
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -240,9 +241,34 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashMistral):
|
class VlmCausalLM(FlashMistral):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
processor_class=AutoProcessor,
|
||||||
|
processor_kwargs=None,
|
||||||
|
batch_class=VlmCausalLMBatch,
|
||||||
|
revision,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if processor_kwargs is None:
|
||||||
|
processor_kwargs = {}
|
||||||
|
self.processor = processor_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
self.batch_class = batch_class
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return VlmCausalLMBatch
|
return self.batch_class
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user