From 05c13c89defeb5c67114f1ad567643f71b29d4d1 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Tue, 30 Jul 2024 10:05:38 +0000 Subject: [PATCH] Remove useless modification Signed-off-by: yuanwu --- .../models/flash_causal_lm.py | 117 +++--------------- 1 file changed, 15 insertions(+), 102 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 72ceca6b..86d9b4c8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,12 +10,7 @@ import numpy as np from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import ( - PreTrainedTokenizerBase, - AutoConfig, - AutoTokenizer, - GenerationConfig, -) +from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -24,11 +19,6 @@ from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) from text_generation_server.models.types import ( Batch, Tokens, @@ -696,97 +686,20 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, - model_id: str, - model_class, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - lora_adapter_ids: Optional[list] = [], - tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, - config_class: PreTrainedTokenizerBase = AutoConfig, - default_dtype=torch.bfloat16, - aliases=None, - # Used for Santacoder override of config - num_kv_heads: Optional[int] = None, - # Deepseek V2 uses different QK and V dims. - head_size: Optional[int] = None, - skip_special_tokens: bool = True, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + num_layers: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + rank: int = 0, + world_size: int = 1, + sliding_window: Optional[int] = None, ): - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - try: - generation_config = GenerationConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - if isinstance(generation_config.eos_token_id, (list, set)): - # TODO Huge hack - tokenizer._eos_token_ids = set(generation_config.eos_token_id) - except Exception: - pass - - config = config_class.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype) - - - - prefix = "" - model = model_class(prefix, config, weights) - - # VLM models define the config we care about in their text_config - text_config = getattr(config, "text_config", None) - if text_config is not None: - config = text_config - - - self.num_layers = config.num_hidden_layers - # Validation is done in the model itself - if num_kv_heads is None: - num_kv_heads = getattr(config, "num_key_value_heads", None) - # GPT-2 workaround - if num_kv_heads is None: - num_kv_heads = getattr(config, "n_head", None) - if num_kv_heads is None: - raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = num_kv_heads ( - num_kv_heads // self.process_group.size() - if num_kv_heads > 1 - else num_kv_heads - ) - assert self.num_kv_heads > 0 - - if head_size is None: - # Some models use GQA and different sizes for o_proj - # and q_proj, that allows for that. - if hasattr(config, "head_dim"): - self.head_size = config.head_dim - else: - self.head_size = config.hidden_size // config.num_attention_heads - else: - self.head_size = head_size - - self.cuda_graphs = {} - self.kv_cache = [] + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_size = head_size self.cuda_graphs = {} @@ -798,7 +711,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=None, + sliding_window=sliding_window, ) @property