Merge branch 'main' into gptq-cuda-kernels

This commit is contained in:
Félix Marty 2023-07-12 18:47:30 +02:00
commit faa5b52fdc
28 changed files with 311 additions and 87 deletions

View File

@ -15,11 +15,11 @@ jobs:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
AWS_REGION: eu-central-1
EC2_AMI_ID: ami-0ab09c07cfd194259
EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-04d472c808f365022
EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
EC2_SECURITY_GROUP: sg-072f92ae3082936c6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
@ -90,7 +90,7 @@ jobs:
- load-tests
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
AWS_REGION: eu-central-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
@ -105,4 +105,4 @@ jobs:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

8
Cargo.lock generated
View File

@ -2848,7 +2848,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "0.9.0"
version = "0.9.1"
dependencies = [
"average",
"clap",
@ -2868,7 +2868,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "0.9.0"
version = "0.9.1"
dependencies = [
"futures",
"grpc-metadata",
@ -2884,7 +2884,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "0.9.0"
version = "0.9.1"
dependencies = [
"clap",
"ctrlc",
@ -2900,7 +2900,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "0.9.0"
version = "0.9.1"
dependencies = [
"async-stream",
"axum",

View File

@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "0.9.0"
version = "0.9.1"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.9.0"
"version": "0.9.1"
},
"paths": {
"/": {

View File

@ -197,6 +197,10 @@ struct Args {
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
/// The port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16,
@ -874,6 +878,8 @@ fn spawn_webserver(
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
args.max_waiting_tokens.to_string(),
"--hostname".to_string(),
args.hostname.to_string(),
"--port".to_string(),
args.port.to_string(),
"--master-shard-uds-path".to_string(),

View File

@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use std::time::Duration;
use text_generation_client::ShardedClient;
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo};
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
@ -40,6 +41,8 @@ struct Args {
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
@ -68,7 +71,7 @@ struct Args {
ngrok_password: Option<String>,
}
fn main() -> Result<(), std::io::Error> {
fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
@ -82,6 +85,7 @@ fn main() -> Result<(), std::io::Error> {
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
@ -146,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.build()?
.block_on(async {
init_logging(otlp_endpoint, json_output);
@ -189,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
.map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
.map_err(RouterError::Cache)?;
// Get info from the shard
let shard_info = sharded_client
.info()
.await
.expect("Unable to get shard info");
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
// Warmup model
tracing::info!("Warming up model");
@ -210,11 +210,16 @@ fn main() -> Result<(), std::io::Error> {
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
.map_err(RouterError::Warmup)?;
tracing::info!("Connected");
// Binds on localhost
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Run server
server::run(
@ -241,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username,
ngrok_password,
)
.await;
.await?;
Ok(())
})
}
@ -323,3 +328,19 @@ pub async fn get_model_info(
}
None
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
#[error("Axum webserver failed: {0}")]
Axum(#[from] axum::BoxError),
}

View File

@ -527,7 +527,7 @@ pub async fn run(
ngrok_domain: Option<String>,
ngrok_username: Option<String>,
ngrok_password: Option<String>,
) {
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
@ -726,8 +726,7 @@ pub async fn run(
.serve(app.into_make_service())
//Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
.await?;
}
#[cfg(not(feature = "ngrok"))]
{
@ -744,9 +743,9 @@ pub async fn run(
.serve(app.into_make_service())
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
.await?;
}
Ok(())
}
/// Shutdown signal handler

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "0.9.0"
version = "0.9.1"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -14,7 +14,7 @@ def test_convert_files():
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id)

View File

@ -1,3 +1,5 @@
import torch
from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch
@ -20,6 +22,8 @@ class Cache:
batch = self.pop(batch_id)
if batch is not None:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
def clear(self):
keys = list(self.cache.keys())

View File

@ -160,8 +160,26 @@ def download_weights(
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
try:
from transformers import AutoConfig
import transformers
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
architecture = config.architectures[0]
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
except Exception as e:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
utils.convert_files(local_pt_files, local_st_files, discard_names)
@app.command()

View File

@ -1,3 +1,4 @@
import torch
import grpc
from google.rpc import status_pb2, code_pb2
@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor):
method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.")
if torch.cuda.is_available():
torch.cuda.empty_cache()
await context.abort_with_status(
rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))

View File

@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
self.beta = 1.0
process_group = weights.process_group
if self.num_heads % process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config=config,

View File

@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,

View File

@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(

View File

@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(

View File

@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
FastLayerNorm,
get_linear,
)
from safetensors import SafetensorError
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
@ -72,8 +73,17 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
try:
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
qweight = qweight.to(weights.device)
qzeros = qzeros.to(weights.device)
@ -102,7 +112,6 @@ def _load_multi_mqa_gptq(
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
@ -211,7 +220,11 @@ class FlashMQAttention(torch.nn.Module):
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"]
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // weights.process_group.size()
self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias

View File

@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0
if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` "
f"(got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)

View File

@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
self.is_decoder = is_decoder
process_group = weights.process_group
assert self.num_heads % process_group.size() == 0
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size()

View File

@ -19,6 +19,8 @@ import math
import warnings
from typing import Optional, Tuple, Union
from loguru import logger
import torch
import torch.distributed
from torch import nn
@ -246,6 +248,11 @@ class T5Attention(nn.Module):
self.o = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o", weights=weights, bias=False
)
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // process_group.size()
self.inner_dim = self.inner_dim // process_group.size()
@ -1001,12 +1008,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super().__init__(config)
self.model_dim = config.d_model
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False

View File

@ -638,6 +638,7 @@ class FlashCausalLMBatch(Batch):
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
del b
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
@ -725,12 +726,11 @@ class FlashCausalLM(Model):
)
_, batch = self.generate_token(batch)
except Exception as e:
logger.exception(
raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
)
raise e
) from e
del batch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
@ -775,16 +775,20 @@ class FlashCausalLM(Model):
# Allocate blocks to this batch
CACHE_MANAGER.allocate(batch)
out = self.forward(
batch.input_ids,
batch.position_ids,
batch.cu_seqlen_prefill,
batch.block_tables_tensor,
batch.slots[batch.slot_indices],
batch.input_lengths_tensor,
batch.max_seqlen,
batch.prefill_head_indices,
)
try:
out = self.forward(
batch.input_ids,
batch.position_ids,
batch.cu_seqlen_prefill,
batch.block_tables_tensor,
batch.slots[batch.slot_indices],
batch.input_lengths_tensor,
batch.max_seqlen,
batch.prefill_head_indices,
)
except Exception as e:
del batch
raise e
if prefill:
next_token_logits = (

View File

@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
)
config.quantize = quantize

View File

@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)

View File

@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
batch = batches[0]

View File

@ -4,11 +4,56 @@ import os
from loguru import logger
from pathlib import Path
from safetensors.torch import save_file, _remove_duplicate_names, load_file
from typing import List
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
from typing import List, Dict
from collections import defaultdict
def convert_file(pt_file: Path, sf_file: Path):
def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)
keep_name = sorted(list(complete_names))[0]
# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]
if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
"""
Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file.
@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path):
loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], sf_files: List[Path]):
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
assert len(pt_files) == len(sf_files)
N = len(pt_files)
@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]):
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
start = datetime.datetime.now()
convert_file(pt_file, sf_file)
convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now() - start
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")

View File

@ -178,13 +178,25 @@ class SuperLayer(nn.Module):
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear)
self.process_group = process_group
self.should_gather = should_gather
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq":
@ -194,13 +206,14 @@ class TensorParallelHead(SuperLayer):
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
should_gather=should_gather,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
world_size = self.process_group.size()
if world_size == 1:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
@ -281,7 +294,7 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group

View File

@ -1,8 +1,9 @@
from pathlib import Path
from typing import List, Dict, Optional
from safetensors import safe_open
from safetensors import safe_open, SafetensorError
import torch
class Weights:
def __init__(
self,
@ -68,7 +69,7 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
@ -80,10 +81,6 @@ class Weights:
start = rank * block_size
stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
@ -97,29 +94,57 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_multi_weights_row(self, prefix: str, quantize: str):
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_triton_kernel = False
if self.process_group.size() > 1:
@ -155,8 +180,17 @@ class Weights:
else:
g_idx = None
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
else: