mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: adjust lints and ignore specific rules
This commit is contained in:
parent
72c97676fd
commit
a10f4010d7
@ -2,10 +2,6 @@ import os
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
|
@ -69,12 +69,13 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
|
||||
pass
|
||||
# Local variable `quantize` is assigned to but never used
|
||||
quantize = None # noqa F841
|
||||
# See above, exl2 LM head can be quantized or not.
|
||||
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
|
||||
pass
|
||||
quantize = None # noqa F841
|
||||
else:
|
||||
pass
|
||||
quantize = config.quantize # noqa F841
|
||||
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None),
|
||||
|
@ -1,3 +1,6 @@
|
||||
# ruff: noqa: F821
|
||||
# the above line disables the `undefined-name` rule for the model type variables
|
||||
|
||||
import torch
|
||||
import enum
|
||||
import os
|
||||
@ -298,12 +301,10 @@ class ModelType(enum.Enum):
|
||||
"multimodal": True,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, model_type: str) -> "ModelType":
|
||||
for model in cls:
|
||||
if model.value["type"] == model_type:
|
||||
return model
|
||||
raise ValueError(f"Unknown model type {model_type}")
|
||||
|
||||
__GLOBALS = locals()
|
||||
for data in ModelType:
|
||||
__GLOBALS[data.name] = data.value["type"]
|
||||
|
||||
|
||||
def get_model(
|
||||
@ -492,10 +493,7 @@ def get_model(
|
||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||
)
|
||||
|
||||
# convert model_type to ModelType enum
|
||||
model_type = ModelType.from_str(model_type)
|
||||
|
||||
if model_type == ModelType.DEEPSEEK_V2:
|
||||
if model_type == DEEPSEEK_V2:
|
||||
if FLASH_ATTENTION:
|
||||
head_size = max(
|
||||
config_dict.get("qk_nope_dim", 128)
|
||||
@ -528,7 +526,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == ModelType.MAMBA:
|
||||
elif model_type == MAMBA:
|
||||
return Mamba(
|
||||
model_id,
|
||||
revision,
|
||||
@ -552,8 +550,8 @@ def get_model(
|
||||
)
|
||||
|
||||
if (
|
||||
model_type == ModelType.GPT_BIGCODE
|
||||
or model_type == ModelType.GPT2
|
||||
model_type == GPT_BIGCODE
|
||||
or model_type == GPT2
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
@ -583,7 +581,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.BLOOM:
|
||||
if model_type == BLOOM:
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=BloomForCausalLM,
|
||||
@ -594,7 +592,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=BloomCausalLMBatch,
|
||||
)
|
||||
elif model_type == ModelType.MPT:
|
||||
elif model_type == MPT:
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MPTForCausalLM,
|
||||
@ -605,7 +603,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=CausalLMBatchKeysLast,
|
||||
)
|
||||
elif model_type == ModelType.GPT2:
|
||||
elif model_type == GPT2:
|
||||
if FLASH_ATTENTION:
|
||||
try:
|
||||
return FlashCausalLM(
|
||||
@ -640,7 +638,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == ModelType.GPT_NEOX:
|
||||
elif model_type == GPT_NEOX:
|
||||
if FLASH_ATTENTION:
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
GPTNeoXConfig,
|
||||
@ -677,7 +675,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == ModelType.PHI:
|
||||
elif model_type == PHI:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -716,11 +714,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif (
|
||||
model_type == ModelType.LLAMA
|
||||
or model_type == ModelType.BAICHUAN
|
||||
or model_type == ModelType.PHI3
|
||||
):
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
print(f">>> model_type: {model_type}")
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
@ -744,7 +738,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type == ModelType.GEMMA:
|
||||
if model_type == GEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -769,7 +763,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == ModelType.GEMMA2:
|
||||
elif model_type == GEMMA2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -795,7 +789,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.COHERE:
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -819,7 +813,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.DBRX:
|
||||
if model_type == DBRX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -846,7 +840,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", ModelType.FALCON]:
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
if config_dict.get("alibi", False):
|
||||
@ -894,7 +888,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.MISTRAL:
|
||||
if model_type == MISTRAL:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -918,7 +912,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.MIXTRAL:
|
||||
if model_type == MIXTRAL:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -942,7 +936,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.STARCODER2:
|
||||
if model_type == STARCODER2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -968,7 +962,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.QWEN2:
|
||||
if model_type == QWEN2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -992,7 +986,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.OPT:
|
||||
if model_type == OPT:
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=OPTForCausalLM,
|
||||
@ -1003,7 +997,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == ModelType.T5:
|
||||
if model_type == T5:
|
||||
return Seq2SeqLM(
|
||||
model_id=model_id,
|
||||
model_class=T5ForConditionalGeneration,
|
||||
@ -1019,7 +1013,7 @@ def get_model(
|
||||
]
|
||||
},
|
||||
)
|
||||
if model_type == ModelType.IDEFICS:
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION:
|
||||
return IDEFICSSharded(
|
||||
model_id,
|
||||
@ -1031,7 +1025,7 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == ModelType.IDEFICS2:
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
@ -1048,7 +1042,7 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == ModelType.PALIGEMMA:
|
||||
if model_type == PALIGEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
@ -1066,7 +1060,7 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == ModelType.LLAVA_NEXT:
|
||||
if model_type == LLAVA_NEXT:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
|
@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
|
||||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if isinstance(self.past_key_values[0], tuple):
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
|
@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
image_hidden_states = self.image_hidden_states[keep_indices]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if isinstance(self.past_key_values[0], tuple):
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
|
@ -253,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if isinstance(self.past_key_values[0], tuple):
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [
|
||||
[t for t in layer] for layer in self.past_key_values
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user