mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add parallelization
This commit is contained in:
parent
8ddbdea45b
commit
f08a1a50b7
3
.github/workflows/build.yaml
vendored
3
.github/workflows/build.yaml
vendored
@ -213,12 +213,13 @@ jobs:
|
|||||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
|
pip install pytest-xdist
|
||||||
make install-integration-tests
|
make install-integration-tests
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv -n 2 --dist loadfile integration-tests
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
name: Stop self-hosted EC2 runner
|
name: Stop self-hosted EC2 runner
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import contextlib
|
import contextlib
|
||||||
import pytest
|
import pytest
|
||||||
@ -7,6 +8,7 @@ import docker
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
@ -205,10 +207,12 @@ def launcher(event_loop):
|
|||||||
def local_launcher(
|
def local_launcher(
|
||||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||||
):
|
):
|
||||||
port = 9999
|
port = random.randint(8000, 10_000)
|
||||||
master_port = 19999
|
master_port = random.randint(10_000, 20_000)
|
||||||
|
|
||||||
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
|
shard_uds_path = (
|
||||||
|
f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
|
||||||
|
)
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
"text-generation-launcher",
|
"text-generation-launcher",
|
||||||
@ -236,7 +240,7 @@ def launcher(event_loop):
|
|||||||
process.wait(60)
|
process.wait(60)
|
||||||
|
|
||||||
launcher_output = process.stdout.read().decode("utf-8")
|
launcher_output = process.stdout.read().decode("utf-8")
|
||||||
print(launcher_output)
|
print(launcher_output, file=sys.stderr)
|
||||||
|
|
||||||
process.stdout.close()
|
process.stdout.close()
|
||||||
process.stderr.close()
|
process.stderr.close()
|
||||||
@ -245,7 +249,7 @@ def launcher(event_loop):
|
|||||||
def docker_launcher(
|
def docker_launcher(
|
||||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||||
):
|
):
|
||||||
port = 9999
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
args = ["--model-id", model_id, "--env"]
|
args = ["--model-id", model_id, "--env"]
|
||||||
|
|
||||||
@ -298,7 +302,7 @@ def launcher(event_loop):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
container_output = container.logs().decode("utf-8")
|
container_output = container.logs().decode("utf-8")
|
||||||
print(container_output)
|
print(container_output, file=sys.stderr)
|
||||||
|
|
||||||
container.remove()
|
container.remove()
|
||||||
|
|
||||||
|
@ -1,92 +1,4 @@
|
|||||||
[
|
[
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -8.6875,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -11.5546875,
|
|
||||||
"text": "request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 363,
|
|
||||||
"logprob": -1.5322266,
|
|
||||||
"special": false,
|
|
||||||
"text": " for"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 847,
|
|
||||||
"logprob": -2.5585938,
|
|
||||||
"special": false,
|
|
||||||
"text": " /"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2754,
|
|
||||||
"logprob": -2.265625,
|
|
||||||
"special": false,
|
|
||||||
"text": "api"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29914,
|
|
||||||
"logprob": -0.034088135,
|
|
||||||
"special": false,
|
|
||||||
"text": "/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29894,
|
|
||||||
"logprob": -0.96240234,
|
|
||||||
"special": false,
|
|
||||||
"text": "v"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29896,
|
|
||||||
"logprob": -0.36816406,
|
|
||||||
"special": false,
|
|
||||||
"text": "1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29914,
|
|
||||||
"logprob": -0.013191223,
|
|
||||||
"special": false,
|
|
||||||
"text": "/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 16418,
|
|
||||||
"logprob": -3.15625,
|
|
||||||
"special": false,
|
|
||||||
"text": "projects"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29914,
|
|
||||||
"logprob": -0.43774414,
|
|
||||||
"special": false,
|
|
||||||
"text": "/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29896,
|
|
||||||
"logprob": -1.9443359,
|
|
||||||
"special": false,
|
|
||||||
"text": "1"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"generated_text": "for /api/v1/projects/1"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
@ -263,6 +175,94 @@
|
|||||||
},
|
},
|
||||||
"generated_text": "for /api/v1/projects/1"
|
"generated_text": "for /api/v1/projects/1"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -8.6875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.5546875,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -1.5322266,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 847,
|
||||||
|
"logprob": -2.5585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " /"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2754,
|
||||||
|
"logprob": -2.265625,
|
||||||
|
"special": false,
|
||||||
|
"text": "api"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29914,
|
||||||
|
"logprob": -0.034088135,
|
||||||
|
"special": false,
|
||||||
|
"text": "/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29894,
|
||||||
|
"logprob": -0.96240234,
|
||||||
|
"special": false,
|
||||||
|
"text": "v"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.36816406,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29914,
|
||||||
|
"logprob": -0.013191223,
|
||||||
|
"special": false,
|
||||||
|
"text": "/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16418,
|
||||||
|
"logprob": -3.15625,
|
||||||
|
"special": false,
|
||||||
|
"text": "projects"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29914,
|
||||||
|
"logprob": -0.43774414,
|
||||||
|
"special": false,
|
||||||
|
"text": "/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -1.9443359,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "for /api/v1/projects/1"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
"id": 926,
|
"id": 926,
|
||||||
"logprob": -4.3554688,
|
"logprob": -4.3554688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "To"
|
"text": " To"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 18295,
|
"id": 18295,
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
"id": 16017,
|
"id": 16017,
|
||||||
"logprob": -1.3505859,
|
"logprob": -1.3505859,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "blue"
|
"text": " blue"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 20495,
|
"id": 20495,
|
||||||
|
@ -1,58 +1,4 @@
|
|||||||
[
|
[
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "eos_token",
|
|
||||||
"generated_tokens": 6,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "<pad>"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"logprob": -1.3789062,
|
|
||||||
"special": false,
|
|
||||||
"text": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 39261,
|
|
||||||
"logprob": -0.36279297,
|
|
||||||
"special": false,
|
|
||||||
"text": "Because"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 609,
|
|
||||||
"logprob": -1.0966797,
|
|
||||||
"special": false,
|
|
||||||
"text": " it"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 339,
|
|
||||||
"logprob": -0.8276367,
|
|
||||||
"special": false,
|
|
||||||
"text": " is"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 16017,
|
|
||||||
"logprob": -1.6845703,
|
|
||||||
"special": false,
|
|
||||||
"text": " blue"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"logprob": -0.72753906,
|
|
||||||
"special": true,
|
|
||||||
"text": "</s>"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"generated_text": "Because it is blue"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
@ -71,7 +17,7 @@
|
|||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.3798828,
|
"logprob": -1.3798828,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ""
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 39261,
|
"id": 39261,
|
||||||
@ -125,7 +71,7 @@
|
|||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.3789062,
|
"logprob": -1.3789062,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ""
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 39261,
|
"id": 39261,
|
||||||
@ -179,7 +125,61 @@
|
|||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.3789062,
|
"logprob": -1.3789062,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ""
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 39261,
|
||||||
|
"logprob": -0.36279297,
|
||||||
|
"special": false,
|
||||||
|
"text": "Because"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 609,
|
||||||
|
"logprob": -1.0966797,
|
||||||
|
"special": false,
|
||||||
|
"text": " it"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 339,
|
||||||
|
"logprob": -0.8276367,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16017,
|
||||||
|
"logprob": -1.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " blue"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.72753906,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "Because it is blue"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 6,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -1.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 39261,
|
"id": 39261,
|
||||||
|
@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
|
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
||||||
}),
|
}),
|
||||||
|
@ -56,7 +56,7 @@ class BLOOM(CausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super(BLOOM, self).__init__(
|
super(BLOOM, self).__init__(
|
||||||
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
model_id=model_id, revision=revision, quantize=quantize
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -111,7 +111,6 @@ class BLOOMSharded(BLOOM):
|
|||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
# offsets.append(None)
|
|
||||||
# token_offsets.append(None)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -102,7 +100,7 @@ class CausalLMBatch(Batch):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
for i, r in enumerate(pb.requests):
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
offsets.append(0)
|
offsets.append(0)
|
||||||
token_offsets.append(input_len)
|
token_offsets.append(input_len)
|
||||||
@ -452,7 +450,6 @@ class CausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 4,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -486,7 +483,6 @@ class CausalLM(Model):
|
|||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=decode_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
offsets.append(None)
|
offsets.append(0)
|
||||||
token_offsets.append(None)
|
token_offsets.append(input_length)
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
@ -394,7 +394,6 @@ class FlashCausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 4,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -410,7 +409,7 @@ class FlashCausalLM(Model):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
)
|
).to(device)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -418,7 +417,6 @@ class FlashCausalLM(Model):
|
|||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=decode_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -66,7 +66,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
self.load_weights(model, filenames, quantize, device, dtype)
|
self.load_weights(model, filenames, quantize, device, dtype)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -191,7 +191,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
)
|
)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -75,7 +75,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
)
|
)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -69,12 +69,11 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -215,14 +214,13 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
)
|
)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
decode_buffer=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -18,20 +18,15 @@ class Model(ABC):
|
|||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
decode_buffer: int = 4,
|
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if decode_buffer < 1:
|
self.model = model.eval()
|
||||||
raise ValueError("decode_buffer must be >= 1")
|
|
||||||
|
|
||||||
self.model = model.eval().to(device)
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.requires_padding = requires_padding
|
self.requires_padding = requires_padding
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.decode_buffer = decode_buffer
|
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
@ -61,12 +56,6 @@ class Model(ABC):
|
|||||||
) -> Tuple[str, int, int]:
|
) -> Tuple[str, int, int]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
|
||||||
# Compatibility layer for old None values.
|
|
||||||
if prefix_offset is None:
|
|
||||||
prefix_offset = 0
|
|
||||||
if read_offset is None:
|
|
||||||
read_offset = 0
|
|
||||||
|
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
# which decide to add a space or not depending on the surrounding ids.
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
prefix_text = self.tokenizer.decode(
|
prefix_text = self.tokenizer.decode(
|
||||||
|
@ -52,7 +52,7 @@ class SantaCoder(CausalLM):
|
|||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=True, # required
|
trust_remote_code=True, # required
|
||||||
)
|
).to(device)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -60,7 +60,6 @@ class SantaCoder(CausalLM):
|
|||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
# offsets.append(None)
|
|
||||||
# token_offsets.append(None)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -123,7 +121,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
.repeat(len(pb.requests))
|
.repeat(len(pb.requests))
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
)
|
)
|
||||||
for i, r in enumerate(pb.requests):
|
for _ in pb.requests:
|
||||||
offsets.append(0)
|
offsets.append(0)
|
||||||
token_offsets.append(1)
|
token_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
@ -505,7 +503,6 @@ class Seq2SeqLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 4,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -535,7 +532,6 @@ class Seq2SeqLM(Model):
|
|||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=decode_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user