Simple updates.

This commit is contained in:
Nicolas Patry 2024-10-24 11:39:02 +02:00
parent cacaba64c3
commit 199973cc3c
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 27 additions and 10 deletions

View File

@ -123,9 +123,15 @@ impl ShardedClient {
.await .await
.into_iter() .into_iter()
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?; .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
let first = results.first().expect("Expect at least 1 warmup result");
assert!(results.iter().all(|&item| item == *first)); // Take the minimum value
Ok(*first) // Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -119,15 +119,19 @@ impl ShardedClient {
)) ))
}) })
.collect(); .collect();
// Take the minimum value
let results = join_all(futures) let results = join_all(futures)
.await .await
.into_iter() .into_iter()
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?; .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
let first = results.first().expect("Expect at least 1 warmup result"); // Take the minimum value
assert!(results.iter().all(|&item| item == *first)); // Different shards hold different parts of vocab, might yield
Ok(*first) // different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional from typing import Optional, Union
from text_generation_server.models.custom_modeling.mamba_modeling import ( from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig, MambaConfig,
) )
@ -475,7 +475,9 @@ class Mamba(Model):
def batch_type(self) -> Type[MambaBatch]: def batch_type(self) -> Type[MambaBatch]:
return MambaBatch return MambaBatch
def warmup(self, batch) -> Optional[int]: def warmup(
self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Union[Optional[int], Optional[int], Optional[int]]:
# TODO: implement warmup for Mamba if needed # TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS: if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0: if self.speculate is None or self.speculate == 0:
@ -489,7 +491,12 @@ class Mamba(Model):
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None if max_total_tokens is None:
max_total_tokens = min(self.tokenizer.model_max_length, 4096)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def cuda_graph_warmup(self, batch_size: int): def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)