mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'huggingface:main' into bnb-4bit
This commit is contained in:
commit
88d753d79b
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -2893,7 +2893,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "0.9.2"
|
version = "0.9.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
@ -2913,7 +2913,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.9.2"
|
version = "0.9.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -2929,7 +2929,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.9.2"
|
version = "0.9.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -2945,7 +2945,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.9.2"
|
version = "0.9.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
@ -8,7 +8,7 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.9.2"
|
version = "0.9.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
15
Dockerfile
15
Dockerfile
@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
|
|||||||
# Build specific version of flash attention
|
# Build specific version of flash attention
|
||||||
RUN make build-flash-attention
|
RUN make build-flash-attention
|
||||||
|
|
||||||
|
# Build Flash Attention v2 CUDA kernels
|
||||||
|
FROM kernel-builder as flash-att-v2-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-flash-att-v2 Makefile
|
||||||
|
|
||||||
|
# Build specific version of flash attention v2
|
||||||
|
RUN make build-flash-attention-v2
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder as custom-kernels-builder
|
FROM kernel-builder as custom-kernels-builder
|
||||||
|
|
||||||
@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
|
|||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention v2 builder
|
||||||
|
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
@ -63,6 +63,8 @@ to power LLMs api-inference widgets.
|
|||||||
- [Starcoder](https://huggingface.co/bigcode/starcoder)
|
- [Starcoder](https://huggingface.co/bigcode/starcoder)
|
||||||
- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
|
- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
|
||||||
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
||||||
|
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||||
|
- [Llama V2](https://huggingface.co/meta-llama)
|
||||||
|
|
||||||
Other architectures are supported on a best effort basis using:
|
Other architectures are supported on a best effort basis using:
|
||||||
|
|
||||||
@ -132,6 +134,10 @@ print(text)
|
|||||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
|
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
|
||||||
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
|
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
|
||||||
|
|
||||||
|
### Using on private models or gated models
|
||||||
|
|
||||||
|
You can use `HUGGING_FACE_HUB_TOKEN` environment variable to set the token used by `text-generation-inference` to give access to protected ressources.
|
||||||
|
|
||||||
### Distributed Tracing
|
### Distributed Tracing
|
||||||
|
|
||||||
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
||||||
@ -211,7 +217,7 @@ sudo apt-get install libssl-dev gcc -y
|
|||||||
### CUDA Kernels
|
### CUDA Kernels
|
||||||
|
|
||||||
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
||||||
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
|
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
||||||
|
|
||||||
Be aware that the official Docker image has them enabled by default.
|
Be aware that the official Docker image has them enabled by default.
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "0.9.2"
|
"version": "0.9.3"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
@ -193,8 +193,8 @@ struct Args {
|
|||||||
/// depends on other parameters like if you're using quantization, flash attention
|
/// depends on other parameters like if you're using quantization, flash attention
|
||||||
/// or the model implementation, text-generation-inference cannot infer this number
|
/// or the model implementation, text-generation-inference cannot infer this number
|
||||||
/// automatically.
|
/// automatically.
|
||||||
#[clap(default_value = "16000", long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
|
||||||
/// This setting defines how many tokens can be passed before forcing the waiting
|
/// This setting defines how many tokens can be passed before forcing the waiting
|
||||||
/// queries to be put on the batch (if the size of the batch allows for it).
|
/// queries to be put on the batch (if the size of the batch allows for it).
|
||||||
@ -276,17 +276,9 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
|
|
||||||
/// ngrok domain name where the axum webserver will be available at
|
/// ngrok edge
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_domain: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
|
|
||||||
/// ngrok basic auth username
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_username: Option<String>,
|
|
||||||
|
|
||||||
/// ngrok basic auth password
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_password: Option<String>,
|
|
||||||
|
|
||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
@ -378,12 +370,6 @@ fn shard_manager(
|
|||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
// Use cuda allocator. It leads to less memory fragmentation
|
|
||||||
envs.push((
|
|
||||||
"PYTORCH_CUDA_ALLOC_CONF".into(),
|
|
||||||
"backend:cudaMallocAsync".into(),
|
|
||||||
));
|
|
||||||
|
|
||||||
// Torch Distributed Env vars
|
// Torch Distributed Env vars
|
||||||
envs.push(("RANK".into(), rank.to_string().into()));
|
envs.push(("RANK".into(), rank.to_string().into()));
|
||||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
@ -437,7 +423,7 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard");
|
||||||
let mut p = match Command::new("text-generation-server")
|
let mut p = match Command::new("text-generation-server")
|
||||||
.args(shard_args)
|
.args(shard_args)
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
@ -502,17 +488,17 @@ fn shard_manager(
|
|||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
p.kill().unwrap();
|
p.kill().unwrap();
|
||||||
let _ = p.wait();
|
let _ = p.wait();
|
||||||
tracing::info!("Shard {rank} terminated");
|
tracing::info!("Shard terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shard is ready
|
// Shard is ready
|
||||||
if uds.exists() && !ready {
|
if uds.exists() && !ready {
|
||||||
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
tracing::info!("Shard ready in {:?}", start_time.elapsed());
|
||||||
status_sender.send(ShardStatus::Ready).unwrap();
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
ready = true;
|
ready = true;
|
||||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||||
tracing::info!("Waiting for shard {rank} to be ready...");
|
tracing::info!("Waiting for shard to be ready...");
|
||||||
wait_time = Instant::now();
|
wait_time = Instant::now();
|
||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
@ -869,8 +855,6 @@ fn spawn_webserver(
|
|||||||
args.max_total_tokens.to_string(),
|
args.max_total_tokens.to_string(),
|
||||||
"--max-batch-prefill-tokens".to_string(),
|
"--max-batch-prefill-tokens".to_string(),
|
||||||
args.max_batch_prefill_tokens.to_string(),
|
args.max_batch_prefill_tokens.to_string(),
|
||||||
"--max-batch-total-tokens".to_string(),
|
|
||||||
args.max_batch_total_tokens.to_string(),
|
|
||||||
"--waiting-served-ratio".to_string(),
|
"--waiting-served-ratio".to_string(),
|
||||||
args.waiting_served_ratio.to_string(),
|
args.waiting_served_ratio.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
@ -887,6 +871,12 @@ fn spawn_webserver(
|
|||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Model optional max batch total tokens
|
||||||
|
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
|
router_args.push("--max-batch-total-tokens".to_string());
|
||||||
|
router_args.push(max_batch_total_tokens.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(ref revision) = args.revision {
|
if let Some(ref revision) = args.revision {
|
||||||
router_args.push("--revision".to_string());
|
router_args.push("--revision".to_string());
|
||||||
@ -911,26 +901,11 @@ fn spawn_webserver(
|
|||||||
|
|
||||||
// Ngrok
|
// Ngrok
|
||||||
if args.ngrok {
|
if args.ngrok {
|
||||||
let authtoken = args.ngrok_authtoken.ok_or_else(|| {
|
|
||||||
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
|
|
||||||
LauncherError::WebserverCannotStart
|
|
||||||
})?;
|
|
||||||
|
|
||||||
router_args.push("--ngrok".to_string());
|
router_args.push("--ngrok".to_string());
|
||||||
router_args.push("--ngrok-authtoken".to_string());
|
router_args.push("--ngrok-authtoken".to_string());
|
||||||
router_args.push(authtoken);
|
router_args.push(args.ngrok_authtoken.unwrap());
|
||||||
|
router_args.push("--ngrok-edge".to_string());
|
||||||
if let Some(domain) = args.ngrok_domain {
|
router_args.push(args.ngrok_edge.unwrap());
|
||||||
router_args.push("--ngrok-domain".to_string());
|
|
||||||
router_args.push(domain);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
|
|
||||||
router_args.push("--ngrok-username".to_string());
|
|
||||||
router_args.push(username);
|
|
||||||
router_args.push("--ngrok-password".to_string());
|
|
||||||
router_args.push(password);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
@ -1008,7 +983,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
|||||||
|
|
||||||
fn main() -> Result<(), LauncherError> {
|
fn main() -> Result<(), LauncherError> {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let args = Args::parse();
|
let args: Args = Args::parse();
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
// Filter events with LOG_LEVEL
|
||||||
let env_filter =
|
let env_filter =
|
||||||
@ -1045,18 +1020,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
args.max_batch_prefill_tokens, args.max_input_length
|
args.max_batch_prefill_tokens, args.max_input_length
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
args.max_batch_prefill_tokens, args.max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
args.max_total_tokens, args.max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if args.validation_workers == 0 {
|
if args.validation_workers == 0 {
|
||||||
return Err(LauncherError::ArgumentValidation(
|
return Err(LauncherError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
@ -1074,6 +1038,35 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
tracing::info!("Sharding model on {num_shard} processes");
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
|
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
args.max_total_tokens, max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.ngrok {
|
||||||
|
if args.ngrok_authtoken.is_none() {
|
||||||
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.ngrok_edge.is_none() {
|
||||||
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
let r = running.clone();
|
let r = running.clone();
|
||||||
|
@ -198,9 +198,10 @@ message DecodeResponse {
|
|||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
/// Maximum number of tokens that the client will send
|
|
||||||
uint32 max_total_tokens = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty response
|
/// Empty response
|
||||||
message WarmupResponse {}
|
message WarmupResponse {
|
||||||
|
/// Maximum number of tokens supported by the model
|
||||||
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
}
|
||||||
|
@ -103,8 +103,7 @@ impl Client {
|
|||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
) -> Result<Option<u32>> {
|
||||||
) -> Result<()> {
|
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
@ -143,13 +142,9 @@ impl Client {
|
|||||||
max_tokens: 0,
|
max_tokens: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
||||||
batch: Some(batch),
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
max_total_tokens,
|
Ok(response.max_supported_total_tokens)
|
||||||
})
|
|
||||||
.inject_context();
|
|
||||||
self.stub.warmup(request).await?.into_inner();
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
@ -95,14 +95,11 @@ impl ShardedClient {
|
|||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
) -> Result<Option<u32>> {
|
||||||
) -> Result<()> {
|
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| {
|
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
||||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// all shards return the same message
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
|
@ -53,7 +53,7 @@ impl Infer {
|
|||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding);
|
let queue = Queue::new(requires_padding, 16);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
@ -37,8 +37,8 @@ struct Args {
|
|||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
#[clap(default_value = "16000", long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: Option<u32>,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
@ -64,11 +64,7 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_domain: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_username: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_password: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), RouterError> {
|
fn main() -> Result<(), RouterError> {
|
||||||
@ -96,9 +92,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_domain,
|
ngrok_edge,
|
||||||
ngrok_username,
|
|
||||||
ngrok_password,
|
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
@ -110,18 +104,22 @@ fn main() -> Result<(), RouterError> {
|
|||||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||||
}
|
}
|
||||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
// Finally, convert to AllowOrigin
|
// Finally, convert to AllowOrigin
|
||||||
@ -210,14 +208,35 @@ fn main() -> Result<(), RouterError> {
|
|||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
sharded_client
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
.warmup(
|
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
||||||
max_input_length as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(RouterError::Warmup)?;
|
.map_err(RouterError::Warmup)?
|
||||||
|
{
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||||
|
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
||||||
|
);
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
max_batch_total_tokens
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
max_supported_batch_total_tokens
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
let addr = match hostname.parse() {
|
||||||
@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_supported_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -249,9 +268,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_domain,
|
ngrok_edge,
|
||||||
ngrok_username,
|
|
||||||
ngrok_password,
|
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -33,12 +33,12 @@ pub(crate) struct Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool) -> Self {
|
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(requires_padding, queue_receiver));
|
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
@ -81,8 +81,12 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
|
async fn queue_task(
|
||||||
let mut state = State::new(requires_padding);
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
receiver: flume::Receiver<QueueCommand>,
|
||||||
|
) {
|
||||||
|
let mut state = State::new(requires_padding, block_size);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -119,15 +123,19 @@ struct State {
|
|||||||
|
|
||||||
/// Whether the model is using padding
|
/// Whether the model is using padding
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
|
||||||
|
/// Paged Attention block size
|
||||||
|
block_size: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(requires_padding: bool) -> Self {
|
fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
block_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,10 +195,21 @@ impl State {
|
|||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||||
} else {
|
} else {
|
||||||
prefill_tokens += entry.request.input_length;
|
// pad to block size
|
||||||
|
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||||
|
/ self.block_size)
|
||||||
|
* self.block_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.requires_padding {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
} else {
|
||||||
|
// pad to block size
|
||||||
|
decode_tokens +=
|
||||||
|
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
|
||||||
|
/ self.block_size)
|
||||||
|
* self.block_size;
|
||||||
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|| (prefill_tokens + decode_tokens) > token_budget
|
|| (prefill_tokens + decode_tokens) > token_budget
|
||||||
@ -321,7 +340,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
@ -337,7 +356,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1, 1).is_none());
|
assert!(state.next_batch(None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||||
@ -345,7 +364,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -377,7 +396,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -410,14 +429,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||||
@ -425,7 +444,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -458,7 +477,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -483,7 +502,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -524,9 +524,7 @@ pub async fn run(
|
|||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
ngrok_domain: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
ngrok_username: Option<String>,
|
|
||||||
ngrok_password: Option<String>,
|
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -696,32 +694,25 @@ pub async fn run(
|
|||||||
#[cfg(feature = "ngrok")]
|
#[cfg(feature = "ngrok")]
|
||||||
{
|
{
|
||||||
use ngrok::config::TunnelBuilder;
|
use ngrok::config::TunnelBuilder;
|
||||||
use ngrok::tunnel::UrlTunnel;
|
|
||||||
|
|
||||||
let _ = addr;
|
let _ = addr;
|
||||||
|
|
||||||
let authtoken =
|
let authtoken =
|
||||||
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
|
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||||
|
|
||||||
let mut tunnel = ngrok::Session::builder()
|
let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");
|
||||||
|
|
||||||
|
let tunnel = ngrok::Session::builder()
|
||||||
.authtoken(authtoken)
|
.authtoken(authtoken)
|
||||||
.connect()
|
.connect()
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.http_endpoint();
|
.labeled_tunnel()
|
||||||
|
.label("edge", edge);
|
||||||
if let Some(domain) = ngrok_domain {
|
|
||||||
tunnel = tunnel.domain(domain);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
|
|
||||||
tunnel = tunnel.basic_auth(username, password);
|
|
||||||
}
|
|
||||||
|
|
||||||
let listener = tunnel.listen().await.unwrap();
|
let listener = tunnel.listen().await.unwrap();
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
tracing::info!("Ingress URL: {:?}", listener.url());
|
|
||||||
axum::Server::builder(listener)
|
axum::Server::builder(listener)
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
//Wait until all requests are finished to shut down
|
//Wait until all requests are finished to shut down
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
include Makefile-flash-att
|
include Makefile-flash-att
|
||||||
|
include Makefile-flash-att-v2
|
||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
|
13
server/Makefile-flash-att-v2
Normal file
13
server/Makefile-flash-att-v2
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||||
|
|
||||||
|
flash-attention-v2:
|
||||||
|
# Clone flash attention
|
||||||
|
pip install packaging
|
||||||
|
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
|
build-flash-attention-v2: flash-attention-v2
|
||||||
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||||
|
cd flash-attention-v2 && python setup.py build
|
||||||
|
|
||||||
|
install-flash-attention-v2: build-flash-attention-v2
|
||||||
|
cd flash-attention-v2 && python setup.py install
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "0.9.2"
|
version = "0.9.3"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
|
@ -196,6 +196,8 @@ def quantize(
|
|||||||
percdamp: float = 0.01,
|
percdamp: float = 0.01,
|
||||||
act_order: bool = False,
|
act_order: bool = False,
|
||||||
):
|
):
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
download_weights(
|
download_weights(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -209,6 +211,7 @@ def quantize(
|
|||||||
bits=4,
|
bits=4,
|
||||||
groupsize=128,
|
groupsize=128,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
upload_to_model_id=upload_to_model_id,
|
upload_to_model_id=upload_to_model_id,
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
|
@ -42,35 +42,10 @@ __all__ = [
|
|||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
"{} requires CUDA and Flash Attention kernels to be installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
FLASH_ATTENTION = True
|
||||||
try:
|
try:
|
||||||
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
|
||||||
"{} requires CUDA. No compatible CUDA devices found."
|
|
||||||
)
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
|
|
||||||
supported = is_sm75 or is_sm8x or is_sm90
|
|
||||||
if not supported:
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
|
||||||
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
|
|
||||||
"No compatible CUDA device found."
|
|
||||||
)
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
from text_generation_server.models.flash_llama import (
|
from text_generation_server.models.flash_llama import (
|
||||||
@ -80,13 +55,8 @@ try:
|
|||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
except ImportError as e:
|
||||||
else:
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
FLASH_ATTENTION = False
|
|
||||||
except ImportError:
|
|
||||||
logger.opt(exception=True).warning(
|
|
||||||
"Could not import Flash Attention enabled models"
|
|
||||||
)
|
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -23,25 +23,77 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_scaling=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
class LlamaRMSNorm(nn.Module):
|
||||||
def __init__(self, prefix, weights, eps=1e-6):
|
def __init__(self, prefix, weights, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@ -59,7 +111,8 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(
|
hidden_states = hidden_states * torch.rsqrt(
|
||||||
variance + self.variance_epsilon
|
variance + self.variance_epsilon
|
||||||
)
|
)
|
||||||
@ -94,6 +147,27 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
return normed_hidden_states, res
|
return normed_hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
|
||||||
|
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
|
||||||
|
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
|
||||||
|
]
|
||||||
|
weight = torch.cat(w, dim=0)
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
bias = None
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -118,6 +192,12 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
f"and `num_shards`: {weights.process_group.size()}"
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
@ -131,9 +211,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -148,38 +229,37 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
qkv[:, 0],
|
query,
|
||||||
qkv[:, 1],
|
torch.select(kv, dim=1, index=0),
|
||||||
qkv[:, 2],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -187,7 +267,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
block_size = kv_cache[1].shape[3]
|
block_size = kv_cache[1].shape[3]
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
@ -324,6 +404,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
if self.num_heads_kv == 1:
|
|
||||||
# Expand to query shape
|
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# Expand to query shape
|
|
||||||
kv = (
|
|
||||||
kv.unsqueeze(2)
|
|
||||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
|
||||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -5,13 +5,11 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# Expand from 1 to num_heads
|
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -710,14 +710,14 @@ class FlashCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
global CACHE_MANAGER
|
global CACHE_MANAGER
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats(self.device)
|
||||||
try:
|
try:
|
||||||
CACHE_MANAGER = CacheManager(
|
CACHE_MANAGER = CacheManager(
|
||||||
# Adds some wiggle room
|
batch.blocks,
|
||||||
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
@ -727,11 +727,43 @@ class FlashCausalLM(Model):
|
|||||||
_, batch = self.generate_token(batch)
|
_, batch = self.generate_token(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
f"prefill tokens. "
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
peak_memory = torch.cuda.max_memory_reserved(self.device)
|
||||||
|
|
||||||
|
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
|
||||||
|
# 0.98 to add some wiggle room
|
||||||
|
num_blocks = (
|
||||||
|
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
|
||||||
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
|
+ batch.blocks
|
||||||
|
)
|
||||||
|
|
||||||
|
del CACHE_MANAGER
|
||||||
del batch
|
del batch
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
CACHE_MANAGER = CacheManager(
|
||||||
|
num_blocks,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.dtype,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
@ -991,7 +1023,6 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if stopped:
|
if stopped:
|
||||||
del batch
|
del batch
|
||||||
torch.cuda.empty_cache()
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, None
|
return generations, None
|
||||||
|
|
||||||
|
@ -2,13 +2,13 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig
|
|
||||||
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
|
LlamaConfig,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = LlamaConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=model.model.num_heads,
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -58,8 +58,9 @@ class Model(ABC):
|
|||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B, max_total_tokens: int):
|
def warmup(self, batch: B) -> Optional[int]:
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
|
return None
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
|
@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
filtered_batch = batch.filter(request.request_ids)
|
filtered_batch = batch.filter(request.request_ids)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
self.model.warmup(batch, request.max_total_tokens)
|
max_supported_total_tokens = self.model.warmup(batch)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
return generate_pb2.WarmupResponse(
|
||||||
torch.cuda.empty_cache()
|
max_supported_total_tokens=max_supported_total_tokens
|
||||||
|
)
|
||||||
return generate_pb2.WarmupResponse()
|
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
|
124
server/text_generation_server/utils/flash_attn.py
Normal file
124
server/text_generation_server/utils/flash_attn.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
|
HAS_FLASH_ATTN = False
|
||||||
|
HAS_FLASH_ATTN_V2 = False
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
|
||||||
|
)
|
||||||
|
if not (is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
|
HAS_FLASH_ATTN_V2 = True
|
||||||
|
except ImportError as e:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not (is_sm75 or is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
):
|
||||||
|
if HAS_FLASH_ATTN_V2:
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if HAS_FLASH_ATTN:
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError("flash attention is not installed")
|
@ -13,6 +13,9 @@ import transformers
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from text_generation_server.utils import initialize_torch_distributed, Weights
|
||||||
|
from text_generation_server.utils.hub import weight_files
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
|
|||||||
maxshrink=0.8,
|
maxshrink=0.8,
|
||||||
trits=False,
|
trits=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.maxq = torch.tensor(2**bits - 1)
|
self.maxq = torch.tensor(2**bits - 1)
|
||||||
self.perchannel = perchannel
|
self.perchannel = perchannel
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
@ -600,6 +602,8 @@ def sequential(
|
|||||||
nsamples,
|
nsamples,
|
||||||
bits,
|
bits,
|
||||||
groupsize,
|
groupsize,
|
||||||
|
*,
|
||||||
|
hooks,
|
||||||
percdamp=0.01,
|
percdamp=0.01,
|
||||||
sym: bool = False,
|
sym: bool = False,
|
||||||
act_order: bool = False,
|
act_order: bool = False,
|
||||||
@ -637,7 +641,7 @@ def sequential(
|
|||||||
layers[0] = Catcher(layers[0])
|
layers[0] = Catcher(layers[0])
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
try:
|
try:
|
||||||
model(batch[0])
|
model(batch[0].cuda())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
layers[0] = layers[0].module
|
layers[0] = layers[0].module
|
||||||
@ -646,6 +650,8 @@ def sequential(
|
|||||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||||
# model.model.norm = model.model.norm.cpu()
|
# model.model.norm = model.model.norm.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
outs = torch.zeros_like(inps)
|
outs = torch.zeros_like(inps)
|
||||||
|
|
||||||
@ -662,10 +668,8 @@ def sequential(
|
|||||||
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||||
print("+==================+==============+============+===========+=======+")
|
print("+==================+==============+============+===========+=======+")
|
||||||
|
|
||||||
from accelerate.hooks import remove_hook_from_submodules
|
layer = layers[i]
|
||||||
|
layer.load()
|
||||||
layer = layers[i].to(dev)
|
|
||||||
remove_hook_from_submodules(layer)
|
|
||||||
full = find_layers(layer)
|
full = find_layers(layer)
|
||||||
sequential = [list(full.keys())]
|
sequential = [list(full.keys())]
|
||||||
|
|
||||||
@ -677,6 +681,7 @@ def sequential(
|
|||||||
gptq[name].quantizer.configure(
|
gptq[name].quantizer.configure(
|
||||||
bits, perchannel=True, sym=sym, mse=False
|
bits, perchannel=True, sym=sym, mse=False
|
||||||
)
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
def add_batch(name):
|
def add_batch(name):
|
||||||
def tmp(_, inp, out):
|
def tmp(_, inp, out):
|
||||||
@ -688,7 +693,6 @@ def sequential(
|
|||||||
for name in subset:
|
for name in subset:
|
||||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||||
for j in range(nsamples):
|
for j in range(nsamples):
|
||||||
|
|
||||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
for h in handles:
|
for h in handles:
|
||||||
h.remove()
|
h.remove()
|
||||||
@ -714,7 +718,7 @@ def sequential(
|
|||||||
for j in range(nsamples):
|
for j in range(nsamples):
|
||||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
|
|
||||||
layers[i] = layer.cpu()
|
layer.unload()
|
||||||
del layer
|
del layer
|
||||||
del gptq
|
del gptq
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def setdeepattr(module, full_name, tensor):
|
||||||
|
current = module
|
||||||
|
tokens = full_name.split(".")
|
||||||
|
for token in tokens[:-1]:
|
||||||
|
current = getattr(current, token)
|
||||||
|
setattr(current, tokens[-1], tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def getdeepattr(module, full_name):
|
||||||
|
current = module
|
||||||
|
tokens = full_name.split(".")
|
||||||
|
for token in tokens:
|
||||||
|
current = getattr(current, token)
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights_pre_hook(module_name, weights, recursive=False):
|
||||||
|
def inner(module, args):
|
||||||
|
print(f"Pre hook {module_name}")
|
||||||
|
local_params = {}
|
||||||
|
for k, v in module.named_parameters():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
for k, v in module.named_buffers():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
|
||||||
|
for local_param in local_params:
|
||||||
|
current_tensor = getdeepattr(module, local_param)
|
||||||
|
if current_tensor.device == torch.device("meta"):
|
||||||
|
# print(f"Loading {local_param}")
|
||||||
|
if module_name:
|
||||||
|
tensor_name = f"{module_name}.{local_param}"
|
||||||
|
else:
|
||||||
|
tensor_name = local_param
|
||||||
|
tensor = weights.get_tensor(tensor_name)
|
||||||
|
setdeepattr(module, local_param, nn.Parameter(tensor))
|
||||||
|
else:
|
||||||
|
setdeepattr(
|
||||||
|
module,
|
||||||
|
local_param,
|
||||||
|
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
|
||||||
|
)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights_post_hook(module_name, weights, recursive=False):
|
||||||
|
def inner(module, args, output):
|
||||||
|
print(f"Post hook {module_name}")
|
||||||
|
local_params = {}
|
||||||
|
for k, v in module.named_parameters():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
for k, v in module.named_buffers():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
for local_param in local_params:
|
||||||
|
# print(f"Unloading {local_param}")
|
||||||
|
current_tensor = getdeepattr(module, local_param)
|
||||||
|
setdeepattr(
|
||||||
|
module,
|
||||||
|
local_param,
|
||||||
|
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def quantize(
|
def quantize(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
bits: int,
|
bits: int,
|
||||||
groupsize: int,
|
groupsize: int,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
|
revision: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
upload_to_model_id: Optional[str],
|
upload_to_model_id: Optional[str],
|
||||||
percdamp: float,
|
percdamp: float,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
):
|
):
|
||||||
print("loading model")
|
print("loading model")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map="balanced_low_0",
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
print("LOADED model")
|
print("LOADED model")
|
||||||
|
files = weight_files(model_id, revision, extension=".safetensors")
|
||||||
|
process_group, _, _ = initialize_torch_distributed()
|
||||||
|
weights = Weights(
|
||||||
|
files,
|
||||||
|
device=torch.device("cuda:0"),
|
||||||
|
dtype=torch.float16,
|
||||||
|
process_group=process_group,
|
||||||
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||||
|
)
|
||||||
|
hooks = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
|
||||||
|
def load(module, name):
|
||||||
|
def _load():
|
||||||
|
load_weights_pre_hook(name, weights, recursive=True)(module, None)
|
||||||
|
|
||||||
|
return _load
|
||||||
|
|
||||||
|
def unload(module, name):
|
||||||
|
def _unload():
|
||||||
|
load_weights_post_hook(name, weights, recursive=True)(
|
||||||
|
module, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
return _unload
|
||||||
|
|
||||||
|
module.load = load(module, name)
|
||||||
|
module.unload = unload(module, name)
|
||||||
|
hooks.append(
|
||||||
|
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
|
||||||
|
)
|
||||||
|
hooks.append(
|
||||||
|
module.register_forward_hook(load_weights_post_hook(name, weights))
|
||||||
|
)
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
dataset = "wikitext2"
|
dataset = "wikitext2"
|
||||||
@ -806,6 +922,7 @@ def quantize(
|
|||||||
groupsize,
|
groupsize,
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
|
hooks=hooks,
|
||||||
)
|
)
|
||||||
print(time.time() - tick)
|
print(time.time() - tick)
|
||||||
|
|
||||||
@ -858,7 +975,6 @@ def quantize(
|
|||||||
logger.info("Saved tokenizer")
|
logger.info("Saved tokenizer")
|
||||||
|
|
||||||
if upload_to_model_id:
|
if upload_to_model_id:
|
||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
|
|
||||||
api.upload_folder(
|
api.upload_folder(
|
||||||
|
Loading…
Reference in New Issue
Block a user