mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Using both value from config as they might not be correct. (#2817)
* Using both value from config as they might not be correct. * Fixing max_position_embeddings for falcon. * Simple attempt to fix the healthcheck block allocation. * Much simpler solution. * Default value for Backend start_health
This commit is contained in:
parent
a2d878fa0f
commit
82c24f7420
@ -104,6 +104,10 @@ impl Backend for BackendV2 {
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn start_health(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
|
@ -111,6 +111,10 @@ impl Backend for BackendV3 {
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn start_health(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
|
@ -217,8 +217,8 @@ impl Health for ShardedClient {
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
truncate: 1,
|
||||
add_special_tokens: false,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
@ -241,7 +241,7 @@ impl Health for ShardedClient {
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
slots: vec![0],
|
||||
cache_len: 0,
|
||||
adapter_id: None,
|
||||
chunk_len: None,
|
||||
|
@ -33,6 +33,13 @@ pub trait Backend {
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool;
|
||||
|
||||
/// The state of the health on startup
|
||||
/// Typically false, or true if the backend includes
|
||||
/// a warmup phase.
|
||||
fn start_health(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Inference struct
|
||||
@ -75,7 +82,7 @@ impl Infer {
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
||||
// Backend health
|
||||
let backend_health = Arc::new(AtomicBool::new(false));
|
||||
let backend_health = Arc::new(AtomicBool::new(backend.start_health()));
|
||||
|
||||
Self {
|
||||
validation,
|
||||
|
@ -78,6 +78,7 @@ class RWConfig(PretrainedConfig):
|
||||
self.alibi = False
|
||||
self.rotary = True
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = 2048
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
# Backward compatibility with n_embed kwarg
|
||||
|
@ -1304,6 +1304,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||
self.config = config
|
||||
# Validation is done in the model itself
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||
@ -1594,7 +1595,10 @@ class FlashCausalLM(Model):
|
||||
if max_total_tokens is None:
|
||||
if get_support_chunking():
|
||||
model_max_length = self.tokenizer.model_max_length
|
||||
max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length)
|
||||
max_position_embeddings = self.config.max_position_embeddings
|
||||
max_total_tokens = min(
|
||||
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
|
||||
)
|
||||
else:
|
||||
max_total_tokens = sum(batch.cache_lengths)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user