Merge branch 'main' into gptq-cuda-kernels

This commit is contained in:
Félix Marty 2023-07-12 18:47:30 +02:00
commit faa5b52fdc
28 changed files with 311 additions and 87 deletions

View File

@ -15,11 +15,11 @@ jobs:
name: Start self-hosted EC2 runner name: Start self-hosted EC2 runner
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
AWS_REGION: us-east-1 AWS_REGION: eu-central-1
EC2_AMI_ID: ami-03cfed9ea28f4b002 EC2_AMI_ID: ami-0ab09c07cfd194259
EC2_INSTANCE_TYPE: g5.12xlarge EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
EC2_SECURITY_GROUP: sg-04d472c808f365022 EC2_SECURITY_GROUP: sg-072f92ae3082936c6
outputs: outputs:
label: ${{ steps.start-ec2-runner.outputs.label }} label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
@ -90,7 +90,7 @@ jobs:
- load-tests - load-tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
AWS_REGION: us-east-1 AWS_REGION: eu-central-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
steps: steps:
- name: Configure AWS credentials - name: Configure AWS credentials
@ -105,4 +105,4 @@ jobs:
mode: stop mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }} label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

8
Cargo.lock generated
View File

@ -2848,7 +2848,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "0.9.0" version = "0.9.1"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2868,7 +2868,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.9.0" version = "0.9.1"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2884,7 +2884,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.9.0" version = "0.9.1"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2900,7 +2900,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "0.9.0" version = "0.9.1"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",

View File

@ -8,7 +8,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.9.0" version = "0.9.1"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.9.0" "version": "0.9.1"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -197,6 +197,10 @@ struct Args {
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
/// The port to listen on. /// The port to listen on.
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
@ -874,6 +878,8 @@ fn spawn_webserver(
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
args.max_waiting_tokens.to_string(), args.max_waiting_tokens.to_string(),
"--hostname".to_string(),
args.hostname.to_string(),
"--port".to_string(), "--port".to_string(),
args.port.to_string(), args.port.to_string(),
"--master-shard-uds-path".to_string(), "--master-shard-uds-path".to_string(),

View File

@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration; use std::time::Duration;
use text_generation_client::ShardedClient; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo}; use text_generation_router::{server, HubModelInfo};
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
@ -40,6 +41,8 @@ struct Args {
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)] #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
@ -68,7 +71,7 @@ struct Args {
ngrok_password: Option<String>, ngrok_password: Option<String>,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), RouterError> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
@ -82,6 +85,7 @@ fn main() -> Result<(), std::io::Error> {
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
hostname,
port, port,
master_shard_uds_path, master_shard_uds_path,
tokenizer_name, tokenizer_name,
@ -146,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()?
.unwrap()
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
@ -189,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
.expect("Could not connect to server"); .map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted // Clear the cache; useful if the webserver rebooted
sharded_client sharded_client
.clear_cache(None) .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .map_err(RouterError::Cache)?;
// Get info from the shard // Get info from the shard
let shard_info = sharded_client let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
.info()
.await
.expect("Unable to get shard info");
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
@ -210,11 +210,16 @@ fn main() -> Result<(), std::io::Error> {
max_batch_total_tokens, max_batch_total_tokens,
) )
.await .await
.expect("Unable to warmup model"); .map_err(RouterError::Warmup)?;
tracing::info!("Connected"); tracing::info!("Connected");
// Binds on localhost let addr = match hostname.parse() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Run server // Run server
server::run( server::run(
@ -241,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username, ngrok_username,
ngrok_password, ngrok_password,
) )
.await; .await?;
Ok(()) Ok(())
}) })
} }
@ -323,3 +328,19 @@ pub async fn get_model_info(
} }
None None
} }
#[derive(Debug, Error)]
enum RouterError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
#[error("Axum webserver failed: {0}")]
Axum(#[from] axum::BoxError),
}

View File

@ -527,7 +527,7 @@ pub async fn run(
ngrok_domain: Option<String>, ngrok_domain: Option<String>,
ngrok_username: Option<String>, ngrok_username: Option<String>,
ngrok_password: Option<String>, ngrok_password: Option<String>,
) { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@ -726,8 +726,7 @@ pub async fn run(
.serve(app.into_make_service()) .serve(app.into_make_service())
//Wait until all requests are finished to shut down //Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await?;
.unwrap();
} }
#[cfg(not(feature = "ngrok"))] #[cfg(not(feature = "ngrok"))]
{ {
@ -744,9 +743,9 @@ pub async fn run(
.serve(app.into_make_service()) .serve(app.into_make_service())
// Wait until all requests are finished to shut down // Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await?;
.unwrap();
} }
Ok(())
} }
/// Shutdown signal handler /// Shutdown signal handler

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "0.9.0" version = "0.9.1"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -14,7 +14,7 @@ def test_convert_files():
local_st_files = [ local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
] ]
convert_files(local_pt_files, local_st_files) convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id) found_st_files = weight_files(model_id)

