mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Finish removal.
This commit is contained in:
parent
f5ff9b5742
commit
e2edf2beb2
@ -97,7 +97,6 @@ try:
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
# from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
@ -128,7 +127,6 @@ except ImportError as e:
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(IDEFICSSharded)
|
||||
# __all__.append(FlashMistral)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
try:
|
||||
@ -838,7 +836,7 @@ def get_model(
|
||||
|
||||
if model_type == MIXTRAL:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMistral(
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMixtralForCausalLM,
|
||||
revision=revision,
|
||||
@ -862,7 +860,7 @@ def get_model(
|
||||
|
||||
if model_type == STARCODER2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMistral(
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashStarcoder2ForCausalLM,
|
||||
revision=revision,
|
||||
@ -888,7 +886,7 @@ def get_model(
|
||||
|
||||
if model_type == QWEN2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMistral(
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2ForCausalLM,
|
||||
revision=revision,
|
||||
|
@ -9,9 +9,9 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
FlashMistral,
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from transformers import AutoProcessor
|
||||
|
||||
@ -240,7 +240,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(FlashMistral):
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
Loading…
Reference in New Issue
Block a user