Simple updates.

This commit is contained in:
Nicolas Patry 2024-10-24 11:39:02 +02:00
parent cacaba64c3
commit 199973cc3c
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 27 additions and 10 deletions

View File

@ -123,9 +123,15 @@ impl ShardedClient {
.await
.into_iter()
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
let first = results.first().expect("Expect at least 1 warmup result");
assert!(results.iter().all(|&item| item == *first));
Ok(*first)
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}
/// Generate one token for each request in the given batch

View File

@ -119,15 +119,19 @@ impl ShardedClient {
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
let first = results.first().expect("Expect at least 1 warmup result");
assert!(results.iter().all(|&item| item == *first));
Ok(*first)
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}
/// Generate one token for each request in the given batch

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from typing import Optional, Union
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
@ -475,7 +475,9 @@ class Mamba(Model):
def batch_type(self) -> Type[MambaBatch]:
return MambaBatch
def warmup(self, batch) -> Optional[int]:
def warmup(
self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Union[Optional[int], Optional[int], Optional[int]]:
# TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:
@ -489,7 +491,12 @@ class Mamba(Model):
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None
if max_total_tokens is None:
max_total_tokens = min(self.tokenizer.model_max_length, 4096)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)