Fix tests.

This commit is contained in:
Nicolas Patry 2024-04-27 11:33:24 +02:00
parent 80c23bdd38
commit 5373fc4707
7 changed files with 17 additions and 4 deletions

View File

@ -162,7 +162,7 @@ Options:
This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.
[env: WAITING_SERVED_RATIO=]
[default: 1.2]
[default: 0.3]
```
## MAX_BATCH_PREFILL_TOKENS

View File

@ -33,7 +33,12 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
@dataclass
class FlashCausalLMBatch(Batch):
@ -788,7 +793,9 @@ class FlashCausalLM(Model):
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory

View File

@ -20,6 +20,7 @@ tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
class FlashLlama(FlashCausalLM):
def __init__(
self,

View File

@ -15,6 +15,7 @@ from text_generation_server.utils import (
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)

View File

@ -16,6 +16,7 @@ from text_generation_server.utils import (
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)

View File

@ -19,6 +19,7 @@ from text_generation_server.utils import (
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)

View File

@ -1,5 +1,6 @@
import torch
def is_xpu_available():
try:
import intel_extension_for_pytorch
@ -8,6 +9,7 @@ def is_xpu_available():
return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available()