mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-08-01 04:40:17 +00:00
Merge branch 'main' into gptq-cuda-kernels
This commit is contained in:
commit
faa5b52fdc
12
.github/workflows/load_test.yaml
vendored
12
.github/workflows/load_test.yaml
vendored
@ -15,11 +15,11 @@ jobs:
|
|||||||
name: Start self-hosted EC2 runner
|
name: Start self-hosted EC2 runner
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: eu-central-1
|
||||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
EC2_AMI_ID: ami-0ab09c07cfd194259
|
||||||
EC2_INSTANCE_TYPE: g5.12xlarge
|
EC2_INSTANCE_TYPE: g5.12xlarge
|
||||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
|
||||||
EC2_SECURITY_GROUP: sg-04d472c808f365022
|
EC2_SECURITY_GROUP: sg-072f92ae3082936c6
|
||||||
outputs:
|
outputs:
|
||||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||||
@ -90,7 +90,7 @@ jobs:
|
|||||||
- load-tests
|
- load-tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: eu-central-1
|
||||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
steps:
|
steps:
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
@ -105,4 +105,4 @@ jobs:
|
|||||||
mode: stop
|
mode: stop
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
label: ${{ needs.start-runner.outputs.label }}
|
label: ${{ needs.start-runner.outputs.label }}
|
||||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||||
|
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -2848,7 +2848,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "0.9.0"
|
version = "0.9.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
@ -2868,7 +2868,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.9.0"
|
version = "0.9.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -2884,7 +2884,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.9.0"
|
version = "0.9.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -2900,7 +2900,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.9.0"
|
version = "0.9.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
@ -8,7 +8,7 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.9.0"
|
version = "0.9.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "0.9.0"
|
"version": "0.9.1"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
@ -197,6 +197,10 @@ struct Args {
|
|||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
|
||||||
|
/// The IP address to listen on
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
|
||||||
/// The port to listen on.
|
/// The port to listen on.
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
@ -874,6 +878,8 @@ fn spawn_webserver(
|
|||||||
args.waiting_served_ratio.to_string(),
|
args.waiting_served_ratio.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
args.max_waiting_tokens.to_string(),
|
args.max_waiting_tokens.to_string(),
|
||||||
|
"--hostname".to_string(),
|
||||||
|
args.hostname.to_string(),
|
||||||
"--port".to_string(),
|
"--port".to_string(),
|
||||||
args.port.to_string(),
|
args.port.to_string(),
|
||||||
"--master-shard-uds-path".to_string(),
|
"--master-shard-uds-path".to_string(),
|
||||||
|
@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
|
|||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::{server, HubModelInfo};
|
use text_generation_router::{server, HubModelInfo};
|
||||||
|
use thiserror::Error;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
@ -40,6 +41,8 @@ struct Args {
|
|||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
@ -68,7 +71,7 @@ struct Args {
|
|||||||
ngrok_password: Option<String>,
|
ngrok_password: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
@ -82,6 +85,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
hostname,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
@ -146,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()?
|
||||||
.unwrap()
|
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
@ -189,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.expect("Could not connect to server");
|
.map_err(RouterError::Connection)?;
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.map_err(RouterError::Cache)?;
|
||||||
// Get info from the shard
|
// Get info from the shard
|
||||||
let shard_info = sharded_client
|
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
||||||
.info()
|
|
||||||
.await
|
|
||||||
.expect("Unable to get shard info");
|
|
||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
@ -210,11 +210,16 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to warmup model");
|
.map_err(RouterError::Warmup)?;
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Binds on localhost
|
let addr = match hostname.parse() {
|
||||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
Ok(ip) => SocketAddr::new(ip, port),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
@ -241,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
ngrok_username,
|
ngrok_username,
|
||||||
ngrok_password,
|
ngrok_password,
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -323,3 +328,19 @@ pub async fn get_model_info(
|
|||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
#[error("Axum webserver failed: {0}")]
|
||||||
|
Axum(#[from] axum::BoxError),
|
||||||
|
}
|
||||||
|
@ -527,7 +527,7 @@ pub async fn run(
|
|||||||
ngrok_domain: Option<String>,
|
ngrok_domain: Option<String>,
|
||||||
ngrok_username: Option<String>,
|
ngrok_username: Option<String>,
|
||||||
ngrok_password: Option<String>,
|
ngrok_password: Option<String>,
|
||||||
) {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
@ -726,8 +726,7 @@ pub async fn run(
|
|||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
//Wait until all requests are finished to shut down
|
//Wait until all requests are finished to shut down
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await?;
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "ngrok"))]
|
#[cfg(not(feature = "ngrok"))]
|
||||||
{
|
{
|
||||||
@ -744,9 +743,9 @@ pub async fn run(
|
|||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
// Wait until all requests are finished to shut down
|
// Wait until all requests are finished to shut down
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await?;
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Shutdown signal handler
|
/// Shutdown signal handler
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "0.9.0"
|
version = "0.9.1"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ def test_convert_files():
|
|||||||
local_st_files = [
|
local_st_files = [
|
||||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||||
]
|
]
|
||||||
convert_files(local_pt_files, local_st_files)
|
convert_files(local_pt_files, local_st_files, discard_names=[])
|
||||||
|
|
||||||
found_st_files = weight_files(model_id)
|
found_st_files = weight_files(model_id)
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from typing import Dict, Optional, TypeVar
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
@ -20,6 +22,8 @@ class Cache:
|
|||||||
batch = self.pop(batch_id)
|
batch = self.pop(batch_id)
|
||||||
if batch is not None:
|
if batch is not None:
|
||||||
del batch
|
del batch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
keys = list(self.cache.keys())
|
keys = list(self.cache.keys())
|
||||||
|
@ -160,8 +160,26 @@ def download_weights(
|
|||||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
for p in local_pt_files
|
for p in local_pt_files
|
||||||
]
|
]
|
||||||
|
try:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
architecture = config.architectures[0]
|
||||||
|
|
||||||
|
class_ = getattr(transformers, architecture)
|
||||||
|
|
||||||
|
# Name for this varible depends on transformers version.
|
||||||
|
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||||
|
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
discard_names = []
|
||||||
# Convert pytorch weights to safetensors
|
# Convert pytorch weights to safetensors
|
||||||
utils.convert_files(local_pt_files, local_st_files)
|
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from google.rpc import status_pb2, code_pb2
|
from google.rpc import status_pb2, code_pb2
|
||||||
@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
method_name = method_name.split("/")[-1]
|
method_name = method_name.split("/")[-1]
|
||||||
logger.exception(f"Method {method_name} encountered an error.")
|
logger.exception(f"Method {method_name} encountered an error.")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
await context.abort_with_status(
|
await context.abort_with_status(
|
||||||
rpc_status.to_status(
|
rpc_status.to_status(
|
||||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
|
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
|
||||||
|
@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
|
|||||||
self.beta = 1.0
|
self.beta = 1.0
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
if self.num_heads % process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
|
@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
|
@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
dim=self.head_size, base=10000.0, device=weights.device
|
dim=self.head_size, base=10000.0, device=weights.device
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
|
|||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from safetensors import SafetensorError
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
@ -72,8 +73,17 @@ def _load_multi_mqa_gptq(
|
|||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
bits = weights.get_tensor("gptq_bits").item()
|
try:
|
||||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
except SafetensorError as e:
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
bits = int(os.getenv("GPTQ_BITS"))
|
||||||
|
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||||
|
except Exception:
|
||||||
|
raise e
|
||||||
|
|
||||||
qweight = qweight.to(weights.device)
|
qweight = qweight.to(weights.device)
|
||||||
qzeros = qzeros.to(weights.device)
|
qzeros = qzeros.to(weights.device)
|
||||||
@ -102,7 +112,6 @@ def _load_multi_mqa_gptq(
|
|||||||
def _load_multi_mqa(
|
def _load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
|
||||||
if any("c_attn" in k for k in weights.routing.keys()):
|
if any("c_attn" in k for k in weights.routing.keys()):
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -211,7 +220,11 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
assert self.num_heads % weights.process_group.size() == 0
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
|
|||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||||
|
|
||||||
|
if self.n_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.n_heads = self.n_heads // weights.process_group.size()
|
self.n_heads = self.n_heads // weights.process_group.size()
|
||||||
self.Wqkv = load_col(
|
self.Wqkv = load_col(
|
||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
|
@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
torch.tensor(self.head_size, dtype=torch.float32)
|
torch.tensor(self.head_size, dtype=torch.float32)
|
||||||
).to(torch.get_default_dtype())
|
).to(torch.get_default_dtype())
|
||||||
|
|
||||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
if self.num_attention_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_attention_heads` must be divisible by `num_shards` "
|
||||||
|
f"(got `num_attention_heads`: {self.num_attention_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_attention_heads = (
|
self.num_attention_heads = (
|
||||||
self.num_attention_heads // weights.process_group.size()
|
self.num_attention_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
|
|||||||
self.is_decoder = is_decoder
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
assert self.num_heads % process_group.size() == 0
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
self.embed_dim = self.embed_dim // process_group.size()
|
self.embed_dim = self.embed_dim // process_group.size()
|
||||||
|
|
||||||
|
@ -19,6 +19,8 @@ import math
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -246,6 +248,11 @@ class T5Attention(nn.Module):
|
|||||||
self.o = TensorParallelRowLinear.load(
|
self.o = TensorParallelRowLinear.load(
|
||||||
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
if self.n_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.n_heads = self.n_heads // process_group.size()
|
self.n_heads = self.n_heads // process_group.size()
|
||||||
self.inner_dim = self.inner_dim // process_group.size()
|
self.inner_dim = self.inner_dim // process_group.size()
|
||||||
|
|
||||||
@ -1001,12 +1008,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model_dim = config.d_model
|
self.model_dim = config.d_model
|
||||||
|
|
||||||
try:
|
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
|
||||||
except RuntimeError:
|
|
||||||
self.shared = TensorParallelEmbedding(
|
|
||||||
prefix="encoder.embed_tokens", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
|
@ -638,6 +638,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
b.block_tables = None
|
b.block_tables = None
|
||||||
|
del b
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
@ -725,12 +726,11 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
_, batch = self.generate_token(batch)
|
_, batch = self.generate_token(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||||
f"prefill tokens. "
|
f"prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||||
)
|
) from e
|
||||||
raise e
|
|
||||||
del batch
|
del batch
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
@ -775,16 +775,20 @@ class FlashCausalLM(Model):
|
|||||||
# Allocate blocks to this batch
|
# Allocate blocks to this batch
|
||||||
CACHE_MANAGER.allocate(batch)
|
CACHE_MANAGER.allocate(batch)
|
||||||
|
|
||||||
out = self.forward(
|
try:
|
||||||
batch.input_ids,
|
out = self.forward(
|
||||||
batch.position_ids,
|
batch.input_ids,
|
||||||
batch.cu_seqlen_prefill,
|
batch.position_ids,
|
||||||
batch.block_tables_tensor,
|
batch.cu_seqlen_prefill,
|
||||||
batch.slots[batch.slot_indices],
|
batch.block_tables_tensor,
|
||||||
batch.input_lengths_tensor,
|
batch.slots[batch.slot_indices],
|
||||||
batch.max_seqlen,
|
batch.input_lengths_tensor,
|
||||||
batch.prefill_head_indices,
|
batch.max_seqlen,
|
||||||
)
|
batch.prefill_head_indices,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
del batch
|
||||||
|
raise e
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
|
@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
||||||
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
|
@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={
|
||||||
|
"shared.weight": [
|
||||||
|
"encoder.embed_tokens.weight",
|
||||||
|
"decoder.embed_tokens.weight",
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
model = T5ForConditionalGeneration(config, weights)
|
model = T5ForConditionalGeneration(config, weights)
|
||||||
|
@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
filtered_batch = batch.filter(request.request_ids)
|
filtered_batch = batch.filter(request.request_ids)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
self.model.warmup(batch, request.max_total_tokens)
|
self.model.warmup(batch, request.max_total_tokens)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return generate_pb2.WarmupResponse()
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
|
@ -4,11 +4,56 @@ import os
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors.torch import save_file, _remove_duplicate_names, load_file
|
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
def convert_file(pt_file: Path, sf_file: Path):
|
def _remove_duplicate_names(
|
||||||
|
state_dict: Dict[str, torch.Tensor],
|
||||||
|
*,
|
||||||
|
preferred_names: List[str] = None,
|
||||||
|
discard_names: List[str] = None,
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
if preferred_names is None:
|
||||||
|
preferred_names = []
|
||||||
|
preferred_names = set(preferred_names)
|
||||||
|
if discard_names is None:
|
||||||
|
discard_names = []
|
||||||
|
discard_names = set(discard_names)
|
||||||
|
|
||||||
|
shareds = _find_shared_tensors(state_dict)
|
||||||
|
to_remove = defaultdict(list)
|
||||||
|
for shared in shareds:
|
||||||
|
complete_names = set(
|
||||||
|
[name for name in shared if _is_complete(state_dict[name])]
|
||||||
|
)
|
||||||
|
if not complete_names:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||||
|
)
|
||||||
|
|
||||||
|
keep_name = sorted(list(complete_names))[0]
|
||||||
|
|
||||||
|
# Mecanism to preferentially select keys to keep
|
||||||
|
# coming from the on-disk file to allow
|
||||||
|
# loading models saved with a different choice
|
||||||
|
# of keep_name
|
||||||
|
preferred = complete_names.difference(discard_names)
|
||||||
|
if preferred:
|
||||||
|
keep_name = sorted(list(preferred))[0]
|
||||||
|
|
||||||
|
if preferred_names:
|
||||||
|
preferred = preferred_names.intersection(complete_names)
|
||||||
|
if preferred:
|
||||||
|
keep_name = sorted(list(preferred))[0]
|
||||||
|
for name in sorted(shared):
|
||||||
|
if name != keep_name:
|
||||||
|
to_remove[keep_name].append(name)
|
||||||
|
return to_remove
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
|
||||||
"""
|
"""
|
||||||
Convert a pytorch file to a safetensors file
|
Convert a pytorch file to a safetensors file
|
||||||
This will remove duplicate tensors from the file.
|
This will remove duplicate tensors from the file.
|
||||||
@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path):
|
|||||||
loaded = torch.load(pt_file, map_location="cpu")
|
loaded = torch.load(pt_file, map_location="cpu")
|
||||||
if "state_dict" in loaded:
|
if "state_dict" in loaded:
|
||||||
loaded = loaded["state_dict"]
|
loaded = loaded["state_dict"]
|
||||||
to_removes = _remove_duplicate_names(loaded)
|
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
|
||||||
|
|
||||||
metadata = {"format": "pt"}
|
metadata = {"format": "pt"}
|
||||||
for kept_name, to_remove_group in to_removes.items():
|
for kept_name, to_remove_group in to_removes.items():
|
||||||
@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path):
|
|||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
|
||||||
assert len(pt_files) == len(sf_files)
|
assert len(pt_files) == len(sf_files)
|
||||||
|
|
||||||
N = len(pt_files)
|
N = len(pt_files)
|
||||||
@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
|||||||
|
|
||||||
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
|
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
|
||||||
start = datetime.datetime.now()
|
start = datetime.datetime.now()
|
||||||
convert_file(pt_file, sf_file)
|
convert_file(pt_file, sf_file, discard_names)
|
||||||
elapsed = datetime.datetime.now() - start
|
elapsed = datetime.datetime.now() - start
|
||||||
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
|
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
|
||||||
|
@ -178,13 +178,25 @@ class SuperLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelHead(SuperLayer):
|
class TensorParallelHead(SuperLayer):
|
||||||
def __init__(self, linear, process_group):
|
def __init__(self, linear, process_group, should_gather: bool):
|
||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.should_gather = should_gather
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
if weights.process_group.size() > 1:
|
||||||
|
try:
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
should_gather = True
|
||||||
|
except AssertionError:
|
||||||
|
# If the vocab size is not divisible by number of shards
|
||||||
|
# just load the entire thing.
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
else:
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
@ -194,13 +206,14 @@ class TensorParallelHead(SuperLayer):
|
|||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None, quantize=quantize),
|
get_linear(weight, bias=None, quantize=quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
|
should_gather=should_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
world_size = self.process_group.size()
|
if not self.should_gather:
|
||||||
if world_size == 1:
|
|
||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
|
world_size = self.process_group.size()
|
||||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||||
out_dim = self.linear.weight.shape[0]
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
@ -281,7 +294,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
class TensorParallelEmbedding(nn.Module):
|
class TensorParallelEmbedding(nn.Module):
|
||||||
def __init__(self, prefix: str, weights, reduce=True):
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open, SafetensorError
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -68,7 +69,7 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
@ -80,10 +81,6 @@ class Weights:
|
|||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
assert (
|
|
||||||
size % world_size == 0
|
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
|
||||||
|
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
elif dim == 1:
|
elif dim == 1:
|
||||||
@ -97,29 +94,57 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def get_sharded(self, tensor_name: str, dim: int):
|
||||||
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
|
f = self._get_handle(filename)
|
||||||
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
size = slice_.get_shape()[dim]
|
||||||
|
assert (
|
||||||
|
size % world_size == 0
|
||||||
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
|
return self.get_partial_sharded(tensor_name, dim)
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
qweight = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
qzeros = torch.cat(
|
||||||
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
scales = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
g_idx = w[0]
|
g_idx = w[0]
|
||||||
|
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
try:
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
except SafetensorError as e:
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
bits = int(os.getenv("GPTQ_BITS"))
|
||||||
|
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||||
|
except Exception:
|
||||||
|
raise e
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_triton_kernel = False
|
use_triton_kernel = False
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
@ -155,8 +180,17 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
try:
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
except SafetensorError as e:
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
bits = int(os.getenv("GPTQ_BITS"))
|
||||||
|
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||||
|
except Exception:
|
||||||
|
raise e
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user