mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add requirements to docker
This commit is contained in:
parent
3f2fce87e7
commit
7338e0097f
@ -188,7 +188,7 @@ COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install ".[bnb, accelerate]" --no-cache-dir
|
||||
pip install ".[bnb, accelerate, ct2]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
|
@ -17,6 +17,7 @@ grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.19.0", optional = true }
|
||||
bitsandbytes = { version = "^0.38.1", optional = true }
|
||||
ctranslate2 = { version = "^3.17.1", optional = true }
|
||||
safetensors = "0.3.1"
|
||||
loguru = "^0.6.0"
|
||||
opentelemetry-api = "^1.15.0"
|
||||
@ -32,6 +33,7 @@ einops = "^0.6.1"
|
||||
[tool.poetry.extras]
|
||||
accelerate = ["accelerate"]
|
||||
bnb = ["bitsandbytes"]
|
||||
ct2 = ["ctranslate2"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
grpcio-tools = "^1.51.1"
|
||||
|
@ -89,6 +89,7 @@ def download_weights(
|
||||
auto_convert: bool = True,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
@ -168,6 +169,7 @@ def download_weights(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
architecture = config.architectures[0]
|
||||
|
||||
|
@ -2,7 +2,9 @@ import torch
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
@ -75,8 +77,9 @@ class CT2CausalLM(Model):
|
||||
# " sampling based / non-greedy next_token"
|
||||
# " of code only working in float16.")
|
||||
# Start CT2 - conversion
|
||||
out_dir = f"./ct2-{model_id.replace('/','_')}-{ct2_compute_type}"
|
||||
if not os.path.exists(os.path.join(out_dir, "model.bin")):
|
||||
out_dir = Path(HUGGINGFACE_HUB_CACHE) / \
|
||||
f"ct2models-{model_id.replace('/','--')}--{ct2_compute_type}"
|
||||
if not os.path.exists(out_dir / "model.bin"):
|
||||
ex = ""
|
||||
try:
|
||||
converter = ctranslate2.converters.TransformersConverter(
|
||||
@ -95,9 +98,9 @@ class CT2CausalLM(Model):
|
||||
)
|
||||
except Exception as ex:
|
||||
pass
|
||||
if not os.path.exists(os.path.join(out_dir, "model.bin")) or ex:
|
||||
if not os.path.exists(out_dir / "model.bin") or ex:
|
||||
raise ValueError(
|
||||
f"conversion for {model_id} failed with ctranslate2: Error {ex}"
|
||||
f"conversion with ctranslate2 for {model_id} failed : Error {ex}"
|
||||
)
|
||||
|
||||
# Start CT2
|
||||
@ -108,10 +111,11 @@ class CT2CausalLM(Model):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision,
|
||||
trust_remote_code=trust_remote_code)
|
||||
|
||||
model = DummyModel()
|
||||
self.vocab_size = model.config.vocab_size
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
@ -165,7 +169,7 @@ class CT2CausalLM(Model):
|
||||
# sampling_temperature=0,
|
||||
# )
|
||||
# # create fake logits from greedy token
|
||||
# logits = torch.full((len(tokens_in), 1, self.vocab_size), -10, dtype=torch.float16, device="cuda")
|
||||
# logits = torch.full((len(tokens_in), 1, self.model.config.vocab_size), -10, dtype=torch.float16, device="cuda")
|
||||
# for i, seq in enumerate(ids):
|
||||
# token = seq.sequences_ids[0]
|
||||
# logits[i, 0, token] = 10
|
||||
|
Loading…
Reference in New Issue
Block a user