mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Remove a lot of duplicated code.
This commit is contained in:
parent
69cb084b5f
commit
ed34cf0222
@ -53,47 +53,62 @@ FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
|
||||
# from text_generation_server.models.flash_llama import (
|
||||
# FlashLlama,
|
||||
# )
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_qwen2 import (
|
||||
FlashQwen2,
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_cohere import (
|
||||
FlashCohere,
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma import (
|
||||
FlashGemma,
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma2 import (
|
||||
FlashGemma2,
|
||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||
FlashDbrxForCausalLM,
|
||||
DbrxConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemma,
|
||||
PaliGemmaBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.idefics2 import Idefics2
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||
FlashStarcoder2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||
FlashMixtralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
@ -102,20 +117,8 @@ except ImportError as e:
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(FlashGPT2)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
# __all__.append(FlashLlama)
|
||||
__all__.append(IDEFICSSharded)
|
||||
__all__.append(FlashMistral)
|
||||
__all__.append(FlashDbrx)
|
||||
__all__.append(FlashPhi)
|
||||
__all__.append(FlashQwen2)
|
||||
__all__.append(FlashStarcoder2)
|
||||
__all__.append(FlashGemma)
|
||||
__all__.append(FlashGemma2)
|
||||
__all__.append(FlashCohere)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
try:
|
||||
@ -468,13 +471,16 @@ def get_model(
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashSantacoderForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
@ -511,13 +517,15 @@ def get_model(
|
||||
elif model_type == GPT2:
|
||||
if FLASH_ATTENTION:
|
||||
try:
|
||||
return FlashGPT2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPT2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
@ -543,13 +551,15 @@ def get_model(
|
||||
)
|
||||
elif model_type == GPT_NEOX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPTNeoXForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
return GPTNeoxSharded(
|
||||
@ -572,13 +582,15 @@ def get_model(
|
||||
|
||||
elif model_type == PHI:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashPhi(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashPhiForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
@ -630,13 +642,15 @@ def get_model(
|
||||
)
|
||||
if model_type == GEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemmaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||
@ -651,13 +665,15 @@ def get_model(
|
||||
)
|
||||
elif model_type == GEMMA2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemma2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
@ -673,13 +689,15 @@ def get_model(
|
||||
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCohere(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashCohereForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||
@ -695,13 +713,16 @@ def get_model(
|
||||
|
||||
if model_type == DBRX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashDbrx(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashDbrxForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DbrxConfig,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||
@ -720,24 +741,30 @@ def get_model(
|
||||
if FLASH_ATTENTION:
|
||||
if config_dict.get("alibi", False):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashRWForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
)
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||
else:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashRWForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
)
|
||||
else:
|
||||
return RW(
|
||||
@ -799,12 +826,15 @@ def get_model(
|
||||
|
||||
if model_type == STARCODER2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashStarcoder2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashMistral(
|
||||
model_id=model_id,
|
||||
model_class=FlashStarcoder2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
@ -822,12 +852,15 @@ def get_model(
|
||||
|
||||
if model_type == QWEN2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashQwen2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashMistral(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||
@ -874,34 +907,43 @@ def get_model(
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION:
|
||||
return Idefics2(
|
||||
model_id,
|
||||
revision,
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics2ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == "paligemma":
|
||||
if FLASH_ATTENTION:
|
||||
return PaliGemma(
|
||||
model_id,
|
||||
revision,
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PaliGemmaForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == LLAVA_NEXT:
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
revision,
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
|
@ -466,6 +466,7 @@ class FlashSantacoderModel(nn.Module):
|
||||
class FlashSantacoderForCausalLM(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
|
@ -822,19 +822,9 @@ class FlashCausalLM(Model):
|
||||
trust_remote_code: bool = False,
|
||||
lora_adapter_ids: Optional[list] = [],
|
||||
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||
default_dtype=torch.float16,
|
||||
# self,
|
||||
# model_id: str,
|
||||
# model_class,
|
||||
# tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||
# num_layers: int,
|
||||
# num_kv_heads: int,
|
||||
# head_size: int,
|
||||
# dtype: torch.dtype,
|
||||
# device: torch.device,
|
||||
# rank: int = 0,
|
||||
# world_size: int = 1,
|
||||
# sliding_window: Optional[int] = None,
|
||||
aliases=None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -868,7 +858,7 @@ class FlashCausalLM(Model):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
config = config_class.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
@ -881,7 +871,9 @@ class FlashCausalLM(Model):
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
weights = Weights(
|
||||
filenames, device, dtype, process_group=self.process_group, aliases=aliases
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
|
@ -1,100 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||
FlashDbrxForCausalLM,
|
||||
DbrxConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashDbrx(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
||||
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
except:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
except:
|
||||
# FIXME: change back to model id once the tokenizer.json is merged
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
"Xenova/dbrx-instruct-tokenizer",
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = DbrxConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashDbrxForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashDbrx, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,83 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGemma(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
# TODO hardcoded
|
||||
prefix = ""
|
||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,83 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import PretrainedConfig, AutoTokenizer
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGemma2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = PretrainedConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
# TODO hardcoded
|
||||
prefix = ""
|
||||
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,82 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from transformers.models.gpt2 import GPT2Tokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGPT2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGPT2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,82 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashNeoXSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashNeoXSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
num_kv_heads=model.gpt_neox.num_heads,
|
||||
head_size=model.gpt_neox.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,111 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashPhi(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashPhiForCausalLM(config, weights)
|
||||
if speculator:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(speculator).exists() and Path(speculator).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
speculator, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
speculator, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(speculator) / "config.json")
|
||||
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashPhi, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,88 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
FlashMistral,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashQwen2(FlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = Qwen2ForCausalLM(config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
@ -1,91 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashRWSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
},
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=model.transformer.cache_size,
|
||||
head_size=model.transformer.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -1,99 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
import json
|
||||
import os
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashSantacoderSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=1,
|
||||
head_size=model.transformer.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
@ -1,51 +0,0 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class Idefics2(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
size={"longest_edge": 448, "shortest_edge": 378},
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=Idefics2ForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.text_model.model.layers),
|
||||
model.text_model.model.num_key_value_heads,
|
||||
model.text_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
@ -1,46 +0,0 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class LlavaNext(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
@ -77,32 +77,6 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
|
||||
|
||||
class PaliGemma(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config_cls=AutoConfig,
|
||||
model_cls=PaliGemmaForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self):
|
||||
return PaliGemmaBatch
|
||||
|
@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
FlashMistral,
|
||||
)
|
||||
from transformers import AutoProcessor
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -240,9 +241,34 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
|
||||
class VlmCausalLM(FlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
*,
|
||||
processor_class=AutoProcessor,
|
||||
processor_kwargs=None,
|
||||
batch_class=VlmCausalLMBatch,
|
||||
revision,
|
||||
trust_remote_code: bool,
|
||||
**kwargs,
|
||||
):
|
||||
if processor_kwargs is None:
|
||||
processor_kwargs = {}
|
||||
self.processor = processor_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
return self.batch_class
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user