fix: adjust lints and ignore specific rules

This commit is contained in:
drbh 2024-07-25 14:50:18 +00:00
parent 72c97676fd
commit a10f4010d7
6 changed files with 39 additions and 48 deletions

View File

@ -2,10 +2,6 @@ import os
import math import math
import torch import torch
from torch import nn from torch import nn
# Inverse dim formula to find dim based on number of rotations
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":

View File

@ -69,12 +69,13 @@ class TensorParallelHead(SuperLayer):
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "marlin"]: if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
pass # Local variable `quantize` is assigned to but never used
quantize = None # noqa F841
# See above, exl2 LM head can be quantized or not. # See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
pass quantize = None # noqa F841
else: else:
pass quantize = config.quantize # noqa F841
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None), get_linear(weight, bias=None),

View File

@ -1,3 +1,6 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch import torch
import enum import enum
import os import os
@ -298,12 +301,10 @@ class ModelType(enum.Enum):
"multimodal": True, "multimodal": True,
} }
@classmethod
def from_str(cls, model_type: str) -> "ModelType": __GLOBALS = locals()
for model in cls: for data in ModelType:
if model.value["type"] == model_type: __GLOBALS[data.name] = data.value["type"]
return model
raise ValueError(f"Unknown model type {model_type}")
def get_model( def get_model(
@ -492,10 +493,7 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )
# convert model_type to ModelType enum if model_type == DEEPSEEK_V2:
model_type = ModelType.from_str(model_type)
if model_type == ModelType.DEEPSEEK_V2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
head_size = max( head_size = max(
config_dict.get("qk_nope_dim", 128) config_dict.get("qk_nope_dim", 128)
@ -528,7 +526,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == ModelType.MAMBA: elif model_type == MAMBA:
return Mamba( return Mamba(
model_id, model_id,
revision, revision,
@ -552,8 +550,8 @@ def get_model(
) )
if ( if (
model_type == ModelType.GPT_BIGCODE model_type == GPT_BIGCODE
or model_type == ModelType.GPT2 or model_type == GPT2
and model_id.startswith("bigcode/") and model_id.startswith("bigcode/")
): ):
if FLASH_ATTENTION: if FLASH_ATTENTION:
@ -583,7 +581,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.BLOOM: if model_type == BLOOM:
return CausalLM( return CausalLM(
model_id=model_id, model_id=model_id,
model_class=BloomForCausalLM, model_class=BloomForCausalLM,
@ -594,7 +592,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=BloomCausalLMBatch, batch_class=BloomCausalLMBatch,
) )
elif model_type == ModelType.MPT: elif model_type == MPT:
return CausalLM( return CausalLM(
model_id=model_id, model_id=model_id,
model_class=MPTForCausalLM, model_class=MPTForCausalLM,
@ -605,7 +603,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast, batch_class=CausalLMBatchKeysLast,
) )
elif model_type == ModelType.GPT2: elif model_type == GPT2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
try: try:
return FlashCausalLM( return FlashCausalLM(
@ -640,7 +638,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == ModelType.GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig, GPTNeoXConfig,
@ -677,7 +675,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == ModelType.PHI: elif model_type == PHI:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -716,11 +714,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif ( elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
model_type == ModelType.LLAMA
or model_type == ModelType.BAICHUAN
or model_type == ModelType.PHI3
):
print(f">>> model_type: {model_type}") print(f">>> model_type: {model_type}")
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
@ -744,7 +738,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -769,7 +763,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == ModelType.GEMMA2: elif model_type == GEMMA2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -795,7 +789,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -819,7 +813,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.DBRX: if model_type == DBRX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -846,7 +840,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in ["RefinedWeb", "RefinedWebModel", ModelType.FALCON]: if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config_dict.get("alibi", False): if config_dict.get("alibi", False):
@ -894,7 +888,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -918,7 +912,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.MIXTRAL: if model_type == MIXTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -942,7 +936,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.STARCODER2: if model_type == STARCODER2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -968,7 +962,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -992,7 +986,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.OPT: if model_type == OPT:
return CausalLM( return CausalLM(
model_id=model_id, model_id=model_id,
model_class=OPTForCausalLM, model_class=OPTForCausalLM,
@ -1003,7 +997,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == ModelType.T5: if model_type == T5:
return Seq2SeqLM( return Seq2SeqLM(
model_id=model_id, model_id=model_id,
model_class=T5ForConditionalGeneration, model_class=T5ForConditionalGeneration,
@ -1019,7 +1013,7 @@ def get_model(
] ]
}, },
) )
if model_type == ModelType.IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return IDEFICSSharded( return IDEFICSSharded(
model_id, model_id,
@ -1031,7 +1025,7 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == ModelType.IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
@ -1048,7 +1042,7 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == ModelType.PALIGEMMA: if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
@ -1066,7 +1060,7 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == ModelType.LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration, model_class=LlavaNextForConditionalGeneration,

View File

@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if isinstance(self.past_key_values[0], tuple): if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values] self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection

View File

@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
image_hidden_states = self.image_hidden_states[keep_indices] image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if isinstance(self.past_key_values[0], tuple): if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values] self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection

View File

@ -253,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if isinstance(self.past_key_values[0], tuple): if type(self.past_key_values[0]) is tuple:
self.past_key_values = [ self.past_key_values = [
[t for t in layer] for layer in self.past_key_values [t for t in layer] for layer in self.past_key_values
] ]