mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix tests.
This commit is contained in:
parent
80c23bdd38
commit
5373fc4707
@ -162,7 +162,7 @@ Options:
|
|||||||
This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.
|
This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.
|
||||||
|
|
||||||
[env: WAITING_SERVED_RATIO=]
|
[env: WAITING_SERVED_RATIO=]
|
||||||
[default: 1.2]
|
[default: 0.3]
|
||||||
|
|
||||||
```
|
```
|
||||||
## MAX_BATCH_PREFILL_TOKENS
|
## MAX_BATCH_PREFILL_TOKENS
|
||||||
|
@ -33,7 +33,12 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import (
|
||||||
|
IS_CUDA_SYSTEM,
|
||||||
|
IS_ROCM_SYSTEM,
|
||||||
|
IS_XPU_SYSTEM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -788,14 +793,16 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
total_gpu_memory = torch.cuda.get_device_properties(
|
||||||
|
self.device
|
||||||
|
).total_memory
|
||||||
|
|
||||||
free_memory = max(
|
free_memory = max(
|
||||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||||
)
|
)
|
||||||
elif IS_XPU_SYSTEM:
|
elif IS_XPU_SYSTEM:
|
||||||
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
||||||
free_memory = int(total_gpu_memory *0.5)
|
free_memory = int(total_gpu_memory * 0.5)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashModel is only available on GPU")
|
raise NotImplementedError("FlashModel is only available on GPU")
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ tracer = trace.get_tracer(__name__)
|
|||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def is_xpu_available():
|
def is_xpu_available():
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch
|
import intel_extension_for_pytorch
|
||||||
@ -8,6 +9,7 @@ def is_xpu_available():
|
|||||||
|
|
||||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
|
||||||
|
|
||||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||||
IS_XPU_SYSTEM = is_xpu_available()
|
IS_XPU_SYSTEM = is_xpu_available()
|
||||||
|
Loading…
Reference in New Issue
Block a user