mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
Simple updates.
This commit is contained in:
parent
cacaba64c3
commit
199973cc3c
@ -123,9 +123,15 @@ impl ShardedClient {
|
|||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
let first = results.first().expect("Expect at least 1 warmup result");
|
|
||||||
assert!(results.iter().all(|&item| item == *first));
|
// Take the minimum value
|
||||||
Ok(*first)
|
// Different shards hold different parts of vocab, might yield
|
||||||
|
// different available block size.
|
||||||
|
let min = results
|
||||||
|
.iter()
|
||||||
|
.min()
|
||||||
|
.expect("Expect at least 1 warmup result");
|
||||||
|
Ok(*min)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
@ -119,15 +119,19 @@ impl ShardedClient {
|
|||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
|
|
||||||
let first = results.first().expect("Expect at least 1 warmup result");
|
// Take the minimum value
|
||||||
assert!(results.iter().all(|&item| item == *first));
|
// Different shards hold different parts of vocab, might yield
|
||||||
Ok(*first)
|
// different available block size.
|
||||||
|
let min = results
|
||||||
|
.iter()
|
||||||
|
.min()
|
||||||
|
.expect("Expect at least 1 warmup result");
|
||||||
|
Ok(*min)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
MambaConfig,
|
MambaConfig,
|
||||||
)
|
)
|
||||||
@ -475,7 +475,9 @@ class Mamba(Model):
|
|||||||
def batch_type(self) -> Type[MambaBatch]:
|
def batch_type(self) -> Type[MambaBatch]:
|
||||||
return MambaBatch
|
return MambaBatch
|
||||||
|
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(
|
||||||
|
self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
||||||
|
) -> Union[Optional[int], Optional[int], Optional[int]]:
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate == 0:
|
if self.speculate is None or self.speculate == 0:
|
||||||
@ -489,7 +491,12 @@ class Mamba(Model):
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
return None
|
if max_total_tokens is None:
|
||||||
|
max_total_tokens = min(self.tokenizer.model_max_length, 4096)
|
||||||
|
|
||||||
|
if max_input_tokens is None:
|
||||||
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
return None, max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def cuda_graph_warmup(self, batch_size: int):
|
def cuda_graph_warmup(self, batch_size: int):
|
||||||
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user