mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: adjust lints and ignore specific rules
This commit is contained in:
parent
72c97676fd
commit
a10f4010d7
@ -2,10 +2,6 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
|
@ -69,12 +69,13 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||||
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
|
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
|
||||||
pass
|
# Local variable `quantize` is assigned to but never used
|
||||||
|
quantize = None # noqa F841
|
||||||
# See above, exl2 LM head can be quantized or not.
|
# See above, exl2 LM head can be quantized or not.
|
||||||
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
|
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
|
||||||
pass
|
quantize = None # noqa F841
|
||||||
else:
|
else:
|
||||||
pass
|
quantize = config.quantize # noqa F841
|
||||||
|
|
||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None),
|
get_linear(weight, bias=None),
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# ruff: noqa: F821
|
||||||
|
# the above line disables the `undefined-name` rule for the model type variables
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
@ -298,12 +301,10 @@ class ModelType(enum.Enum):
|
|||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_str(cls, model_type: str) -> "ModelType":
|
__GLOBALS = locals()
|
||||||
for model in cls:
|
for data in ModelType:
|
||||||
if model.value["type"] == model_type:
|
__GLOBALS[data.name] = data.value["type"]
|
||||||
return model
|
|
||||||
raise ValueError(f"Unknown model type {model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -492,10 +493,7 @@ def get_model(
|
|||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert model_type to ModelType enum
|
if model_type == DEEPSEEK_V2:
|
||||||
model_type = ModelType.from_str(model_type)
|
|
||||||
|
|
||||||
if model_type == ModelType.DEEPSEEK_V2:
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
config_dict.get("qk_nope_dim", 128)
|
config_dict.get("qk_nope_dim", 128)
|
||||||
@ -528,7 +526,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == ModelType.MAMBA:
|
elif model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -552,8 +550,8 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_type == ModelType.GPT_BIGCODE
|
model_type == GPT_BIGCODE
|
||||||
or model_type == ModelType.GPT2
|
or model_type == GPT2
|
||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
@ -583,7 +581,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.BLOOM:
|
if model_type == BLOOM:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=BloomForCausalLM,
|
model_class=BloomForCausalLM,
|
||||||
@ -594,7 +592,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
batch_class=BloomCausalLMBatch,
|
batch_class=BloomCausalLMBatch,
|
||||||
)
|
)
|
||||||
elif model_type == ModelType.MPT:
|
elif model_type == MPT:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=MPTForCausalLM,
|
model_class=MPTForCausalLM,
|
||||||
@ -605,7 +603,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
batch_class=CausalLMBatchKeysLast,
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == ModelType.GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
try:
|
try:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -640,7 +638,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == ModelType.GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
GPTNeoXConfig,
|
GPTNeoXConfig,
|
||||||
@ -677,7 +675,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == ModelType.PHI:
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -716,11 +714,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||||
model_type == ModelType.LLAMA
|
|
||||||
or model_type == ModelType.BAICHUAN
|
|
||||||
or model_type == ModelType.PHI3
|
|
||||||
):
|
|
||||||
print(f">>> model_type: {model_type}")
|
print(f">>> model_type: {model_type}")
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -744,7 +738,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type == ModelType.GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -769,7 +763,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == ModelType.GEMMA2:
|
elif model_type == GEMMA2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -795,7 +789,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -819,7 +813,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.DBRX:
|
if model_type == DBRX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -846,7 +840,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel", ModelType.FALCON]:
|
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
@ -894,7 +888,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.MISTRAL:
|
if model_type == MISTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -918,7 +912,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -942,7 +936,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -968,7 +962,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.QWEN2:
|
if model_type == QWEN2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -992,7 +986,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.OPT:
|
if model_type == OPT:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=OPTForCausalLM,
|
model_class=OPTForCausalLM,
|
||||||
@ -1003,7 +997,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ModelType.T5:
|
if model_type == T5:
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=T5ForConditionalGeneration,
|
model_class=T5ForConditionalGeneration,
|
||||||
@ -1019,7 +1013,7 @@ def get_model(
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if model_type == ModelType.IDEFICS:
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return IDEFICSSharded(
|
return IDEFICSSharded(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1031,7 +1025,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == ModelType.IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -1048,7 +1042,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == ModelType.PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -1066,7 +1060,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == ModelType.LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_class=LlavaNextForConditionalGeneration,
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
|
@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
if isinstance(self.past_key_values[0], tuple):
|
if type(self.past_key_values[0]) is tuple:
|
||||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||||
|
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
|
@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
image_hidden_states = self.image_hidden_states[keep_indices]
|
image_hidden_states = self.image_hidden_states[keep_indices]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
if isinstance(self.past_key_values[0], tuple):
|
if type(self.past_key_values[0]) is tuple:
|
||||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||||
|
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
|
@ -253,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
if isinstance(self.past_key_values[0], tuple):
|
if type(self.past_key_values[0]) is tuple:
|
||||||
self.past_key_values = [
|
self.past_key_values = [
|
||||||
[t for t in layer] for layer in self.past_key_values
|
[t for t in layer] for layer in self.past_key_values
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user