mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-31 04:10:16 +00:00
Merge branch 'main' into gptq-cuda-kernels
This commit is contained in:
commit
faa5b52fdc
12
.github/workflows/load_test.yaml
vendored
12
.github/workflows/load_test.yaml
vendored
@ -15,11 +15,11 @@ jobs:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
AWS_REGION: eu-central-1
|
||||
EC2_AMI_ID: ami-0ab09c07cfd194259
|
||||
EC2_INSTANCE_TYPE: g5.12xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-04d472c808f365022
|
||||
EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
|
||||
EC2_SECURITY_GROUP: sg-072f92ae3082936c6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
@ -90,7 +90,7 @@ jobs:
|
||||
- load-tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
AWS_REGION: eu-central-1
|
||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
@ -105,4 +105,4 @@ jobs:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
|
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -2848,7 +2848,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
@ -2868,7 +2868,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
@ -2884,7 +2884,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
@ -2900,7 +2900,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
|
@ -8,7 +8,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "0.9.0"
|
||||
"version": "0.9.1"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -197,6 +197,10 @@ struct Args {
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
|
||||
/// The port to listen on.
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
@ -874,6 +878,8 @@ fn spawn_webserver(
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
args.max_waiting_tokens.to_string(),
|
||||
"--hostname".to_string(),
|
||||
args.hostname.to_string(),
|
||||
"--port".to_string(),
|
||||
args.port.to_string(),
|
||||
"--master-shard-uds-path".to_string(),
|
||||
|
@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::{server, HubModelInfo};
|
||||
use thiserror::Error;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tower_http::cors::AllowOrigin;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
@ -40,6 +41,8 @@ struct Args {
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
@ -68,7 +71,7 @@ struct Args {
|
||||
ngrok_password: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
@ -82,6 +85,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
@ -146,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
@ -189,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
|
||||
// Instantiate sharded client from the master unix socket
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
.map_err(RouterError::Connection)?;
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
.map_err(RouterError::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client
|
||||
.info()
|
||||
.await
|
||||
.expect("Unable to get shard info");
|
||||
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
@ -210,11 +210,16 @@ fn main() -> Result<(), std::io::Error> {
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
.expect("Unable to warmup model");
|
||||
.map_err(RouterError::Warmup)?;
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Binds on localhost
|
||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
||||
let addr = match hostname.parse() {
|
||||
Ok(ip) => SocketAddr::new(ip, port),
|
||||
Err(_) => {
|
||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||
}
|
||||
};
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
@ -241,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
ngrok_username,
|
||||
ngrok_password,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
@ -323,3 +328,19 @@ pub async fn get_model_info(
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
#[error("Axum webserver failed: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
@ -527,7 +527,7 @@ pub async fn run(
|
||||
ngrok_domain: Option<String>,
|
||||
ngrok_username: Option<String>,
|
||||
ngrok_password: Option<String>,
|
||||
) {
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
@ -726,8 +726,7 @@ pub async fn run(
|
||||
.serve(app.into_make_service())
|
||||
//Wait until all requests are finished to shut down
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
#[cfg(not(feature = "ngrok"))]
|
||||
{
|
||||
@ -744,9 +743,9 @@ pub async fn run(
|
||||
.serve(app.into_make_service())
|
||||
// Wait until all requests are finished to shut down
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown signal handler
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
|
@ -14,7 +14,7 @@ def test_convert_files():
|
||||
local_st_files = [
|
||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||
]
|
||||
convert_files(local_pt_files, local_st_files)
|
||||
convert_files(local_pt_files, local_st_files, discard_names=[])
|
||||
|
||||
found_st_files = weight_files(model_id)
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import torch
|
||||
|
||||
from typing import Dict, Optional, TypeVar
|
||||
|
||||
from text_generation_server.models.types import Batch
|
||||
@ -20,6 +22,8 @@ class Cache:
|
||||
batch = self.pop(batch_id)
|
||||
if batch is not None:
|
||||
del batch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def clear(self):
|
||||
keys = list(self.cache.keys())
|
||||
|
@ -160,8 +160,26 @@ def download_weights(
|
||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||
for p in local_pt_files
|
||||
]
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
import transformers
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
architecture = config.architectures[0]
|
||||
|
||||
class_ = getattr(transformers, architecture)
|
||||
|
||||
# Name for this varible depends on transformers version.
|
||||
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
|
||||
|
||||
except Exception as e:
|
||||
discard_names = []
|
||||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files)
|
||||
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import grpc
|
||||
|
||||
from google.rpc import status_pb2, code_pb2
|
||||
@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
method_name = method_name.split("/")[-1]
|
||||
logger.exception(f"Method {method_name} encountered an error.")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
|
||||
|
@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
|
||||
self.beta = 1.0
|
||||
|
||||
process_group = weights.process_group
|
||||
if self.num_heads % process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
|
@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
|
@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
|
@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
|
||||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
|
||||
FastLayerNorm,
|
||||
get_linear,
|
||||
)
|
||||
from safetensors import SafetensorError
|
||||
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
@ -72,8 +73,17 @@ def _load_multi_mqa_gptq(
|
||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
try:
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
qweight = qweight.to(weights.device)
|
||||
qzeros = qzeros.to(weights.device)
|
||||
@ -102,7 +112,6 @@ def _load_multi_mqa_gptq(
|
||||
def _load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
|
||||
if any("c_attn" in k for k in weights.routing.keys()):
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||
shape = slice_.get_shape()
|
||||
@ -211,7 +220,11 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
assert self.num_heads % weights.process_group.size() == 0
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
|
||||
if self.softmax_scale is None:
|
||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||
|
||||
if self.n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.n_heads = self.n_heads // weights.process_group.size()
|
||||
self.Wqkv = load_col(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
|
@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
|
||||
torch.tensor(self.head_size, dtype=torch.float32)
|
||||
).to(torch.get_default_dtype())
|
||||
|
||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
||||
if self.num_attention_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_attention_heads` must be divisible by `num_shards` "
|
||||
f"(got `num_attention_heads`: {self.num_attention_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_attention_heads = (
|
||||
self.num_attention_heads // weights.process_group.size()
|
||||
)
|
||||
|
@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
process_group = weights.process_group
|
||||
assert self.num_heads % process_group.size() == 0
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.embed_dim = self.embed_dim // process_group.size()
|
||||
|
||||
|
@ -19,6 +19,8 @@ import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
@ -246,6 +248,11 @@ class T5Attention(nn.Module):
|
||||
self.o = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
||||
)
|
||||
if self.n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.n_heads = self.n_heads // process_group.size()
|
||||
self.inner_dim = self.inner_dim // process_group.size()
|
||||
|
||||
@ -1001,12 +1008,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
try:
|
||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||
except RuntimeError:
|
||||
self.shared = TensorParallelEmbedding(
|
||||
prefix="encoder.embed_tokens", weights=weights
|
||||
)
|
||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
|
@ -638,6 +638,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
for b in batches:
|
||||
b.block_tables = None
|
||||
del b
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
@ -725,12 +726,11 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
_, batch = self.generate_token(batch)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||
f"prefill tokens. "
|
||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||
)
|
||||
raise e
|
||||
) from e
|
||||
del batch
|
||||
|
||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||
@ -775,16 +775,20 @@ class FlashCausalLM(Model):
|
||||
# Allocate blocks to this batch
|
||||
CACHE_MANAGER.allocate(batch)
|
||||
|
||||
out = self.forward(
|
||||
batch.input_ids,
|
||||
batch.position_ids,
|
||||
batch.cu_seqlen_prefill,
|
||||
batch.block_tables_tensor,
|
||||
batch.slots[batch.slot_indices],
|
||||
batch.input_lengths_tensor,
|
||||
batch.max_seqlen,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
try:
|
||||
out = self.forward(
|
||||
batch.input_ids,
|
||||
batch.position_ids,
|
||||
batch.cu_seqlen_prefill,
|
||||
batch.block_tables_tensor,
|
||||
batch.slots[batch.slot_indices],
|
||||
batch.input_lengths_tensor,
|
||||
batch.max_seqlen,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
except Exception as e:
|
||||
del batch
|
||||
raise e
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
|
@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
|
||||
|
@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = T5ForConditionalGeneration(config, weights)
|
||||
|
@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
filtered_batch = batch.filter(request.request_ids)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
self.model.warmup(batch, request.max_total_tokens)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.WarmupResponse()
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
if len(batches) > 1:
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
|
@ -4,11 +4,56 @@ import os
|
||||
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from safetensors.torch import save_file, _remove_duplicate_names, load_file
|
||||
from typing import List
|
||||
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
|
||||
from typing import List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def convert_file(pt_file: Path, sf_file: Path):
|
||||
def _remove_duplicate_names(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
*,
|
||||
preferred_names: List[str] = None,
|
||||
discard_names: List[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
if preferred_names is None:
|
||||
preferred_names = []
|
||||
preferred_names = set(preferred_names)
|
||||
if discard_names is None:
|
||||
discard_names = []
|
||||
discard_names = set(discard_names)
|
||||
|
||||
shareds = _find_shared_tensors(state_dict)
|
||||
to_remove = defaultdict(list)
|
||||
for shared in shareds:
|
||||
complete_names = set(
|
||||
[name for name in shared if _is_complete(state_dict[name])]
|
||||
)
|
||||
if not complete_names:
|
||||
raise RuntimeError(
|
||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||
)
|
||||
|
||||
keep_name = sorted(list(complete_names))[0]
|
||||
|
||||
# Mecanism to preferentially select keys to keep
|
||||
# coming from the on-disk file to allow
|
||||
# loading models saved with a different choice
|
||||
# of keep_name
|
||||
preferred = complete_names.difference(discard_names)
|
||||
if preferred:
|
||||
keep_name = sorted(list(preferred))[0]
|
||||
|
||||
if preferred_names:
|
||||
preferred = preferred_names.intersection(complete_names)
|
||||
if preferred:
|
||||
keep_name = sorted(list(preferred))[0]
|
||||
for name in sorted(shared):
|
||||
if name != keep_name:
|
||||
to_remove[keep_name].append(name)
|
||||
return to_remove
|
||||
|
||||
|
||||
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
|
||||
"""
|
||||
Convert a pytorch file to a safetensors file
|
||||
This will remove duplicate tensors from the file.
|
||||
@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path):
|
||||
loaded = torch.load(pt_file, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
to_removes = _remove_duplicate_names(loaded)
|
||||
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
|
||||
|
||||
metadata = {"format": "pt"}
|
||||
for kept_name, to_remove_group in to_removes.items():
|
||||
@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
||||
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
|
||||
assert len(pt_files) == len(sf_files)
|
||||
|
||||
N = len(pt_files)
|
||||
@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
||||
|
||||
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
|
||||
start = datetime.datetime.now()
|
||||
convert_file(pt_file, sf_file)
|
||||
convert_file(pt_file, sf_file, discard_names)
|
||||
elapsed = datetime.datetime.now() - start
|
||||
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
|
||||
|
@ -178,13 +178,25 @@ class SuperLayer(nn.Module):
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
@ -194,13 +206,14 @@ class TensorParallelHead(SuperLayer):
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
world_size = self.process_group.size()
|
||||
if world_size == 1:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
@ -281,7 +294,7 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -1,8 +1,9 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from safetensors import safe_open
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
@ -68,7 +69,7 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
@ -80,10 +81,6 @@ class Weights:
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
@ -97,29 +94,57 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
size = slice_.get_shape()[dim]
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
||||
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_triton_kernel = False
|
||||
if self.process_group.size() > 1:
|
||||
@ -155,8 +180,17 @@ class Weights:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user