mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing legacy and CPU configs.
This commit is contained in:
parent
2f243a1a15
commit
f48b6109fd
@ -21,10 +21,38 @@ use tracing_subscriber::EnvFilter;
|
|||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct RawConfig {
|
||||||
|
max_position_embeddings: Option<usize>,
|
||||||
|
n_positions: Option<usize>,
|
||||||
|
max_seq_len: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
max_seq_len: Option<usize>,
|
}
|
||||||
|
|
||||||
|
impl From<RawConfig> for Config {
|
||||||
|
fn from(other: RawConfig) -> Self {
|
||||||
|
if other.max_position_embeddings.is_some() {
|
||||||
|
Config {
|
||||||
|
max_position_embeddings: other.max_position_embeddings,
|
||||||
|
}
|
||||||
|
} else if other.max_seq_len.is_some() {
|
||||||
|
Config {
|
||||||
|
max_position_embeddings: other.max_seq_len,
|
||||||
|
}
|
||||||
|
} else if other.n_positions.is_some() {
|
||||||
|
Config {
|
||||||
|
max_position_embeddings: other.n_positions,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Config {
|
||||||
|
max_position_embeddings: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
@ -1309,33 +1337,30 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: Config = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
let config: Config = config.into();
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
let max_default = 4096;
|
let max_default = 4096;
|
||||||
|
|
||||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
if max_position_embeddings > max_default {
|
||||||
if max_position_embeddings > max_default {
|
let max = max_position_embeddings;
|
||||||
let max = max_position_embeddings;
|
if args.max_input_tokens.is_none()
|
||||||
if args.max_input_tokens.is_none()
|
&& args.max_total_tokens.is_none()
|
||||||
&& args.max_total_tokens.is_none()
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
{
|
||||||
{
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
max_default
|
|
||||||
} else {
|
|
||||||
max_position_embeddings
|
|
||||||
}
|
}
|
||||||
|
Ok(max_default)
|
||||||
|
} else {
|
||||||
|
Ok(max_position_embeddings)
|
||||||
}
|
}
|
||||||
_ => {
|
} else {
|
||||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
Err(Box::new(LauncherError::ArgumentValidation(
|
||||||
"no max defined".to_string(),
|
"no max defined".to_string(),
|
||||||
)));
|
)))
|
||||||
}
|
}
|
||||||
};
|
|
||||||
Ok(max_position_embeddings)
|
|
||||||
};
|
};
|
||||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||||
|
|
||||||
|
@ -472,14 +472,26 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGPT2(
|
try:
|
||||||
model_id,
|
return FlashGPT2(
|
||||||
revision,
|
model_id,
|
||||||
quantize=quantize,
|
revision,
|
||||||
speculator=speculator,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
speculator=speculator,
|
||||||
trust_remote_code=trust_remote_code,
|
dtype=dtype,
|
||||||
)
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Lots of legacy models with various weight names.
|
||||||
|
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
||||||
|
return CausalLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||||
else:
|
else:
|
||||||
|
@ -14,13 +14,20 @@ from typing import List, Optional
|
|||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
try:
|
||||||
VlmCausalLMBatch,
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
)
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
|
VlmCausalLMBatch,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
|
||||||
|
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch}
|
||||||
|
except (ImportError, NotImplementedError):
|
||||||
|
VLM_BATCH_TYPES = set()
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
|
||||||
from text_generation_server.models.globals import set_model_id
|
from text_generation_server.models.globals import set_model_id
|
||||||
|
|
||||||
|
|
||||||
@ -96,11 +103,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.model.batch_type in {
|
if (
|
||||||
IdeficsCausalLMBatch,
|
self.model.batch_type in VLM_BATCH_TYPES
|
||||||
VlmCausalLMBatch,
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
PaliGemmaBatch,
|
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
|
||||||
batch = self.model.batch_type.from_pb_processor(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
@ -121,11 +126,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
if self.model.batch_type in {
|
if (
|
||||||
IdeficsCausalLMBatch,
|
self.model.batch_type in VLM_BATCH_TYPES
|
||||||
VlmCausalLMBatch,
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
PaliGemmaBatch,
|
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
|
||||||
batch = self.model.batch_type.from_pb_processor(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user