Finish removal.

This commit is contained in:
Nicolas Patry 2024-07-03 15:19:06 +00:00
parent f5ff9b5742
commit e2edf2beb2
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
2 changed files with 7 additions and 9 deletions

View File

@ -97,7 +97,6 @@ try:
LlavaNextForConditionalGeneration,
)
# from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
@ -128,7 +127,6 @@ except ImportError as e:
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(IDEFICSSharded)
# __all__.append(FlashMistral)
MAMBA_AVAILABLE = True
try:
@ -838,7 +836,7 @@ def get_model(
if model_type == MIXTRAL:
if FLASH_ATTENTION:
return FlashMistral(
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
@ -862,7 +860,7 @@ def get_model(
if model_type == STARCODER2:
if FLASH_ATTENTION:
return FlashMistral(
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
@ -888,7 +886,7 @@ def get_model(
if model_type == QWEN2:
if FLASH_ATTENTION:
return FlashMistral(
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,

View File

@ -9,9 +9,9 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import (
FlashMistral,
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from transformers import AutoProcessor
@ -240,7 +240,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch
class VlmCausalLM(FlashMistral):
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,