Refactor dead code.

This commit is contained in:
Nicolas Patry 2024-07-02 11:13:51 +00:00
parent 245d3de948
commit b28946d695
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
3 changed files with 109 additions and 146 deletions

View File

@ -56,8 +56,12 @@ try:
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama, # from text_generation_server.models.flash_llama import (
# FlashLlama,
# )
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
) )
from text_generation_server.models.flash_qwen2 import ( from text_generation_server.models.flash_qwen2 import (
FlashQwen2, FlashQwen2,
@ -81,7 +85,9 @@ try:
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2 from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx from text_generation_server.models.flash_dbrx import FlashDbrx
@ -97,7 +103,7 @@ if FLASH_ATTENTION:
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) # __all__.append(FlashLlama)
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
__all__.append(FlashMistral) __all__.append(FlashMistral)
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
@ -599,9 +605,10 @@ def get_model(
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -743,12 +750,14 @@ def get_model(
if model_type == MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMistral( return FlashMistral(
model_id, model_id=model_id,
revision, model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
@ -21,6 +26,12 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -803,25 +814,88 @@ class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model: torch.nn.Module, model_class,
tokenizer: PreTrainedTokenizerBase, revision: Optional[str] = None,
num_layers: int, quantize: Optional[str] = None,
num_kv_heads: int, speculator: Optional[str] = None,
head_size: int, dtype: Optional[torch.dtype] = None,
dtype: torch.dtype, trust_remote_code: bool = False,
device: torch.device, lora_adapter_ids: Optional[list] = [],
rank: int = 0, tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
world_size: int = 1, default_dtype=torch.float16,
sliding_window: Optional[int] = None, # self,
# model_id: str,
# model_class,
# tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
# num_layers: int,
# num_kv_heads: int,
# head_size: int,
# dtype: torch.dtype,
# device: torch.device,
# rank: int = 0,
# world_size: int = 1,
# sliding_window: Optional[int] = None,
): ):
self.num_layers = num_layers self.process_group, rank, world_size = initialize_torch_distributed()
self.num_kv_heads = num_kv_heads if torch.cuda.is_available():
self.head_size = head_size device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError(f"{model_class} is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
self.num_layers = config.num_hidden_layers
self.num_kv_heads = config.num_key_value_heads
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
super(FlashCausalLM, self).__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -830,7 +904,7 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window, sliding_window=config.sliding_window,
) )
@property @property

View File

@ -1,24 +1,7 @@
import torch import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Dict, List from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [ ADAPTER_LAYERS = [
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM): class FlashMistral(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
@property @property
def supports_adapter_loading(self) -> bool: def supports_adapter_loading(self) -> bool:
return True return True
@ -183,25 +85,3 @@ class BaseFlashMistral(FlashCausalLM):
def is_row_parallel(self, layer_type: str) -> bool: def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)