mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add tests
This commit is contained in:
parent
9a9244937b
commit
d69e4d2d1e
47
.github/workflows/build.yaml
vendored
47
.github/workflows/build.yaml
vendored
@ -70,6 +70,8 @@ jobs:
|
||||
# with sigstore/fulcio when running outside of PRs.
|
||||
id-token: write
|
||||
security-events: write
|
||||
outputs:
|
||||
image:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
@ -108,7 +110,19 @@ jobs:
|
||||
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
||||
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
||||
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
||||
# If pull request
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4.3.0
|
||||
with:
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
tags: |
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
|
||||
# If main, release or tag
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4.3.0
|
||||
with:
|
||||
@ -129,7 +143,7 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
push: true
|
||||
platforms: 'linux/amd64'
|
||||
build-args: |
|
||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||
@ -172,11 +186,42 @@ jobs:
|
||||
with:
|
||||
sarif_file: 'trivy-results.sarif'
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image # Wait for the docker image to be built
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
env:
|
||||
DOCKER_IMAGE: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Tailscale
|
||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||
- name: Prepare disks
|
||||
run: |
|
||||
sudo mkfs -t ext4 /dev/nvme1n1
|
||||
sudo mkdir /data
|
||||
sudo mount /dev/nvme1n1 /data
|
||||
sudo chown -R $USER:$USER /data
|
||||
- name: Install
|
||||
run: |
|
||||
make install-integration-tests
|
||||
- name: Run tests
|
||||
run: |
|
||||
make integration-tests
|
||||
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image
|
||||
- integration-tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
|
11
Makefile
11
Makefile
@ -1,6 +1,9 @@
|
||||
install-server:
|
||||
cd server && make install
|
||||
|
||||
install-integration-tests:
|
||||
cd integration-tests && pip install -r requirements.txt
|
||||
|
||||
install-router:
|
||||
cd router && cargo install --path .
|
||||
|
||||
@ -18,9 +21,15 @@ server-dev:
|
||||
router-dev:
|
||||
cd router && cargo run -- --port 8080
|
||||
|
||||
integration-tests: install-router install-launcher
|
||||
rust-tests: install-router install-launcher
|
||||
cargo test
|
||||
|
||||
integration-tests: install-integration-tests
|
||||
pytest -s -vv integration-tests
|
||||
|
||||
update-integration-tests: install-integration-tests
|
||||
pytest -s -vv --snapshot-update integration-tests
|
||||
|
||||
python-server-tests:
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
|
||||
|
||||
|
@ -253,5 +253,7 @@ make python-client-tests
|
||||
# or both server and client tests
|
||||
make python-tests
|
||||
# rust cargo tests
|
||||
make rust-tests
|
||||
# integration tests
|
||||
make integration-tests
|
||||
```
|
||||
|
@ -13,6 +13,7 @@ from text_generation.types import Response
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -26,7 +27,7 @@ def event_loop():
|
||||
def launcher(event_loop):
|
||||
@contextlib.contextmanager
|
||||
def local_launcher(
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||
):
|
||||
port = 9999
|
||||
master_port = 19999
|
||||
@ -66,7 +67,7 @@ def launcher(event_loop):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def docker_launcher(
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||
):
|
||||
port = 9999
|
||||
|
||||
@ -94,6 +95,10 @@ def launcher(event_loop):
|
||||
if HUGGING_FACE_HUB_TOKEN is not None:
|
||||
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
|
||||
|
||||
volumes = []
|
||||
if DOCKER_VOLUME:
|
||||
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||
|
||||
container = client.containers.run(
|
||||
DOCKER_IMAGE,
|
||||
command=args,
|
||||
@ -104,7 +109,7 @@ def launcher(event_loop):
|
||||
device_requests=[
|
||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||
],
|
||||
volumes=["/data:/data"],
|
||||
volumes=volumes,
|
||||
ports={"80/tcp": port},
|
||||
)
|
||||
|
||||
@ -130,6 +135,6 @@ def generate_load():
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*futures)
|
||||
return results
|
||||
return [r.generated_text for r in results]
|
||||
|
||||
return generate_load_inner
|
||||
|
@ -7,88 +7,233 @@
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 10264,
|
||||
'id': 17934,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
'text': 'Pour',
|
||||
}),
|
||||
dict({
|
||||
'id': 8821,
|
||||
'logprob': -11.3125,
|
||||
'text': ' request',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 11,
|
||||
'logprob': -2.859375,
|
||||
'special': False,
|
||||
'text': '(',
|
||||
'id': 49833,
|
||||
'logprob': -10.5625,
|
||||
'text': ' dég',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -2.34375,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
'id': 21543,
|
||||
'logprob': -0.14770508,
|
||||
'text': 'uster',
|
||||
}),
|
||||
dict({
|
||||
'id': 1587,
|
||||
'logprob': -3.25,
|
||||
'special': False,
|
||||
'text': 'get',
|
||||
'id': 447,
|
||||
'logprob': -1.9287109,
|
||||
'text': ' un',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -1.828125,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
'id': 46341,
|
||||
'logprob': -15.4609375,
|
||||
'text': ' ort',
|
||||
}),
|
||||
dict({
|
||||
'id': 35567,
|
||||
'logprob': -7.5585938,
|
||||
'text': 'olan',
|
||||
}),
|
||||
dict({
|
||||
'id': 15,
|
||||
'logprob': -0.35546875,
|
||||
'special': False,
|
||||
'logprob': -1.4003906,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 567,
|
||||
'logprob': -2.4375,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
'id': 1669,
|
||||
'logprob': -1.5673828,
|
||||
'text': ' il',
|
||||
}),
|
||||
dict({
|
||||
'id': 17,
|
||||
'logprob': -4.40625,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
'id': 11580,
|
||||
'logprob': -0.94628906,
|
||||
'text': ' faut',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -2.46875,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
'id': 3913,
|
||||
'logprob': -3.703125,
|
||||
'text': ' tout',
|
||||
}),
|
||||
dict({
|
||||
'id': 12,
|
||||
'logprob': -1.6015625,
|
||||
'id': 39261,
|
||||
'logprob': -1.5732422,
|
||||
'text': " d'abord",
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 578,
|
||||
'logprob': -1.6591797,
|
||||
'special': False,
|
||||
'text': ')',
|
||||
'text': ' le',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.14550781,
|
||||
'id': 5608,
|
||||
'logprob': -2.4492188,
|
||||
'special': False,
|
||||
'text': ';',
|
||||
'text': ' faire',
|
||||
}),
|
||||
dict({
|
||||
'id': 159570,
|
||||
'logprob': -6.6835938,
|
||||
'special': False,
|
||||
'text': ' réch',
|
||||
}),
|
||||
dict({
|
||||
'id': 810,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': 'au',
|
||||
}),
|
||||
dict({
|
||||
'id': 12736,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': 'ffer',
|
||||
}),
|
||||
dict({
|
||||
'id': 1742,
|
||||
'logprob': -2.5175781,
|
||||
'special': False,
|
||||
'text': ' au',
|
||||
}),
|
||||
dict({
|
||||
'id': 6105,
|
||||
'logprob': -2.0078125,
|
||||
'special': False,
|
||||
'text': ' bain',
|
||||
}),
|
||||
dict({
|
||||
'id': 88254,
|
||||
'logprob': -0.12695312,
|
||||
'special': False,
|
||||
'text': '-mar',
|
||||
}),
|
||||
dict({
|
||||
'id': 641,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': 'ie',
|
||||
}),
|
||||
dict({
|
||||
'id': 2940,
|
||||
'logprob': -3.5175781,
|
||||
'special': False,
|
||||
'text': ' avec',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '("get", ".");',
|
||||
'generated_text': ' le faire réchauffer au bain-marie avec',
|
||||
})
|
||||
# ---
|
||||
# name: test_bloom_560m_all_params
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 15,
|
||||
'logprob': None,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 1669,
|
||||
'logprob': -5.4414062,
|
||||
'text': ' il',
|
||||
}),
|
||||
dict({
|
||||
'id': 11580,
|
||||
'logprob': -2.3378906,
|
||||
'text': ' faut',
|
||||
}),
|
||||
dict({
|
||||
'id': 3913,
|
||||
'logprob': -4.3554688,
|
||||
'text': ' tout',
|
||||
}),
|
||||
dict({
|
||||
'id': 39261,
|
||||
'logprob': -2.9238281,
|
||||
'text': " d'abord",
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 408,
|
||||
'logprob': -1.9267578,
|
||||
'special': False,
|
||||
'text': ' que',
|
||||
}),
|
||||
dict({
|
||||
'id': 20288,
|
||||
'logprob': -2.9257812,
|
||||
'special': False,
|
||||
'text': " l'on",
|
||||
}),
|
||||
dict({
|
||||
'id': 22255,
|
||||
'logprob': -2.8964844,
|
||||
'special': False,
|
||||
'text': ' trouve',
|
||||
}),
|
||||
dict({
|
||||
'id': 1622,
|
||||
'logprob': -1.1083984,
|
||||
'special': False,
|
||||
'text': ' une',
|
||||
}),
|
||||
dict({
|
||||
'id': 187079,
|
||||
'logprob': -7.796875,
|
||||
'special': False,
|
||||
'text': ' posture',
|
||||
}),
|
||||
dict({
|
||||
'id': 501,
|
||||
'logprob': -5.390625,
|
||||
'special': False,
|
||||
'text': ' par',
|
||||
}),
|
||||
dict({
|
||||
'id': 8741,
|
||||
'logprob': -0.34936523,
|
||||
'special': False,
|
||||
'text': ' rapport',
|
||||
}),
|
||||
dict({
|
||||
'id': 693,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': ' à',
|
||||
}),
|
||||
dict({
|
||||
'id': 366,
|
||||
'logprob': -2.3378906,
|
||||
'special': False,
|
||||
'text': ' la',
|
||||
}),
|
||||
dict({
|
||||
'id': 36503,
|
||||
'logprob': -3.6640625,
|
||||
'special': False,
|
||||
'text': ' pratique',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique",
|
||||
})
|
||||
# ---
|
||||
# name: test_bloom_560m_load
|
||||
list([
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
])
|
||||
# ---
|
||||
|
@ -7,88 +7,133 @@
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 10264,
|
||||
'id': 17934,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
'text': 'Pour',
|
||||
}),
|
||||
dict({
|
||||
'id': 8821,
|
||||
'logprob': -11.375,
|
||||
'text': ' request',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 11,
|
||||
'logprob': -2.734375,
|
||||
'special': False,
|
||||
'text': '(',
|
||||
'id': 49833,
|
||||
'logprob': -10.5390625,
|
||||
'text': ' dég',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -1.9765625,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
'id': 21543,
|
||||
'logprob': -0.14758301,
|
||||
'text': 'uster',
|
||||
}),
|
||||
dict({
|
||||
'id': 1587,
|
||||
'logprob': -3.140625,
|
||||
'special': False,
|
||||
'text': 'get',
|
||||
'id': 447,
|
||||
'logprob': -1.9296875,
|
||||
'text': ' un',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -3.515625,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
'id': 46341,
|
||||
'logprob': -15.4453125,
|
||||
'text': ' ort',
|
||||
}),
|
||||
dict({
|
||||
'id': 35567,
|
||||
'logprob': -7.59375,
|
||||
'text': 'olan',
|
||||
}),
|
||||
dict({
|
||||
'id': 15,
|
||||
'logprob': -0.37304688,
|
||||
'special': False,
|
||||
'logprob': -1.3994141,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 567,
|
||||
'logprob': -2.6875,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
'id': 1669,
|
||||
'logprob': -1.578125,
|
||||
'text': ' il',
|
||||
}),
|
||||
dict({
|
||||
'id': 17,
|
||||
'logprob': -4.65625,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
'id': 11580,
|
||||
'logprob': -0.9453125,
|
||||
'text': ' faut',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -2.28125,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
'id': 3913,
|
||||
'logprob': -3.7011719,
|
||||
'text': ' tout',
|
||||
}),
|
||||
dict({
|
||||
'id': 12,
|
||||
'logprob': -1.6640625,
|
||||
'id': 39261,
|
||||
'logprob': -1.5732422,
|
||||
'text': " d'abord",
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 578,
|
||||
'logprob': -1.6474609,
|
||||
'special': False,
|
||||
'text': ')',
|
||||
'text': ' le',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.14355469,
|
||||
'id': 5608,
|
||||
'logprob': -2.5097656,
|
||||
'special': False,
|
||||
'text': ';',
|
||||
'text': ' faire',
|
||||
}),
|
||||
dict({
|
||||
'id': 159570,
|
||||
'logprob': -6.65625,
|
||||
'special': False,
|
||||
'text': ' réch',
|
||||
}),
|
||||
dict({
|
||||
'id': 810,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': 'au',
|
||||
}),
|
||||
dict({
|
||||
'id': 12736,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': 'ffer',
|
||||
}),
|
||||
dict({
|
||||
'id': 1742,
|
||||
'logprob': -2.5859375,
|
||||
'special': False,
|
||||
'text': ' au',
|
||||
}),
|
||||
dict({
|
||||
'id': 6105,
|
||||
'logprob': -2.03125,
|
||||
'special': False,
|
||||
'text': ' bain',
|
||||
}),
|
||||
dict({
|
||||
'id': 88254,
|
||||
'logprob': -0.12695312,
|
||||
'special': False,
|
||||
'text': '-mar',
|
||||
}),
|
||||
dict({
|
||||
'id': 641,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': 'ie',
|
||||
}),
|
||||
dict({
|
||||
'id': 2940,
|
||||
'logprob': -3.5175781,
|
||||
'special': False,
|
||||
'text': ' avec',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '("get", ".");',
|
||||
'generated_text': ' le faire réchauffer au bain-marie avec',
|
||||
})
|
||||
# ---
|
||||
# name: test_bloom_560m_sharded_load
|
||||
list([
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
" le faire cuire dans de l'eau bouillante sal",
|
||||
])
|
||||
# ---
|
||||
|
195
integration-tests/models/__snapshots__/test_flash_llama.ambr
Normal file
195
integration-tests/models/__snapshots__/test_flash_llama.ambr
Normal file
@ -0,0 +1,195 @@
|
||||
# serializer version: 1
|
||||
# name: test_flash_llama
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 1,
|
||||
'logprob': None,
|
||||
'text': '<s>',
|
||||
}),
|
||||
dict({
|
||||
'id': 4321,
|
||||
'logprob': -8.6875,
|
||||
'text': 'Test',
|
||||
}),
|
||||
dict({
|
||||
'id': 2009,
|
||||
'logprob': -11.5546875,
|
||||
'text': 'request',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 363,
|
||||
'logprob': -1.5380859,
|
||||
'special': False,
|
||||
'text': ' for',
|
||||
}),
|
||||
dict({
|
||||
'id': 847,
|
||||
'logprob': -2.5917969,
|
||||
'special': False,
|
||||
'text': ' /',
|
||||
}),
|
||||
dict({
|
||||
'id': 2754,
|
||||
'logprob': -2.2773438,
|
||||
'special': False,
|
||||
'text': 'api',
|
||||
}),
|
||||
dict({
|
||||
'id': 29914,
|
||||
'logprob': -0.034362793,
|
||||
'special': False,
|
||||
'text': '/',
|
||||
}),
|
||||
dict({
|
||||
'id': 29894,
|
||||
'logprob': -0.96533203,
|
||||
'special': False,
|
||||
'text': 'v',
|
||||
}),
|
||||
dict({
|
||||
'id': 29896,
|
||||
'logprob': -0.36669922,
|
||||
'special': False,
|
||||
'text': '1',
|
||||
}),
|
||||
dict({
|
||||
'id': 29914,
|
||||
'logprob': -0.013122559,
|
||||
'special': False,
|
||||
'text': '/',
|
||||
}),
|
||||
dict({
|
||||
'id': 16418,
|
||||
'logprob': -3.1503906,
|
||||
'special': False,
|
||||
'text': 'projects',
|
||||
}),
|
||||
dict({
|
||||
'id': 29914,
|
||||
'logprob': -0.43652344,
|
||||
'special': False,
|
||||
'text': '/',
|
||||
}),
|
||||
dict({
|
||||
'id': 29896,
|
||||
'logprob': -1.9404297,
|
||||
'special': False,
|
||||
'text': '1',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': 'for /api/v1/projects/1',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_llama_all_params
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 1,
|
||||
'logprob': None,
|
||||
'text': '<s>',
|
||||
}),
|
||||
dict({
|
||||
'id': 4321,
|
||||
'logprob': -8.6875,
|
||||
'text': 'Test',
|
||||
}),
|
||||
dict({
|
||||
'id': 2009,
|
||||
'logprob': -11.5546875,
|
||||
'text': 'request',
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 5229,
|
||||
'logprob': -3.3085938,
|
||||
'special': False,
|
||||
'text': ' failed',
|
||||
}),
|
||||
dict({
|
||||
'id': 363,
|
||||
'logprob': -3.984375,
|
||||
'special': False,
|
||||
'text': ' for',
|
||||
}),
|
||||
dict({
|
||||
'id': 5641,
|
||||
'logprob': -6.53125,
|
||||
'special': False,
|
||||
'text': ' IP',
|
||||
}),
|
||||
dict({
|
||||
'id': 16428,
|
||||
'logprob': -3.1835938,
|
||||
'special': False,
|
||||
'text': ' Address',
|
||||
}),
|
||||
dict({
|
||||
'id': 29901,
|
||||
'logprob': -1.2324219,
|
||||
'special': False,
|
||||
'text': ':',
|
||||
}),
|
||||
dict({
|
||||
'id': 525,
|
||||
'logprob': -2.6855469,
|
||||
'special': False,
|
||||
'text': " '",
|
||||
}),
|
||||
dict({
|
||||
'id': 8516,
|
||||
'logprob': -7.1601562,
|
||||
'special': False,
|
||||
'text': 'None',
|
||||
}),
|
||||
dict({
|
||||
'id': 4286,
|
||||
'logprob': -2.4433594,
|
||||
'special': False,
|
||||
'text': "'.",
|
||||
}),
|
||||
dict({
|
||||
'id': 13,
|
||||
'logprob': -0.06530762,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 294,
|
||||
'logprob': -7.953125,
|
||||
'special': False,
|
||||
'text': 'as',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '''
|
||||
Test requestfailed for IP Address: 'None'.
|
||||
as
|
||||
''',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_llama_load
|
||||
list([
|
||||
'for /api/v1/projects/1',
|
||||
'for /api/v1/projects/1',
|
||||
'for /api/v1/projects/1',
|
||||
'for /api/v1/projects/1',
|
||||
])
|
||||
# ---
|
174
integration-tests/models/__snapshots__/test_flash_neox.ambr
Normal file
174
integration-tests/models/__snapshots__/test_flash_neox.ambr
Normal file
@ -0,0 +1,174 @@
|
||||
# serializer version: 1
|
||||
# name: test_flash_neox
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 50278,
|
||||
'logprob': None,
|
||||
'text': '<|prompter|>',
|
||||
}),
|
||||
dict({
|
||||
'id': 1276,
|
||||
'logprob': -8.03125,
|
||||
'text': 'What',
|
||||
}),
|
||||
dict({
|
||||
'id': 310,
|
||||
'logprob': -5.421875,
|
||||
'text': ' is',
|
||||
}),
|
||||
dict({
|
||||
'id': 247,
|
||||
'logprob': -2.1601562,
|
||||
'text': ' a',
|
||||
}),
|
||||
dict({
|
||||
'id': 1167,
|
||||
'logprob': -5.4609375,
|
||||
'text': ' mem',
|
||||
}),
|
||||
dict({
|
||||
'id': 70,
|
||||
'logprob': -0.005657196,
|
||||
'text': 'e',
|
||||
}),
|
||||
dict({
|
||||
'id': 13,
|
||||
'logprob': -7.28125,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 285,
|
||||
'logprob': -0.2980957,
|
||||
'text': ' and',
|
||||
}),
|
||||
dict({
|
||||
'id': 752,
|
||||
'logprob': -2.1679688,
|
||||
'text': ' what',
|
||||
}),
|
||||
dict({
|
||||
'id': 434,
|
||||
'logprob': -5.6210938,
|
||||
'text': "'s",
|
||||
}),
|
||||
dict({
|
||||
'id': 253,
|
||||
'logprob': -0.81103516,
|
||||
'text': ' the',
|
||||
}),
|
||||
dict({
|
||||
'id': 2892,
|
||||
'logprob': -6.6640625,
|
||||
'text': ' history',
|
||||
}),
|
||||
dict({
|
||||
'id': 3212,
|
||||
'logprob': -2.265625,
|
||||
'text': ' behind',
|
||||
}),
|
||||
dict({
|
||||
'id': 436,
|
||||
'logprob': -11.5078125,
|
||||
'text': ' this',
|
||||
}),
|
||||
dict({
|
||||
'id': 3159,
|
||||
'logprob': -2.1582031,
|
||||
'text': ' word',
|
||||
}),
|
||||
dict({
|
||||
'id': 32,
|
||||
'logprob': -0.008720398,
|
||||
'text': '?',
|
||||
}),
|
||||
dict({
|
||||
'id': 0,
|
||||
'logprob': -2.4726562,
|
||||
'text': '<|endoftext|>',
|
||||
}),
|
||||
dict({
|
||||
'id': 50281,
|
||||
'logprob': -18.265625,
|
||||
'text': '<|assistant|>',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 510,
|
||||
'logprob': -0.63183594,
|
||||
'special': False,
|
||||
'text': 'The',
|
||||
}),
|
||||
dict({
|
||||
'id': 3159,
|
||||
'logprob': -0.5390625,
|
||||
'special': False,
|
||||
'text': ' word',
|
||||
}),
|
||||
dict({
|
||||
'id': 346,
|
||||
'logprob': -0.045684814,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
}),
|
||||
dict({
|
||||
'id': 6441,
|
||||
'logprob': -0.002090454,
|
||||
'special': False,
|
||||
'text': 'mem',
|
||||
}),
|
||||
dict({
|
||||
'id': 70,
|
||||
'logprob': -1.3589859e-05,
|
||||
'special': False,
|
||||
'text': 'e',
|
||||
}),
|
||||
dict({
|
||||
'id': 3,
|
||||
'logprob': -0.0009455681,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 369,
|
||||
'logprob': -0.088012695,
|
||||
'special': False,
|
||||
'text': ' was',
|
||||
}),
|
||||
dict({
|
||||
'id': 806,
|
||||
'logprob': -0.12585449,
|
||||
'special': False,
|
||||
'text': ' first',
|
||||
}),
|
||||
dict({
|
||||
'id': 908,
|
||||
'logprob': -0.017196655,
|
||||
'special': False,
|
||||
'text': ' used',
|
||||
}),
|
||||
dict({
|
||||
'id': 275,
|
||||
'logprob': -0.49731445,
|
||||
'special': False,
|
||||
'text': ' in',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': 'The word "meme" was first used in',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_neox_load
|
||||
list([
|
||||
'The word "meme" was first used in',
|
||||
'The word "meme" was first used in',
|
||||
'The word "meme" was first used in',
|
||||
'The word "meme" was first used in',
|
||||
])
|
||||
# ---
|
@ -7,95 +7,132 @@
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 804,
|
||||
'id': 563,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
'text': 'def',
|
||||
}),
|
||||
dict({
|
||||
'id': 942,
|
||||
'logprob': -5.1367188,
|
||||
'text': ' print',
|
||||
}),
|
||||
dict({
|
||||
'id': 62,
|
||||
'logprob': -0.24450684,
|
||||
'text': '_',
|
||||
}),
|
||||
dict({
|
||||
'id': 7196,
|
||||
'logprob': -6.9609375,
|
||||
'text': 'hello',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 25,
|
||||
'logprob': -2.828125,
|
||||
'id': 1241,
|
||||
'logprob': -0.9863281,
|
||||
'special': False,
|
||||
'text': ':',
|
||||
'text': '():',
|
||||
}),
|
||||
dict({
|
||||
'id': 287,
|
||||
'logprob': -1.5703125,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
}),
|
||||
dict({
|
||||
'id': 385,
|
||||
'logprob': -0.03955078,
|
||||
'special': False,
|
||||
'text': ' +',
|
||||
}),
|
||||
dict({
|
||||
'id': 1028,
|
||||
'logprob': -1.453125,
|
||||
'special': False,
|
||||
'text': ' request',
|
||||
}),
|
||||
dict({
|
||||
'id': 13,
|
||||
'logprob': -0.796875,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
}),
|
||||
dict({
|
||||
'id': 1832,
|
||||
'logprob': -1.4296875,
|
||||
'special': False,
|
||||
'text': 'toString',
|
||||
}),
|
||||
dict({
|
||||
'id': 782,
|
||||
'logprob': -0.17871094,
|
||||
'special': False,
|
||||
'text': '());',
|
||||
}),
|
||||
dict({
|
||||
'id': 259,
|
||||
'logprob': -1.359375,
|
||||
'id': 258,
|
||||
'logprob': -0.21447754,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 294,
|
||||
'logprob': -1.6484375,
|
||||
'id': 942,
|
||||
'logprob': -0.43701172,
|
||||
'special': False,
|
||||
'text': ' }',
|
||||
'text': ' print',
|
||||
}),
|
||||
dict({
|
||||
'id': 437,
|
||||
'logprob': -1.0625,
|
||||
'id': 372,
|
||||
'logprob': -0.5361328,
|
||||
'special': False,
|
||||
'text': '("',
|
||||
}),
|
||||
dict({
|
||||
'id': 7371,
|
||||
'logprob': -0.44555664,
|
||||
'special': False,
|
||||
'text': 'Hello',
|
||||
}),
|
||||
dict({
|
||||
'id': 9956,
|
||||
'logprob': -1.2412109,
|
||||
'special': False,
|
||||
'text': ' World',
|
||||
}),
|
||||
dict({
|
||||
'id': 8657,
|
||||
'logprob': -0.7583008,
|
||||
'special': False,
|
||||
'text': '!")',
|
||||
}),
|
||||
dict({
|
||||
'id': 185,
|
||||
'logprob': -0.76171875,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 185,
|
||||
'logprob': -0.20837402,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 1018,
|
||||
'logprob': -1.2470703,
|
||||
'special': False,
|
||||
'text': 'print',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '''
|
||||
: " + request.toString());
|
||||
}
|
||||
():
|
||||
print("Hello World!")
|
||||
|
||||
|
||||
print
|
||||
''',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_santacoder_load
|
||||
list([
|
||||
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=': " + request.toString());\n }\n\n ', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=287, text=' "', logprob=-1.5859375, special=False), Token(id=385, text=' +', logprob=-0.037841797, special=False), Token(id=1028, text=' request', logprob=-1.4453125, special=False), Token(id=13, text='.', logprob=-0.79296875, special=False), Token(id=1832, text='toString', logprob=-1.4375, special=False), Token(id=782, text='());', logprob=-0.19335938, special=False), Token(id=259, text='\n ', logprob=-1.359375, special=False), Token(id=294, text=' }', logprob=-1.609375, special=False), Token(id=437, text='\n\n ', logprob=-1.0546875, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||
'''
|
||||
():
|
||||
print("Hello World!")
|
||||
|
||||
print
|
||||
''',
|
||||
'''
|
||||
():
|
||||
print("Hello World!")
|
||||
|
||||
print
|
||||
''',
|
||||
'''
|
||||
():
|
||||
print("Hello World!")
|
||||
|
||||
print
|
||||
''',
|
||||
'''
|
||||
():
|
||||
print("Hello World!")
|
||||
|
||||
print
|
||||
''',
|
||||
])
|
||||
# ---
|
||||
|
@ -7,83 +7,249 @@
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 1006,
|
||||
'id': 589,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
'text': 'def',
|
||||
}),
|
||||
dict({
|
||||
'id': 1459,
|
||||
'logprob': -5.6289062,
|
||||
'text': ' print',
|
||||
}),
|
||||
dict({
|
||||
'id': 81,
|
||||
'logprob': -1.6005859,
|
||||
'text': '_',
|
||||
}),
|
||||
dict({
|
||||
'id': 7656,
|
||||
'logprob': -5.9921875,
|
||||
'text': 'hello',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -2.734375,
|
||||
'id': 2262,
|
||||
'logprob': -0.7705078,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
'text': '():',
|
||||
}),
|
||||
dict({
|
||||
'id': 892,
|
||||
'logprob': -2.828125,
|
||||
'id': 284,
|
||||
'logprob': -0.2590332,
|
||||
'special': False,
|
||||
'text': ' String',
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 1984,
|
||||
'logprob': -3.28125,
|
||||
'id': 1459,
|
||||
'logprob': -0.39379883,
|
||||
'special': False,
|
||||
'text': ' url',
|
||||
'text': ' print',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.796875,
|
||||
'id': 440,
|
||||
'logprob': -0.61376953,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
'text': '("',
|
||||
}),
|
||||
dict({
|
||||
'id': 892,
|
||||
'logprob': -1.0390625,
|
||||
'id': 8279,
|
||||
'logprob': -0.47338867,
|
||||
'special': False,
|
||||
'text': ' String',
|
||||
'text': 'Hello',
|
||||
}),
|
||||
dict({
|
||||
'id': 1411,
|
||||
'logprob': -2.078125,
|
||||
'id': 10896,
|
||||
'logprob': -1.5068359,
|
||||
'special': False,
|
||||
'text': ' method',
|
||||
'text': ' World',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.49609375,
|
||||
'id': 657,
|
||||
'logprob': -0.80810547,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
'text': '")',
|
||||
}),
|
||||
dict({
|
||||
'id': 892,
|
||||
'logprob': -1.0546875,
|
||||
'id': 203,
|
||||
'logprob': -0.7397461,
|
||||
'special': False,
|
||||
'text': ' String',
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 3361,
|
||||
'logprob': -1.71875,
|
||||
'id': 203,
|
||||
'logprob': -0.35229492,
|
||||
'special': False,
|
||||
'text': ' body',
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 27,
|
||||
'logprob': -0.69921875,
|
||||
'id': 589,
|
||||
'logprob': -1.0371094,
|
||||
'special': False,
|
||||
'text': ')',
|
||||
'text': 'def',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': ', String url, String method, String body)',
|
||||
'generated_text': '''
|
||||
():
|
||||
print("Hello World")
|
||||
|
||||
def
|
||||
''',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_starcoder_default_params
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||
'generated_tokens': 12,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 589,
|
||||
'logprob': None,
|
||||
'text': 'def',
|
||||
}),
|
||||
dict({
|
||||
'id': 1459,
|
||||
'logprob': -5.6289062,
|
||||
'text': ' print',
|
||||
}),
|
||||
dict({
|
||||
'id': 81,
|
||||
'logprob': -1.6005859,
|
||||
'text': '_',
|
||||
}),
|
||||
dict({
|
||||
'id': 7656,
|
||||
'logprob': -5.9921875,
|
||||
'text': 'hello',
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 2262,
|
||||
'logprob': -0.7451172,
|
||||
'special': False,
|
||||
'text': '():',
|
||||
}),
|
||||
dict({
|
||||
'id': 284,
|
||||
'logprob': -0.21325684,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 5741,
|
||||
'logprob': -5.734375,
|
||||
'special': False,
|
||||
'text': ' logging',
|
||||
}),
|
||||
dict({
|
||||
'id': 32,
|
||||
'logprob': 0.0,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
}),
|
||||
dict({
|
||||
'id': 1338,
|
||||
'logprob': -0.3232422,
|
||||
'special': False,
|
||||
'text': 'info',
|
||||
}),
|
||||
dict({
|
||||
'id': 463,
|
||||
'logprob': -1.0380859,
|
||||
'special': False,
|
||||
'text': "('",
|
||||
}),
|
||||
dict({
|
||||
'id': 8279,
|
||||
'logprob': -0.8378906,
|
||||
'special': False,
|
||||
'text': 'Hello',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -1.9501953,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 10896,
|
||||
'logprob': -1.3476562,
|
||||
'special': False,
|
||||
'text': ' World',
|
||||
}),
|
||||
dict({
|
||||
'id': 683,
|
||||
'logprob': -1.796875,
|
||||
'special': False,
|
||||
'text': "')",
|
||||
}),
|
||||
dict({
|
||||
'id': 203,
|
||||
'logprob': -0.9873047,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 0,
|
||||
'logprob': -0.7495117,
|
||||
'special': True,
|
||||
'text': '<|endoftext|>',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '''
|
||||
():
|
||||
logging.info('Hello, World')
|
||||
<|endoftext|>
|
||||
''',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_starcoder_load
|
||||
list([
|
||||
Response(generated_text=', String url, String method, String body)', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=892, text=' String', logprob=-2.84375, special=False), Token(id=1984, text=' url', logprob=-3.28125, special=False), Token(id=30, text=',', logprob=-0.796875, special=False), Token(id=892, text=' String', logprob=-1.03125, special=False), Token(id=1411, text=' method', logprob=-2.09375, special=False), Token(id=30, text=',', logprob=-0.49804688, special=False), Token(id=892, text=' String', logprob=-1.0546875, special=False), Token(id=3361, text=' body', logprob=-1.7265625, special=False), Token(id=27, text=')', logprob=-0.703125, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||
'''
|
||||
():
|
||||
print("Hello World")
|
||||
|
||||
def
|
||||
''',
|
||||
'''
|
||||
():
|
||||
print("Hello World")
|
||||
|
||||
def
|
||||
''',
|
||||
'''
|
||||
():
|
||||
print("Hello World")
|
||||
|
||||
def
|
||||
''',
|
||||
'''
|
||||
():
|
||||
print("Hello World")
|
||||
|
||||
def
|
||||
''',
|
||||
])
|
||||
# ---
|
||||
|
139
integration-tests/models/__snapshots__/test_mt0_base.ambr
Normal file
139
integration-tests/models/__snapshots__/test_mt0_base.ambr
Normal file
@ -0,0 +1,139 @@
|
||||
# serializer version: 1
|
||||
# name: test_mt0_base
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||
'generated_tokens': 5,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 0,
|
||||
'logprob': None,
|
||||
'text': '<pad>',
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 926,
|
||||
'logprob': -4.3554688,
|
||||
'special': False,
|
||||
'text': 'To',
|
||||
}),
|
||||
dict({
|
||||
'id': 18295,
|
||||
'logprob': -7.7734375,
|
||||
'special': False,
|
||||
'text': ' sell',
|
||||
}),
|
||||
dict({
|
||||
'id': 7868,
|
||||
'logprob': -3.9257812,
|
||||
'special': False,
|
||||
'text': ' things',
|
||||
}),
|
||||
dict({
|
||||
'id': 260,
|
||||
'logprob': -2.4179688,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
}),
|
||||
dict({
|
||||
'id': 1,
|
||||
'logprob': 0.0,
|
||||
'special': True,
|
||||
'text': '</s>',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': 'To sell things.',
|
||||
})
|
||||
# ---
|
||||
# name: test_mt0_base_all_params
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 0,
|
||||
'logprob': None,
|
||||
'text': '<pad>',
|
||||
}),
|
||||
]),
|
||||
'seed': 0,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 16017,
|
||||
'logprob': -1.3505859,
|
||||
'special': False,
|
||||
'text': 'blue',
|
||||
}),
|
||||
dict({
|
||||
'id': 20495,
|
||||
'logprob': -0.50439453,
|
||||
'special': False,
|
||||
'text': ' sky',
|
||||
}),
|
||||
dict({
|
||||
'id': 259,
|
||||
'logprob': -1.2011719,
|
||||
'special': False,
|
||||
'text': ' ',
|
||||
}),
|
||||
dict({
|
||||
'id': 15484,
|
||||
'logprob': -2.8378906,
|
||||
'special': False,
|
||||
'text': 'appear',
|
||||
}),
|
||||
dict({
|
||||
'id': 345,
|
||||
'logprob': -0.87597656,
|
||||
'special': False,
|
||||
'text': 'ed',
|
||||
}),
|
||||
dict({
|
||||
'id': 288,
|
||||
'logprob': -1.8447266,
|
||||
'special': False,
|
||||
'text': ' to',
|
||||
}),
|
||||
dict({
|
||||
'id': 35622,
|
||||
'logprob': -7.1445312,
|
||||
'special': False,
|
||||
'text': ' cloud',
|
||||
}),
|
||||
dict({
|
||||
'id': 263,
|
||||
'logprob': -1.2929688,
|
||||
'special': False,
|
||||
'text': 's',
|
||||
}),
|
||||
dict({
|
||||
'id': 14701,
|
||||
'logprob': -3.0761719,
|
||||
'special': False,
|
||||
'text': ' above',
|
||||
}),
|
||||
dict({
|
||||
'id': 751,
|
||||
'logprob': -4.4375,
|
||||
'special': False,
|
||||
'text': ' all',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': 'Why is the sky blue?blue sky appeared to clouds above all',
|
||||
})
|
||||
# ---
|
||||
# name: test_mt0_base_load
|
||||
list([
|
||||
'Because it is blue',
|
||||
'Because it is blue',
|
||||
'Because it is blue',
|
||||
'Because it is blue',
|
||||
])
|
||||
# ---
|
@ -13,7 +13,35 @@ def bloom_560(launcher):
|
||||
async def test_bloom_560m(bloom_560, snapshot):
|
||||
await health_check(bloom_560, 60)
|
||||
|
||||
response = await bloom_560.generate("Test request", max_new_tokens=10)
|
||||
response = await bloom_560.generate(
|
||||
"Pour déguster un ortolan, il faut tout d'abord",
|
||||
max_new_tokens=10,
|
||||
top_p=0.9,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m_all_params(bloom_560, snapshot):
|
||||
await health_check(bloom_560, 60)
|
||||
|
||||
response = await bloom_560.generate(
|
||||
"Pour déguster un ortolan, il faut tout d'abord",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
@ -23,7 +51,12 @@ async def test_bloom_560m(bloom_560, snapshot):
|
||||
async def test_bloom_560m_load(bloom_560, generate_load, snapshot):
|
||||
await health_check(bloom_560, 60)
|
||||
|
||||
responses = await generate_load(bloom_560, "Test request", max_new_tokens=10, n=4)
|
||||
responses = await generate_load(
|
||||
bloom_560,
|
||||
"Pour déguster un ortolan, il faut tout d'abord",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
|
@ -13,7 +13,12 @@ def bloom_560m_sharded(launcher):
|
||||
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
|
||||
await health_check(bloom_560m_sharded, 60)
|
||||
|
||||
response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10)
|
||||
response = await bloom_560m_sharded.generate(
|
||||
"Pour déguster un ortolan, il faut tout d'abord",
|
||||
max_new_tokens=10,
|
||||
top_p=0.9,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
@ -24,7 +29,10 @@ async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapsh
|
||||
await health_check(bloom_560m_sharded, 60)
|
||||
|
||||
responses = await generate_load(
|
||||
bloom_560m_sharded, "Test request", max_new_tokens=10, n=4
|
||||
bloom_560m_sharded,
|
||||
"Pour déguster un ortolan, il faut tout d'abord",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
53
integration-tests/models/test_flash_llama.py
Normal file
53
integration-tests/models/test_flash_llama.py
Normal file
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama(launcher):
|
||||
with launcher("huggingface/llama-7b", num_shard=2) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama(flash_llama, snapshot):
|
||||
await health_check(flash_llama, 120)
|
||||
|
||||
response = await flash_llama.generate("Test request", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_all_params(flash_llama, snapshot):
|
||||
await health_check(flash_llama, 120)
|
||||
|
||||
response = await flash_llama.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_load(flash_llama, generate_load, snapshot):
|
||||
await health_check(flash_llama, 120)
|
||||
|
||||
responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
assert responses == snapshot
|
38
integration-tests/models/test_flash_neox.py
Normal file
38
integration-tests/models/test_flash_neox.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_neox(launcher):
|
||||
with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_neox(flash_neox, snapshot):
|
||||
await health_check(flash_neox, 240)
|
||||
|
||||
response = await flash_neox.generate(
|
||||
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_neox_load(flash_neox, generate_load, snapshot):
|
||||
await health_check(flash_neox, 240)
|
||||
|
||||
responses = await generate_load(
|
||||
flash_neox,
|
||||
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
assert responses == snapshot
|
@ -13,7 +13,7 @@ def flash_santacoder(launcher):
|
||||
async def test_flash_santacoder(flash_santacoder, snapshot):
|
||||
await health_check(flash_santacoder, 60)
|
||||
|
||||
response = await flash_santacoder.generate("Test request", max_new_tokens=10)
|
||||
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
@ -24,7 +24,7 @@ async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot):
|
||||
await health_check(flash_santacoder, 60)
|
||||
|
||||
responses = await generate_load(
|
||||
flash_santacoder, "Test request", max_new_tokens=10, n=4
|
||||
flash_santacoder, "def print_hello", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
@ -5,26 +5,38 @@ from utils import health_check
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder(launcher):
|
||||
with launcher("bigcode/large-model", num_shard=2) as client:
|
||||
with launcher("bigcode/starcoder", num_shard=2) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder(flash_starcoder, snapshot):
|
||||
await health_check(flash_starcoder, 60)
|
||||
await health_check(flash_starcoder, 240)
|
||||
|
||||
response = await flash_starcoder.generate("Test request", max_new_tokens=10)
|
||||
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder_default_params(flash_starcoder, snapshot):
|
||||
await health_check(flash_starcoder, 240)
|
||||
|
||||
response = await flash_starcoder.generate(
|
||||
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 12
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
|
||||
await health_check(flash_starcoder, 60)
|
||||
await health_check(flash_starcoder, 240)
|
||||
|
||||
responses = await generate_load(
|
||||
flash_starcoder, "Test request", max_new_tokens=10, n=4
|
||||
flash_starcoder, "def print_hello", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
63
integration-tests/models/test_mt0_base.py
Normal file
63
integration-tests/models/test_mt0_base.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mt0_base(launcher):
|
||||
with launcher("bigscience/mt0-base") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mt0_base(mt0_base, snapshot):
|
||||
await health_check(mt0_base, 60)
|
||||
|
||||
response = await mt0_base.generate(
|
||||
"Why is the sky blue?",
|
||||
max_new_tokens=10,
|
||||
top_p=0.9,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 5
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mt0_base_all_params(mt0_base, snapshot):
|
||||
await health_check(mt0_base, 60)
|
||||
|
||||
response = await mt0_base.generate(
|
||||
"Why is the sky blue?",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mt0_base_load(mt0_base, generate_load, snapshot):
|
||||
await health_check(mt0_base, 60)
|
||||
|
||||
responses = await generate_load(
|
||||
mt0_base,
|
||||
"Why is the sky blue?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
assert responses == snapshot
|
@ -12,4 +12,4 @@ async def health_check(client: AsyncClient, timeout: int = 60):
|
||||
return
|
||||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
||||
time.sleep(1)
|
||||
raise e
|
||||
raise RuntimeError("Health check failed")
|
||||
|
@ -1,142 +0,0 @@
|
||||
{
|
||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 10264,
|
||||
"text": "Test",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 8821,
|
||||
"text": " request",
|
||||
"logprob": -11.894989
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 17,
|
||||
"text": ".",
|
||||
"logprob": -1.8267672,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1587,
|
||||
"text": "get",
|
||||
"logprob": -2.4674969,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"text": "(",
|
||||
"logprob": -1.906001,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"text": "\"",
|
||||
"logprob": -1.2279545,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4899,
|
||||
"text": "action",
|
||||
"logprob": -4.170299,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"text": "\"",
|
||||
"logprob": -0.32478866,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"text": ")",
|
||||
"logprob": -1.0773665,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"text": ";",
|
||||
"logprob": -0.27640742,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 837,
|
||||
"text": "\n ",
|
||||
"logprob": -1.6970354,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1320,
|
||||
"text": " if",
|
||||
"logprob": -1.4495516,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 375,
|
||||
"text": " (",
|
||||
"logprob": -0.23609057,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4899,
|
||||
"text": "action",
|
||||
"logprob": -1.1916996,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 3535,
|
||||
"text": " ==",
|
||||
"logprob": -0.8918753,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5109,
|
||||
"text": " null",
|
||||
"logprob": -0.3933342,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"text": ")",
|
||||
"logprob": -0.43212673,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 731,
|
||||
"text": " {",
|
||||
"logprob": -0.17702064,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1260,
|
||||
"text": "\n ",
|
||||
"logprob": -0.07027565,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 10519,
|
||||
"text": " throw",
|
||||
"logprob": -1.3915029,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2084,
|
||||
"text": " new",
|
||||
"logprob": -0.04201372,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 150858,
|
||||
"text": " RuntimeException",
|
||||
"logprob": -1.7329919,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -1,172 +0,0 @@
|
||||
use float_eq::assert_float_eq;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
use subprocess::{Popen, PopenConfig, Redirection};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Token {
|
||||
id: u32,
|
||||
text: String,
|
||||
logprob: Option<f32>,
|
||||
special: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Details {
|
||||
finish_reason: String,
|
||||
generated_tokens: u32,
|
||||
tokens: Vec<Token>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GeneratedText {
|
||||
generated_text: String,
|
||||
details: Details,
|
||||
}
|
||||
|
||||
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||
let argv = vec![
|
||||
"text-generation-launcher".to_string(),
|
||||
"--model-id".to_string(),
|
||||
model_id.clone(),
|
||||
"--num-shard".to_string(),
|
||||
num_shard.to_string(),
|
||||
"--port".to_string(),
|
||||
port.to_string(),
|
||||
"--master-port".to_string(),
|
||||
master_port.to_string(),
|
||||
"--shard-uds-path".to_string(),
|
||||
format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
|
||||
];
|
||||
|
||||
let mut launcher = Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Merge,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.expect("Could not start launcher");
|
||||
|
||||
// Redirect STDOUT and STDERR to the console
|
||||
// (STDERR is merged into STDOUT)
|
||||
let launcher_stdout = launcher.stdout.take().unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
let stdout = BufReader::new(launcher_stdout);
|
||||
for line in stdout.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
});
|
||||
|
||||
for _ in 0..60 {
|
||||
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
|
||||
if health.is_ok() {
|
||||
return launcher;
|
||||
}
|
||||
sleep(Duration::from_secs(2));
|
||||
}
|
||||
|
||||
launcher.terminate().unwrap();
|
||||
launcher.wait().unwrap();
|
||||
panic!("failed to launch {}", model_id)
|
||||
}
|
||||
|
||||
fn test_model(
|
||||
model_id: String,
|
||||
num_shard: usize,
|
||||
port: usize,
|
||||
master_port: usize,
|
||||
) -> GeneratedText {
|
||||
let mut launcher = start_launcher(model_id, num_shard, port, master_port);
|
||||
|
||||
let data = r#"
|
||||
{
|
||||
"inputs": "Test request",
|
||||
"parameters": {
|
||||
"details": true
|
||||
}
|
||||
}"#;
|
||||
let req: Value = serde_json::from_str(data).unwrap();
|
||||
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let res = client
|
||||
.post(format!("http://localhost:{}/generate", port))
|
||||
.json(&req)
|
||||
.send();
|
||||
|
||||
launcher.terminate().unwrap();
|
||||
launcher.wait().unwrap();
|
||||
|
||||
let result: GeneratedText = res.unwrap().json().unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
fn read_json(name: &str) -> GeneratedText {
|
||||
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
d.push("tests/");
|
||||
d.push(name);
|
||||
|
||||
let file = File::open(d).unwrap();
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||
assert_eq!(result.generated_text, expected.generated_text);
|
||||
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
|
||||
assert_eq!(
|
||||
result.details.generated_tokens,
|
||||
expected.details.generated_tokens
|
||||
);
|
||||
|
||||
for (token, expected_token) in result
|
||||
.details
|
||||
.tokens
|
||||
.into_iter()
|
||||
.zip(expected.details.tokens.into_iter())
|
||||
{
|
||||
assert_eq!(token.id, expected_token.id);
|
||||
assert_eq!(token.text, expected_token.text);
|
||||
assert_eq!(token.special, expected_token.special);
|
||||
if let Some(logprob) = token.logprob {
|
||||
let expected_logprob = expected_token.logprob.unwrap();
|
||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||
} else {
|
||||
assert_eq!(token.logprob, expected_token.logprob);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bloom_560m() {
|
||||
let expected = read_json("bloom_560m.json");
|
||||
|
||||
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
|
||||
compare_results(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bloom_560m_distributed() {
|
||||
let expected = read_json("bloom_560m.json");
|
||||
|
||||
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
|
||||
compare_results(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mt0_base() {
|
||||
let expected = read_json("mt0_base.json");
|
||||
|
||||
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
|
||||
compare_results(result, expected);
|
||||
}
|
@ -1,137 +0,0 @@
|
||||
{
|
||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"text": "<pad>",
|
||||
"logprob": null
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 259,
|
||||
"text": "",
|
||||
"logprob": -1.3656927,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 215100,
|
||||
"text": "\"\"\"",
|
||||
"logprob": -2.6551573,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 46138,
|
||||
"text": "Test",
|
||||
"logprob": -1.8059857,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"text": " the",
|
||||
"logprob": -1.2102449,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": " ",
|
||||
"logprob": -1.6057279,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 49076,
|
||||
"text": "contents",
|
||||
"logprob": -3.6060903,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " of",
|
||||
"logprob": -0.5270343,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"text": " the",
|
||||
"logprob": -0.62522805,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": " ",
|
||||
"logprob": -1.4069618,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 49076,
|
||||
"text": "contents",
|
||||
"logprob": -2.621994,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " of",
|
||||
"logprob": -1.3172221,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"text": " the",
|
||||
"logprob": -0.3501925,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": " ",
|
||||
"logprob": -0.7219573,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 49076,
|
||||
"text": "contents",
|
||||
"logprob": -1.0494149,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 260,
|
||||
"text": ".",
|
||||
"logprob": -1.0803378,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": " ",
|
||||
"logprob": -0.32933083,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 215100,
|
||||
"text": "\"\"\"",
|
||||
"logprob": -0.11268901,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2978,
|
||||
"text": " test",
|
||||
"logprob": -1.5846587,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 290,
|
||||
"text": "_",
|
||||
"logprob": -0.49796978,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4125,
|
||||
"text": "test",
|
||||
"logprob": -2.0026445,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -30,7 +30,6 @@ from typing import Optional
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
|
@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashNeoX(FlashCausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
super(FlashNeoX, self).__init__(
|
||||
FlashGPTNeoXForCausalLM, model_id, revision, quantize
|
||||
)
|
||||
@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM):
|
||||
|
||||
class FlashNeoXSharded(FlashNeoX):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -27,7 +27,12 @@ tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashSantacoder(FlashCausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
@ -170,7 +175,10 @@ class FlashSantacoder(FlashCausalLM):
|
||||
|
||||
class FlashSantacoderSharded(FlashSantacoder):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -48,7 +48,10 @@ class OPT(CausalLM):
|
||||
|
||||
class OPTSharded(OPT):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
Loading…
Reference in New Issue
Block a user