nicer diff x2

This commit is contained in:
fxmarty 2024-05-17 08:55:37 +00:00
parent 8d7f18f41e
commit 3ded96fb4c
6 changed files with 2 additions and 8 deletions

View File

@ -10,7 +10,6 @@ from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,

View File

@ -318,7 +318,6 @@ class BaseFlashMistral(FlashCausalLM):
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")

View File

@ -21,7 +21,6 @@ from text_generation_server.utils import (
Weights,
)
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
# we split individual characters inside special tokens like [START_DNA]
@ -172,7 +171,6 @@ class GalacticaSharded(CausalLM):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")

View File

@ -22,6 +22,7 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.vlm_causal_lm import split
import re
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
@ -576,7 +577,6 @@ class IdeficsCausalLM(Model):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)

View File

@ -412,7 +412,6 @@ class Mamba(Model):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
@ -476,7 +475,7 @@ class Mamba(Model):
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS='{CUDA_GRAPHS}').")
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None

View File

@ -17,7 +17,6 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)