fix: adjust lints and ignore specific rules

This commit is contained in:
drbh 2024-07-25 14:50:18 +00:00
parent 72c97676fd
commit a10f4010d7
6 changed files with 39 additions and 48 deletions

View File

@ -2,10 +2,6 @@ import os
import math
import torch
from torch import nn
# Inverse dim formula to find dim based on number of rotations
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":

View File

@ -69,12 +69,13 @@ class TensorParallelHead(SuperLayer):
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
pass
# Local variable `quantize` is assigned to but never used
quantize = None # noqa F841
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
pass
quantize = None # noqa F841
else:
pass
quantize = config.quantize # noqa F841
return TensorParallelHead(
get_linear(weight, bias=None),

View File

@ -1,3 +1,6 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import enum
import os
@ -298,12 +301,10 @@ class ModelType(enum.Enum):
"multimodal": True,
}
@classmethod
def from_str(cls, model_type: str) -> "ModelType":
for model in cls:
if model.value["type"] == model_type:
return model
raise ValueError(f"Unknown model type {model_type}")
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
def get_model(
@ -492,10 +493,7 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
# convert model_type to ModelType enum
model_type = ModelType.from_str(model_type)
if model_type == ModelType.DEEPSEEK_V2:
if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
@ -528,7 +526,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == ModelType.MAMBA:
elif model_type == MAMBA:
return Mamba(
model_id,
revision,
@ -552,8 +550,8 @@ def get_model(
)
if (
model_type == ModelType.GPT_BIGCODE
or model_type == ModelType.GPT2
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION:
@ -583,7 +581,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.BLOOM:
if model_type == BLOOM:
return CausalLM(
model_id=model_id,
model_class=BloomForCausalLM,
@ -594,7 +592,7 @@ def get_model(
trust_remote_code=trust_remote_code,
batch_class=BloomCausalLMBatch,
)
elif model_type == ModelType.MPT:
elif model_type == MPT:
return CausalLM(
model_id=model_id,
model_class=MPTForCausalLM,
@ -605,7 +603,7 @@ def get_model(
trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
)
elif model_type == ModelType.GPT2:
elif model_type == GPT2:
if FLASH_ATTENTION:
try:
return FlashCausalLM(
@ -640,7 +638,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == ModelType.GPT_NEOX:
elif model_type == GPT_NEOX:
if FLASH_ATTENTION:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
@ -677,7 +675,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == ModelType.PHI:
elif model_type == PHI:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -716,11 +714,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif (
model_type == ModelType.LLAMA
or model_type == ModelType.BAICHUAN
or model_type == ModelType.PHI3
):
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION:
return FlashCausalLM(
@ -744,7 +738,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.GEMMA:
if model_type == GEMMA:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -769,7 +763,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == ModelType.GEMMA2:
elif model_type == GEMMA2:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -795,7 +789,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.COHERE:
if model_type == COHERE:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -819,7 +813,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.DBRX:
if model_type == DBRX:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -846,7 +840,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type in ["RefinedWeb", "RefinedWebModel", ModelType.FALCON]:
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded:
if FLASH_ATTENTION:
if config_dict.get("alibi", False):
@ -894,7 +888,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.MISTRAL:
if model_type == MISTRAL:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -918,7 +912,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.MIXTRAL:
if model_type == MIXTRAL:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -942,7 +936,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.STARCODER2:
if model_type == STARCODER2:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -968,7 +962,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.QWEN2:
if model_type == QWEN2:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -992,7 +986,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.OPT:
if model_type == OPT:
return CausalLM(
model_id=model_id,
model_class=OPTForCausalLM,
@ -1003,7 +997,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == ModelType.T5:
if model_type == T5:
return Seq2SeqLM(
model_id=model_id,
model_class=T5ForConditionalGeneration,
@ -1019,7 +1013,7 @@ def get_model(
]
},
)
if model_type == ModelType.IDEFICS:
if model_type == IDEFICS:
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
@ -1031,7 +1025,7 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == ModelType.IDEFICS2:
if model_type == IDEFICS2:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
@ -1048,7 +1042,7 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == ModelType.PALIGEMMA:
if model_type == PALIGEMMA:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
@ -1066,7 +1060,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == ModelType.LLAVA_NEXT:
if model_type == LLAVA_NEXT:
if FLASH_ATTENTION:
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,

View File

@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
]
# Ensure that past_key_values tensors can be updated in-place
if isinstance(self.past_key_values[0], tuple):
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection

View File

@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if isinstance(self.past_key_values[0], tuple):
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection

View File

@ -253,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
]
# Ensure that past_key_values tensors can be updated in-place
if isinstance(self.past_key_values[0], tuple):
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]