mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Update style.
This commit is contained in:
parent
6fc959765b
commit
ab1ec3e27e
@ -33,7 +33,12 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
@ -788,7 +793,9 @@ class FlashCausalLM(Model):
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
total_gpu_memory = torch.cuda.get_device_properties(
|
||||
self.device
|
||||
).total_memory
|
||||
|
||||
free_memory = max(
|
||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||
|
@ -20,6 +20,7 @@ tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.utils import (
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def is_xpu_available():
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
@ -8,6 +9,7 @@ def is_xpu_available():
|
||||
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
IS_XPU_SYSTEM = is_xpu_available()
|
||||
|
Loading…
Reference in New Issue
Block a user