View File

@ -1,3 +1,5 @@
import torch
from typing import Dict, Optional, TypeVar from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch from text_generation_server.models.types import Batch
@ -20,6 +22,8 @@ class Cache:
batch = self.pop(batch_id) batch = self.pop(batch_id)
if batch is not None: if batch is not None:
del batch del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
def clear(self): def clear(self):
keys = list(self.cache.keys()) keys = list(self.cache.keys())

View File

@ -160,8 +160,26 @@ def download_weights(
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files for p in local_pt_files
] ]
try:
from transformers import AutoConfig
import transformers
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
architecture = config.architectures[0]
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
except Exception as e:
discard_names = []
# Convert pytorch weights to safetensors # Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files) utils.convert_files(local_pt_files, local_st_files, discard_names)
@app.command() @app.command()

View File

@ -1,3 +1,4 @@
import torch
import grpc import grpc
from google.rpc import status_pb2, code_pb2 from google.rpc import status_pb2, code_pb2
@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor):
method_name = method_name.split("/")[-1] method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.") logger.exception(f"Method {method_name} encountered an error.")
if torch.cuda.is_available():
torch.cuda.empty_cache()
await context.abort_with_status( await context.abort_with_status(
rpc_status.to_status( rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))

View File

@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
self.beta = 1.0 self.beta = 1.0
process_group = weights.process_group process_group = weights.process_group
if self.num_heads % process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load( self.query_key_value = TensorParallelColumnLinear.load(
config=config, config=config,

View File

@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
config, config,

View File

@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load( self.rotary_emb = PositionRotaryEmbedding.load(

View File

@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
dim=self.head_size, base=10000.0, device=weights.device dim=self.head_size, base=10000.0, device=weights.device
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load( self.query_key_value = TensorParallelColumnLinear.load(

View File

@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
FastLayerNorm, FastLayerNorm,
get_linear, get_linear,
) )
from safetensors import SafetensorError
def load_multi_mqa( def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
@ -72,8 +73,17 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item() try:
groupsize = weights.get_tensor("gptq_groupsize").item() bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
qweight = qweight.to(weights.device) qweight = qweight.to(weights.device)
qzeros = qzeros.to(weights.device) qzeros = qzeros.to(weights.device)
@ -102,7 +112,6 @@ def _load_multi_mqa_gptq(
def _load_multi_mqa( def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
if any("c_attn" in k for k in weights.routing.keys()): if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight") slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape() shape = slice_.get_shape()
@ -211,7 +220,11 @@ class FlashMQAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0 if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
if self.softmax_scale is None: if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"] self.attn_dropout_p = config.attn_config["attn_pdrop"]
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // weights.process_group.size() self.n_heads = self.n_heads // weights.process_group.size()
self.Wqkv = load_col( self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias

View File

@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
torch.tensor(self.head_size, dtype=torch.float32) torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype()) ).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0 if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` "
f"(got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_attention_heads = ( self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size() self.num_attention_heads // weights.process_group.size()
) )

View File

@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
self.is_decoder = is_decoder self.is_decoder = is_decoder
process_group = weights.process_group process_group = weights.process_group
assert self.num_heads % process_group.size() == 0 if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size() self.embed_dim = self.embed_dim // process_group.size()

View File

@ -19,6 +19,8 @@ import math
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from loguru import logger
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
@ -246,6 +248,11 @@ class T5Attention(nn.Module):
self.o = TensorParallelRowLinear.load( self.o = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o", weights=weights, bias=False config, prefix=f"{prefix}.o", weights=weights, bias=False
) )
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // process_group.size() self.n_heads = self.n_heads // process_group.size()
self.inner_dim = self.inner_dim // process_group.size() self.inner_dim = self.inner_dim // process_group.size()
@ -1001,12 +1008,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super().__init__(config) super().__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model
try: self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False encoder_config.is_decoder = False

View File

@ -638,6 +638,7 @@ class FlashCausalLMBatch(Batch):
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
b.block_tables = None b.block_tables = None
del b
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
@ -725,12 +726,11 @@ class FlashCausalLM(Model):
) )
_, batch = self.generate_token(batch) _, batch = self.generate_token(batch)
except Exception as e: except Exception as e:
logger.exception( raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. " f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
) ) from e
raise e
del batch del batch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
@ -775,16 +775,20 @@ class FlashCausalLM(Model):
# Allocate blocks to this batch # Allocate blocks to this batch
CACHE_MANAGER.allocate(batch) CACHE_MANAGER.allocate(batch)
out = self.forward( try:
batch.input_ids, out = self.forward(
batch.position_ids, batch.input_ids,
batch.cu_seqlen_prefill, batch.position_ids,
batch.block_tables_tensor, batch.cu_seqlen_prefill,
batch.slots[batch.slot_indices], batch.block_tables_tensor,
batch.input_lengths_tensor, batch.slots[batch.slot_indices],
batch.max_seqlen, batch.input_lengths_tensor,
batch.prefill_head_indices, batch.max_seqlen,
) batch.prefill_head_indices,
)
except Exception as e:
del batch
raise e
if prefill: if prefill:
next_token_logits = ( next_token_logits = (

View File

@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
)
config.quantize = quantize config.quantize = quantize

View File

@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
) )
model = T5ForConditionalGeneration(config, weights) model = T5ForConditionalGeneration(config, weights)

View File

@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch = batch.filter(request.request_ids) filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
self.model.warmup(batch, request.max_total_tokens) self.model.warmup(batch, request.max_total_tokens)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.WarmupResponse() return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else: else:
batch = batches[0] batch = batches[0]

View File

@ -4,11 +4,56 @@ import os
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
from safetensors.torch import save_file, _remove_duplicate_names, load_file from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
from typing import List from typing import List, Dict
from collections import defaultdict
def convert_file(pt_file: Path, sf_file: Path): def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)
keep_name = sorted(list(complete_names))[0]
# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]
if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
""" """
Convert a pytorch file to a safetensors file Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file. This will remove duplicate tensors from the file.
@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path):
loaded = torch.load(pt_file, map_location="cpu") loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded) to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
metadata = {"format": "pt"} metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items(): for kept_name, to_remove_group in to_removes.items():
@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path):
raise RuntimeError(f"The output tensors do not match for key {k}") raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], sf_files: List[Path]): def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
assert len(pt_files) == len(sf_files) assert len(pt_files) == len(sf_files)
N = len(pt_files) N = len(pt_files)
@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]):
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
start = datetime.datetime.now() start = datetime.datetime.now()
convert_file(pt_file, sf_file) convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now() - start elapsed = datetime.datetime.now() - start
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")

View File

@ -178,13 +178,25 @@ class SuperLayer(nn.Module):
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group): def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear) super().__init__(linear)
self.process_group = process_group self.process_group = process_group
self.should_gather = should_gather
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0) if weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings) # GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq": if config.quantize == "gptq":
@ -194,13 +206,14 @@ class TensorParallelHead(SuperLayer):
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize), get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group, process_group=weights.process_group,
should_gather=should_gather,
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
world_size = self.process_group.size() if not self.should_gather:
if world_size == 1:
return super().forward(input) return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear): if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0] out_dim = self.linear.weight.shape[0]
@ -281,7 +294,7 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module): class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True): def __init__(self, prefix: str, weights, reduce=True):
super().__init__() super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0] num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group process_group = weights.process_group

View File

@ -1,8 +1,9 @@
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional from typing import List, Dict, Optional
from safetensors import safe_open from safetensors import safe_open, SafetensorError
import torch import torch
class Weights: class Weights:
def __init__( def __init__(
self, self,
@ -68,7 +69,7 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
@ -80,10 +81,6 @@ class Weights:
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0: if dim == 0:
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif dim == 1: elif dim == 1:
@ -97,29 +94,57 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq": if quantize == "gptq":
try: try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) qzeros = torch.cat(
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
bits = self.get_tensor("gptq_bits").item() try:
groupsize = self.get_tensor("gptq_groupsize").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
return weight return weight
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_triton_kernel = False use_triton_kernel = False
if self.process_group.size() > 1: if self.process_group.size() > 1:
@ -155,8 +180,17 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits = self.get_tensor("gptq_bits").item() try:
groupsize = self.get_tensor("gptq_groupsize").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
else: else: