Update style.

This commit is contained in:
Nicolas Patry 2024-04-26 14:11:33 +00:00
parent 6fc959765b
commit ab1ec3e27e
6 changed files with 16 additions and 3 deletions

View File

@ -33,7 +33,12 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -788,7 +793,9 @@ class FlashCausalLM(Model):
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device) total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max( free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory

View File

@ -20,6 +20,7 @@ tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__( def __init__(
self, self,

View File

@ -15,6 +15,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)

View File

@ -16,6 +16,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)

View File

@ -19,6 +19,7 @@ from text_generation_server.utils import (
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)

View File

@ -1,5 +1,6 @@
import torch import torch
def is_xpu_available(): def is_xpu_available():
try: try:
import intel_extension_for_pytorch import intel_extension_for_pytorch
@ -8,6 +9,7 @@ def is_xpu_available():
return hasattr(torch, "xpu") and torch.xpu.is_available() return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available() IS_XPU_SYSTEM = is_xpu_available()