From d69e4d2d1e8318a41d03f271f59d7fc82d7e8343 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 15 May 2023 17:13:33 +0200 Subject: [PATCH] add tests --- .github/workflows/build.yaml | 47 +++- Makefile | 11 +- README.md | 2 + integration-tests/conftest.py | 13 +- .../models/__snapshots__/test_bloom_560m.ambr | 247 ++++++++++++++---- .../test_bloom_560m_sharded.ambr | 147 +++++++---- .../__snapshots__/test_flash_llama.ambr | 195 ++++++++++++++ .../models/__snapshots__/test_flash_neox.ambr | 174 ++++++++++++ .../__snapshots__/test_flash_santacoder.ambr | 153 +++++++---- .../__snapshots__/test_flash_starcoder.ambr | 240 ++++++++++++++--- .../models/__snapshots__/test_mt0_base.ambr | 139 ++++++++++ integration-tests/models/test_bloom_560m.py | 37 ++- .../models/test_bloom_560m_sharded.py | 12 +- integration-tests/models/test_flash_llama.py | 53 ++++ integration-tests/models/test_flash_neox.py | 38 +++ .../models/test_flash_santacoder.py | 4 +- .../models/test_flash_starcoder.py | 22 +- integration-tests/models/test_mt0_base.py | 63 +++++ integration-tests/models/utils.py | 2 +- launcher/tests/bloom_560m.json | 142 ---------- launcher/tests/integration_tests.rs | 172 ------------ launcher/tests/mt0_base.json | 137 ---------- .../custom_modeling/flash_llama_modeling.py | 1 - .../models/flash_llama.py | 7 +- .../models/flash_neox.py | 12 +- .../models/flash_santacoder.py | 12 +- server/text_generation_server/models/opt.py | 5 +- 27 files changed, 1414 insertions(+), 673 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama.ambr create mode 100644 integration-tests/models/__snapshots__/test_flash_neox.ambr create mode 100644 integration-tests/models/__snapshots__/test_mt0_base.ambr create mode 100644 integration-tests/models/test_flash_llama.py create mode 100644 integration-tests/models/test_flash_neox.py create mode 100644 integration-tests/models/test_mt0_base.py delete mode 100644 launcher/tests/bloom_560m.json delete mode 100644 launcher/tests/integration_tests.rs delete mode 100644 launcher/tests/mt0_base.json diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0a99fb52..953b72d6 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,6 +70,8 @@ jobs: # with sigstore/fulcio when running outside of PRs. id-token: write security-events: write + outputs: + image: steps: - name: Checkout repository uses: actions/checkout@v3 @@ -108,7 +110,19 @@ jobs: username: ${{ secrets.AZURE_DOCKER_USERNAME }} password: ${{ secrets.AZURE_DOCKER_PASSWORD }} registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io + # If pull request - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name == 'pull_request' }} + id: meta + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/community/text-generation-inference + tags: | + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} + # If main, release or tag + - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name != 'pull_request' }} id: meta uses: docker/metadata-action@v4.3.0 with: @@ -129,7 +143,7 @@ jobs: with: context: . file: Dockerfile - push: ${{ github.event_name != 'pull_request' }} + push: true platforms: 'linux/amd64' build-args: | GIT_SHA=${{ env.GITHUB_SHA }} @@ -172,11 +186,42 @@ jobs: with: sarif_file: 'trivy-results.sarif' + integration-tests: + needs: + - start-runner + - build-and-push-image # Wait for the docker image to be built + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + env: + DOCKER_IMAGE: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.9 + - name: Tailscale + uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Prepare disks + run: | + sudo mkfs -t ext4 /dev/nvme1n1 + sudo mkdir /data + sudo mount /dev/nvme1n1 /data + sudo chown -R $USER:$USER /data + - name: Install + run: | + make install-integration-tests + - name: Run tests + run: | + make integration-tests + stop-runner: name: Stop self-hosted EC2 runner needs: - start-runner - build-and-push-image + - integration-tests runs-on: ubuntu-latest env: AWS_REGION: us-east-1 diff --git a/Makefile b/Makefile index 032a49de..0d4a2f73 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ install-server: cd server && make install +install-integration-tests: + cd integration-tests && pip install -r requirements.txt + install-router: cd router && cargo install --path . @@ -18,9 +21,15 @@ server-dev: router-dev: cd router && cargo run -- --port 8080 -integration-tests: install-router install-launcher +rust-tests: install-router install-launcher cargo test +integration-tests: install-integration-tests + pytest -s -vv integration-tests + +update-integration-tests: install-integration-tests + pytest -s -vv --snapshot-update integration-tests + python-server-tests: HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests diff --git a/README.md b/README.md index 756d7e35..918ea5e2 100644 --- a/README.md +++ b/README.md @@ -253,5 +253,7 @@ make python-client-tests # or both server and client tests make python-tests # rust cargo tests +make rust-tests +# integration tests make integration-tests ``` diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index b4d35697..68528a2f 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -13,6 +13,7 @@ from text_generation.types import Response DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) +DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") @pytest.fixture(scope="module") @@ -26,7 +27,7 @@ def event_loop(): def launcher(event_loop): @contextlib.contextmanager def local_launcher( - model_id: str, num_shard: Optional[int] = None, quantize: bool = False + model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None ): port = 9999 master_port = 19999 @@ -66,7 +67,7 @@ def launcher(event_loop): @contextlib.contextmanager def docker_launcher( - model_id: str, num_shard: Optional[int] = None, quantize: bool = False + model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None ): port = 9999 @@ -94,6 +95,10 @@ def launcher(event_loop): if HUGGING_FACE_HUB_TOKEN is not None: env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN + volumes = [] + if DOCKER_VOLUME: + volumes = [f"{DOCKER_VOLUME}:/data"] + container = client.containers.run( DOCKER_IMAGE, command=args, @@ -104,7 +109,7 @@ def launcher(event_loop): device_requests=[ docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) ], - volumes=["/data:/data"], + volumes=volumes, ports={"80/tcp": port}, ) @@ -130,6 +135,6 @@ def generate_load(): ] results = await asyncio.gather(*futures) - return results + return [r.generated_text for r in results] return generate_load_inner diff --git a/integration-tests/models/__snapshots__/test_bloom_560m.ambr b/integration-tests/models/__snapshots__/test_bloom_560m.ambr index d81b225e..9aa212e8 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m.ambr +++ b/integration-tests/models/__snapshots__/test_bloom_560m.ambr @@ -7,88 +7,233 @@ 'generated_tokens': 10, 'prefill': list([ dict({ - 'id': 10264, + 'id': 17934, 'logprob': None, - 'text': 'Test', + 'text': 'Pour', }), dict({ - 'id': 8821, - 'logprob': -11.3125, - 'text': ' request', - }), - ]), - 'seed': None, - 'tokens': list([ - dict({ - 'id': 11, - 'logprob': -2.859375, - 'special': False, - 'text': '(', + 'id': 49833, + 'logprob': -10.5625, + 'text': ' dég', }), dict({ - 'id': 5, - 'logprob': -2.34375, - 'special': False, - 'text': '"', + 'id': 21543, + 'logprob': -0.14770508, + 'text': 'uster', }), dict({ - 'id': 1587, - 'logprob': -3.25, - 'special': False, - 'text': 'get', + 'id': 447, + 'logprob': -1.9287109, + 'text': ' un', }), dict({ - 'id': 5, - 'logprob': -1.828125, - 'special': False, - 'text': '"', + 'id': 46341, + 'logprob': -15.4609375, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'logprob': -7.5585938, + 'text': 'olan', }), dict({ 'id': 15, - 'logprob': -0.35546875, - 'special': False, + 'logprob': -1.4003906, 'text': ',', }), dict({ - 'id': 567, - 'logprob': -2.4375, - 'special': False, - 'text': ' "', + 'id': 1669, + 'logprob': -1.5673828, + 'text': ' il', }), dict({ - 'id': 17, - 'logprob': -4.40625, - 'special': False, - 'text': '.', + 'id': 11580, + 'logprob': -0.94628906, + 'text': ' faut', }), dict({ - 'id': 5, - 'logprob': -2.46875, - 'special': False, - 'text': '"', + 'id': 3913, + 'logprob': -3.703125, + 'text': ' tout', }), dict({ - 'id': 12, - 'logprob': -1.6015625, + 'id': 39261, + 'logprob': -1.5732422, + 'text': " d'abord", + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 578, + 'logprob': -1.6591797, 'special': False, - 'text': ')', + 'text': ' le', }), dict({ - 'id': 30, - 'logprob': -0.14550781, + 'id': 5608, + 'logprob': -2.4492188, 'special': False, - 'text': ';', + 'text': ' faire', + }), + dict({ + 'id': 159570, + 'logprob': -6.6835938, + 'special': False, + 'text': ' réch', + }), + dict({ + 'id': 810, + 'logprob': 0.0, + 'special': False, + 'text': 'au', + }), + dict({ + 'id': 12736, + 'logprob': 0.0, + 'special': False, + 'text': 'ffer', + }), + dict({ + 'id': 1742, + 'logprob': -2.5175781, + 'special': False, + 'text': ' au', + }), + dict({ + 'id': 6105, + 'logprob': -2.0078125, + 'special': False, + 'text': ' bain', + }), + dict({ + 'id': 88254, + 'logprob': -0.12695312, + 'special': False, + 'text': '-mar', + }), + dict({ + 'id': 641, + 'logprob': 0.0, + 'special': False, + 'text': 'ie', + }), + dict({ + 'id': 2940, + 'logprob': -3.5175781, + 'special': False, + 'text': ' avec', }), ]), }), - 'generated_text': '("get", ".");', + 'generated_text': ' le faire réchauffer au bain-marie avec', + }) +# --- +# name: test_bloom_560m_all_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 15, + 'logprob': None, + 'text': ',', + }), + dict({ + 'id': 1669, + 'logprob': -5.4414062, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'logprob': -2.3378906, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'logprob': -4.3554688, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'logprob': -2.9238281, + 'text': " d'abord", + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 408, + 'logprob': -1.9267578, + 'special': False, + 'text': ' que', + }), + dict({ + 'id': 20288, + 'logprob': -2.9257812, + 'special': False, + 'text': " l'on", + }), + dict({ + 'id': 22255, + 'logprob': -2.8964844, + 'special': False, + 'text': ' trouve', + }), + dict({ + 'id': 1622, + 'logprob': -1.1083984, + 'special': False, + 'text': ' une', + }), + dict({ + 'id': 187079, + 'logprob': -7.796875, + 'special': False, + 'text': ' posture', + }), + dict({ + 'id': 501, + 'logprob': -5.390625, + 'special': False, + 'text': ' par', + }), + dict({ + 'id': 8741, + 'logprob': -0.34936523, + 'special': False, + 'text': ' rapport', + }), + dict({ + 'id': 693, + 'logprob': 0.0, + 'special': False, + 'text': ' à', + }), + dict({ + 'id': 366, + 'logprob': -2.3378906, + 'special': False, + 'text': ' la', + }), + dict({ + 'id': 36503, + 'logprob': -3.6640625, + 'special': False, + 'text': ' pratique', + }), + ]), + }), + 'generated_text': "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique", }) # --- # name: test_bloom_560m_load list([ - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), + " le faire cuire dans de l'eau bouillante sal", + " le faire cuire dans de l'eau bouillante sal", + " le faire cuire dans de l'eau bouillante sal", + " le faire cuire dans de l'eau bouillante sal", ]) # --- diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr index 7cf26255..1c842ddc 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr @@ -7,88 +7,133 @@ 'generated_tokens': 10, 'prefill': list([ dict({ - 'id': 10264, + 'id': 17934, 'logprob': None, - 'text': 'Test', + 'text': 'Pour', }), dict({ - 'id': 8821, - 'logprob': -11.375, - 'text': ' request', - }), - ]), - 'seed': None, - 'tokens': list([ - dict({ - 'id': 11, - 'logprob': -2.734375, - 'special': False, - 'text': '(', + 'id': 49833, + 'logprob': -10.5390625, + 'text': ' dég', }), dict({ - 'id': 5, - 'logprob': -1.9765625, - 'special': False, - 'text': '"', + 'id': 21543, + 'logprob': -0.14758301, + 'text': 'uster', }), dict({ - 'id': 1587, - 'logprob': -3.140625, - 'special': False, - 'text': 'get', + 'id': 447, + 'logprob': -1.9296875, + 'text': ' un', }), dict({ - 'id': 5, - 'logprob': -3.515625, - 'special': False, - 'text': '"', + 'id': 46341, + 'logprob': -15.4453125, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'logprob': -7.59375, + 'text': 'olan', }), dict({ 'id': 15, - 'logprob': -0.37304688, - 'special': False, + 'logprob': -1.3994141, 'text': ',', }), dict({ - 'id': 567, - 'logprob': -2.6875, - 'special': False, - 'text': ' "', + 'id': 1669, + 'logprob': -1.578125, + 'text': ' il', }), dict({ - 'id': 17, - 'logprob': -4.65625, - 'special': False, - 'text': '.', + 'id': 11580, + 'logprob': -0.9453125, + 'text': ' faut', }), dict({ - 'id': 5, - 'logprob': -2.28125, - 'special': False, - 'text': '"', + 'id': 3913, + 'logprob': -3.7011719, + 'text': ' tout', }), dict({ - 'id': 12, - 'logprob': -1.6640625, + 'id': 39261, + 'logprob': -1.5732422, + 'text': " d'abord", + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 578, + 'logprob': -1.6474609, 'special': False, - 'text': ')', + 'text': ' le', }), dict({ - 'id': 30, - 'logprob': -0.14355469, + 'id': 5608, + 'logprob': -2.5097656, 'special': False, - 'text': ';', + 'text': ' faire', + }), + dict({ + 'id': 159570, + 'logprob': -6.65625, + 'special': False, + 'text': ' réch', + }), + dict({ + 'id': 810, + 'logprob': 0.0, + 'special': False, + 'text': 'au', + }), + dict({ + 'id': 12736, + 'logprob': 0.0, + 'special': False, + 'text': 'ffer', + }), + dict({ + 'id': 1742, + 'logprob': -2.5859375, + 'special': False, + 'text': ' au', + }), + dict({ + 'id': 6105, + 'logprob': -2.03125, + 'special': False, + 'text': ' bain', + }), + dict({ + 'id': 88254, + 'logprob': -0.12695312, + 'special': False, + 'text': '-mar', + }), + dict({ + 'id': 641, + 'logprob': 0.0, + 'special': False, + 'text': 'ie', + }), + dict({ + 'id': 2940, + 'logprob': -3.5175781, + 'special': False, + 'text': ' avec', }), ]), }), - 'generated_text': '("get", ".");', + 'generated_text': ' le faire réchauffer au bain-marie avec', }) # --- # name: test_bloom_560m_sharded_load list([ - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), - Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), + " le faire cuire dans de l'eau bouillante sal", + " le faire cuire dans de l'eau bouillante sal", + " le faire cuire dans de l'eau bouillante sal", + " le faire cuire dans de l'eau bouillante sal", ]) # --- diff --git a/integration-tests/models/__snapshots__/test_flash_llama.ambr b/integration-tests/models/__snapshots__/test_flash_llama.ambr new file mode 100644 index 00000000..2fde1b01 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama.ambr @@ -0,0 +1,195 @@ +# serializer version: 1 +# name: test_flash_llama + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'logprob': None, + 'text': '', + }), + dict({ + 'id': 4321, + 'logprob': -8.6875, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'logprob': -11.5546875, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'logprob': -1.5380859, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'logprob': -2.5917969, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'logprob': -2.2773438, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'logprob': -0.034362793, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'logprob': -0.96533203, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'logprob': -0.36669922, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'logprob': -0.013122559, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'logprob': -3.1503906, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'logprob': -0.43652344, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'logprob': -1.9404297, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }) +# --- +# name: test_flash_llama_all_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'logprob': None, + 'text': '', + }), + dict({ + 'id': 4321, + 'logprob': -8.6875, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'logprob': -11.5546875, + 'text': 'request', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 5229, + 'logprob': -3.3085938, + 'special': False, + 'text': ' failed', + }), + dict({ + 'id': 363, + 'logprob': -3.984375, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 5641, + 'logprob': -6.53125, + 'special': False, + 'text': ' IP', + }), + dict({ + 'id': 16428, + 'logprob': -3.1835938, + 'special': False, + 'text': ' Address', + }), + dict({ + 'id': 29901, + 'logprob': -1.2324219, + 'special': False, + 'text': ':', + }), + dict({ + 'id': 525, + 'logprob': -2.6855469, + 'special': False, + 'text': " '", + }), + dict({ + 'id': 8516, + 'logprob': -7.1601562, + 'special': False, + 'text': 'None', + }), + dict({ + 'id': 4286, + 'logprob': -2.4433594, + 'special': False, + 'text': "'.", + }), + dict({ + 'id': 13, + 'logprob': -0.06530762, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 294, + 'logprob': -7.953125, + 'special': False, + 'text': 'as', + }), + ]), + }), + 'generated_text': ''' + Test requestfailed for IP Address: 'None'. + as + ''', + }) +# --- +# name: test_flash_llama_load + list([ + 'for /api/v1/projects/1', + 'for /api/v1/projects/1', + 'for /api/v1/projects/1', + 'for /api/v1/projects/1', + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_neox.ambr b/integration-tests/models/__snapshots__/test_flash_neox.ambr new file mode 100644 index 00000000..671a1c0c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_neox.ambr @@ -0,0 +1,174 @@ +# serializer version: 1 +# name: test_flash_neox + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'logprob': None, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'logprob': -8.03125, + 'text': 'What', + }), + dict({ + 'id': 310, + 'logprob': -5.421875, + 'text': ' is', + }), + dict({ + 'id': 247, + 'logprob': -2.1601562, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'logprob': -5.4609375, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'logprob': -0.005657196, + 'text': 'e', + }), + dict({ + 'id': 13, + 'logprob': -7.28125, + 'text': ',', + }), + dict({ + 'id': 285, + 'logprob': -0.2980957, + 'text': ' and', + }), + dict({ + 'id': 752, + 'logprob': -2.1679688, + 'text': ' what', + }), + dict({ + 'id': 434, + 'logprob': -5.6210938, + 'text': "'s", + }), + dict({ + 'id': 253, + 'logprob': -0.81103516, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'logprob': -6.6640625, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'logprob': -2.265625, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'logprob': -11.5078125, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'logprob': -2.1582031, + 'text': ' word', + }), + dict({ + 'id': 32, + 'logprob': -0.008720398, + 'text': '?', + }), + dict({ + 'id': 0, + 'logprob': -2.4726562, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'logprob': -18.265625, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'logprob': -0.63183594, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'logprob': -0.5390625, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'logprob': -0.045684814, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'logprob': -0.002090454, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'logprob': -1.3589859e-05, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'logprob': -0.0009455681, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'logprob': -0.088012695, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'logprob': -0.12585449, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'logprob': -0.017196655, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'logprob': -0.49731445, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }) +# --- +# name: test_flash_neox_load + list([ + 'The word "meme" was first used in', + 'The word "meme" was first used in', + 'The word "meme" was first used in', + 'The word "meme" was first used in', + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_santacoder.ambr b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr index bb5dfbba..a6a8e599 100644 --- a/integration-tests/models/__snapshots__/test_flash_santacoder.ambr +++ b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr @@ -7,95 +7,132 @@ 'generated_tokens': 10, 'prefill': list([ dict({ - 'id': 804, + 'id': 563, 'logprob': None, - 'text': 'Test', + 'text': 'def', + }), + dict({ + 'id': 942, + 'logprob': -5.1367188, + 'text': ' print', + }), + dict({ + 'id': 62, + 'logprob': -0.24450684, + 'text': '_', + }), + dict({ + 'id': 7196, + 'logprob': -6.9609375, + 'text': 'hello', }), ]), 'seed': None, 'tokens': list([ dict({ - 'id': 25, - 'logprob': -2.828125, + 'id': 1241, + 'logprob': -0.9863281, 'special': False, - 'text': ':', + 'text': '():', }), dict({ - 'id': 287, - 'logprob': -1.5703125, - 'special': False, - 'text': ' "', - }), - dict({ - 'id': 385, - 'logprob': -0.03955078, - 'special': False, - 'text': ' +', - }), - dict({ - 'id': 1028, - 'logprob': -1.453125, - 'special': False, - 'text': ' request', - }), - dict({ - 'id': 13, - 'logprob': -0.796875, - 'special': False, - 'text': '.', - }), - dict({ - 'id': 1832, - 'logprob': -1.4296875, - 'special': False, - 'text': 'toString', - }), - dict({ - 'id': 782, - 'logprob': -0.17871094, - 'special': False, - 'text': '());', - }), - dict({ - 'id': 259, - 'logprob': -1.359375, + 'id': 258, + 'logprob': -0.21447754, 'special': False, 'text': ''' - + ''', }), dict({ - 'id': 294, - 'logprob': -1.6484375, + 'id': 942, + 'logprob': -0.43701172, 'special': False, - 'text': ' }', + 'text': ' print', }), dict({ - 'id': 437, - 'logprob': -1.0625, + 'id': 372, + 'logprob': -0.5361328, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'logprob': -0.44555664, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'logprob': -1.2412109, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'logprob': -0.7583008, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'logprob': -0.76171875, 'special': False, 'text': ''' - - + ''', }), + dict({ + 'id': 185, + 'logprob': -0.20837402, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'logprob': -1.2470703, + 'special': False, + 'text': 'print', + }), ]), }), 'generated_text': ''' - : " + request.toString()); - } + (): + print("Hello World!") - + print ''', }) # --- # name: test_flash_santacoder_load list([ - Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)), - Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)), - Response(generated_text=': " + request.toString());\n }\n\n ', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=287, text=' "', logprob=-1.5859375, special=False), Token(id=385, text=' +', logprob=-0.037841797, special=False), Token(id=1028, text=' request', logprob=-1.4453125, special=False), Token(id=13, text='.', logprob=-0.79296875, special=False), Token(id=1832, text='toString', logprob=-1.4375, special=False), Token(id=782, text='());', logprob=-0.19335938, special=False), Token(id=259, text='\n ', logprob=-1.359375, special=False), Token(id=294, text=' }', logprob=-1.609375, special=False), Token(id=437, text='\n\n ', logprob=-1.0546875, special=False)], best_of_sequences=None)), - Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)), + ''' + (): + print("Hello World!") + + print + ''', + ''' + (): + print("Hello World!") + + print + ''', + ''' + (): + print("Hello World!") + + print + ''', + ''' + (): + print("Hello World!") + + print + ''', ]) # --- diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder.ambr b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr index 2ecdbf6c..65e14581 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder.ambr +++ b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr @@ -7,83 +7,249 @@ 'generated_tokens': 10, 'prefill': list([ dict({ - 'id': 1006, + 'id': 589, 'logprob': None, - 'text': 'Test', + 'text': 'def', + }), + dict({ + 'id': 1459, + 'logprob': -5.6289062, + 'text': ' print', + }), + dict({ + 'id': 81, + 'logprob': -1.6005859, + 'text': '_', + }), + dict({ + 'id': 7656, + 'logprob': -5.9921875, + 'text': 'hello', }), ]), 'seed': None, 'tokens': list([ dict({ - 'id': 30, - 'logprob': -2.734375, + 'id': 2262, + 'logprob': -0.7705078, 'special': False, - 'text': ',', + 'text': '():', }), dict({ - 'id': 892, - 'logprob': -2.828125, + 'id': 284, + 'logprob': -0.2590332, 'special': False, - 'text': ' String', + 'text': ''' + + + ''', }), dict({ - 'id': 1984, - 'logprob': -3.28125, + 'id': 1459, + 'logprob': -0.39379883, 'special': False, - 'text': ' url', + 'text': ' print', }), dict({ - 'id': 30, - 'logprob': -0.796875, + 'id': 440, + 'logprob': -0.61376953, 'special': False, - 'text': ',', + 'text': '("', }), dict({ - 'id': 892, - 'logprob': -1.0390625, + 'id': 8279, + 'logprob': -0.47338867, 'special': False, - 'text': ' String', + 'text': 'Hello', }), dict({ - 'id': 1411, - 'logprob': -2.078125, + 'id': 10896, + 'logprob': -1.5068359, 'special': False, - 'text': ' method', + 'text': ' World', }), dict({ - 'id': 30, - 'logprob': -0.49609375, + 'id': 657, + 'logprob': -0.80810547, 'special': False, - 'text': ',', + 'text': '")', }), dict({ - 'id': 892, - 'logprob': -1.0546875, + 'id': 203, + 'logprob': -0.7397461, 'special': False, - 'text': ' String', + 'text': ''' + + + ''', }), dict({ - 'id': 3361, - 'logprob': -1.71875, + 'id': 203, + 'logprob': -0.35229492, 'special': False, - 'text': ' body', + 'text': ''' + + + ''', }), dict({ - 'id': 27, - 'logprob': -0.69921875, + 'id': 589, + 'logprob': -1.0371094, 'special': False, - 'text': ')', + 'text': 'def', }), ]), }), - 'generated_text': ', String url, String method, String body)', + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }) +# --- +# name: test_flash_starcoder_default_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 12, + 'prefill': list([ + dict({ + 'id': 589, + 'logprob': None, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'logprob': -5.6289062, + 'text': ' print', + }), + dict({ + 'id': 81, + 'logprob': -1.6005859, + 'text': '_', + }), + dict({ + 'id': 7656, + 'logprob': -5.9921875, + 'text': 'hello', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 2262, + 'logprob': -0.7451172, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'logprob': -0.21325684, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 5741, + 'logprob': -5.734375, + 'special': False, + 'text': ' logging', + }), + dict({ + 'id': 32, + 'logprob': 0.0, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 1338, + 'logprob': -0.3232422, + 'special': False, + 'text': 'info', + }), + dict({ + 'id': 463, + 'logprob': -1.0380859, + 'special': False, + 'text': "('", + }), + dict({ + 'id': 8279, + 'logprob': -0.8378906, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 30, + 'logprob': -1.9501953, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 10896, + 'logprob': -1.3476562, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 683, + 'logprob': -1.796875, + 'special': False, + 'text': "')", + }), + dict({ + 'id': 203, + 'logprob': -0.9873047, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 0, + 'logprob': -0.7495117, + 'special': True, + 'text': '<|endoftext|>', + }), + ]), + }), + 'generated_text': ''' + (): + logging.info('Hello, World') + <|endoftext|> + ''', }) # --- # name: test_flash_starcoder_load list([ - Response(generated_text=', String url, String method, String body)', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=892, text=' String', logprob=-2.84375, special=False), Token(id=1984, text=' url', logprob=-3.28125, special=False), Token(id=30, text=',', logprob=-0.796875, special=False), Token(id=892, text=' String', logprob=-1.03125, special=False), Token(id=1411, text=' method', logprob=-2.09375, special=False), Token(id=30, text=',', logprob=-0.49804688, special=False), Token(id=892, text=' String', logprob=-1.0546875, special=False), Token(id=3361, text=' body', logprob=-1.7265625, special=False), Token(id=27, text=')', logprob=-0.703125, special=False)], best_of_sequences=None)), - Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)), - Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)), - Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)), + ''' + (): + print("Hello World") + + def + ''', + ''' + (): + print("Hello World") + + def + ''', + ''' + (): + print("Hello World") + + def + ''', + ''' + (): + print("Hello World") + + def + ''', ]) # --- diff --git a/integration-tests/models/__snapshots__/test_mt0_base.ambr b/integration-tests/models/__snapshots__/test_mt0_base.ambr new file mode 100644 index 00000000..dc974891 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mt0_base.ambr @@ -0,0 +1,139 @@ +# serializer version: 1 +# name: test_mt0_base + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 5, + 'prefill': list([ + dict({ + 'id': 0, + 'logprob': None, + 'text': '', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 926, + 'logprob': -4.3554688, + 'special': False, + 'text': 'To', + }), + dict({ + 'id': 18295, + 'logprob': -7.7734375, + 'special': False, + 'text': ' sell', + }), + dict({ + 'id': 7868, + 'logprob': -3.9257812, + 'special': False, + 'text': ' things', + }), + dict({ + 'id': 260, + 'logprob': -2.4179688, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 1, + 'logprob': 0.0, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'To sell things.', + }) +# --- +# name: test_mt0_base_all_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 0, + 'logprob': None, + 'text': '', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 16017, + 'logprob': -1.3505859, + 'special': False, + 'text': 'blue', + }), + dict({ + 'id': 20495, + 'logprob': -0.50439453, + 'special': False, + 'text': ' sky', + }), + dict({ + 'id': 259, + 'logprob': -1.2011719, + 'special': False, + 'text': ' ', + }), + dict({ + 'id': 15484, + 'logprob': -2.8378906, + 'special': False, + 'text': 'appear', + }), + dict({ + 'id': 345, + 'logprob': -0.87597656, + 'special': False, + 'text': 'ed', + }), + dict({ + 'id': 288, + 'logprob': -1.8447266, + 'special': False, + 'text': ' to', + }), + dict({ + 'id': 35622, + 'logprob': -7.1445312, + 'special': False, + 'text': ' cloud', + }), + dict({ + 'id': 263, + 'logprob': -1.2929688, + 'special': False, + 'text': 's', + }), + dict({ + 'id': 14701, + 'logprob': -3.0761719, + 'special': False, + 'text': ' above', + }), + dict({ + 'id': 751, + 'logprob': -4.4375, + 'special': False, + 'text': ' all', + }), + ]), + }), + 'generated_text': 'Why is the sky blue?blue sky appeared to clouds above all', + }) +# --- +# name: test_mt0_base_load + list([ + 'Because it is blue', + 'Because it is blue', + 'Because it is blue', + 'Because it is blue', + ]) +# --- diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index f29cdb2d..39850cad 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -13,7 +13,35 @@ def bloom_560(launcher): async def test_bloom_560m(bloom_560, snapshot): await health_check(bloom_560, 60) - response = await bloom_560.generate("Test request", max_new_tokens=10) + response = await bloom_560.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + top_p=0.9, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_bloom_560m_all_params(bloom_560, snapshot): + await health_check(bloom_560, 60) + + response = await bloom_560.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) assert response.details.generated_tokens == 10 assert response == snapshot @@ -23,7 +51,12 @@ async def test_bloom_560m(bloom_560, snapshot): async def test_bloom_560m_load(bloom_560, generate_load, snapshot): await health_check(bloom_560, 60) - responses = await generate_load(bloom_560, "Test request", max_new_tokens=10, n=4) + responses = await generate_load( + bloom_560, + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + n=4, + ) assert len(responses) == 4 diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index e3d1cbe0..89d95a23 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -13,7 +13,12 @@ def bloom_560m_sharded(launcher): async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): await health_check(bloom_560m_sharded, 60) - response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10) + response = await bloom_560m_sharded.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + top_p=0.9, + seed=0, + ) assert response.details.generated_tokens == 10 assert response == snapshot @@ -24,7 +29,10 @@ async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapsh await health_check(bloom_560m_sharded, 60) responses = await generate_load( - bloom_560m_sharded, "Test request", max_new_tokens=10, n=4 + bloom_560m_sharded, + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + n=4, ) assert len(responses) == 4 diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py new file mode 100644 index 00000000..899a26bf --- /dev/null +++ b/integration-tests/models/test_flash_llama.py @@ -0,0 +1,53 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_llama(launcher): + with launcher("huggingface/llama-7b", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +async def test_flash_llama(flash_llama, snapshot): + await health_check(flash_llama, 120) + + response = await flash_llama.generate("Test request", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_flash_llama_all_params(flash_llama, snapshot): + await health_check(flash_llama, 120) + + response = await flash_llama.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_flash_llama_load(flash_llama, generate_load, snapshot): + await health_check(flash_llama, 120) + + responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + + assert responses == snapshot diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py new file mode 100644 index 00000000..42d8182a --- /dev/null +++ b/integration-tests/models/test_flash_neox.py @@ -0,0 +1,38 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_neox(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +async def test_flash_neox(flash_neox, snapshot): + await health_check(flash_neox, 240) + + response = await flash_neox.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + ) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_flash_neox_load(flash_neox, generate_load, snapshot): + await health_check(flash_neox, 240) + + responses = await generate_load( + flash_neox, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + + assert responses == snapshot diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index ea357b64..8ee44839 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -13,7 +13,7 @@ def flash_santacoder(launcher): async def test_flash_santacoder(flash_santacoder, snapshot): await health_check(flash_santacoder, 60) - response = await flash_santacoder.generate("Test request", max_new_tokens=10) + response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) assert response.details.generated_tokens == 10 assert response == snapshot @@ -24,7 +24,7 @@ async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot): await health_check(flash_santacoder, 60) responses = await generate_load( - flash_santacoder, "Test request", max_new_tokens=10, n=4 + flash_santacoder, "def print_hello", max_new_tokens=10, n=4 ) assert len(responses) == 4 diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 406caa6c..f5d2a47a 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -5,26 +5,38 @@ from utils import health_check @pytest.fixture(scope="module") def flash_starcoder(launcher): - with launcher("bigcode/large-model", num_shard=2) as client: + with launcher("bigcode/starcoder", num_shard=2) as client: yield client @pytest.mark.asyncio async def test_flash_starcoder(flash_starcoder, snapshot): - await health_check(flash_starcoder, 60) + await health_check(flash_starcoder, 240) - response = await flash_starcoder.generate("Test request", max_new_tokens=10) + response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) assert response.details.generated_tokens == 10 assert response == snapshot +@pytest.mark.asyncio +async def test_flash_starcoder_default_params(flash_starcoder, snapshot): + await health_check(flash_starcoder, 240) + + response = await flash_starcoder.generate( + "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 + ) + + assert response.details.generated_tokens == 12 + assert response == snapshot + + @pytest.mark.asyncio async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot): - await health_check(flash_starcoder, 60) + await health_check(flash_starcoder, 240) responses = await generate_load( - flash_starcoder, "Test request", max_new_tokens=10, n=4 + flash_starcoder, "def print_hello", max_new_tokens=10, n=4 ) assert len(responses) == 4 diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py new file mode 100644 index 00000000..70ac470a --- /dev/null +++ b/integration-tests/models/test_mt0_base.py @@ -0,0 +1,63 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def mt0_base(launcher): + with launcher("bigscience/mt0-base") as client: + yield client + + +@pytest.mark.asyncio +async def test_mt0_base(mt0_base, snapshot): + await health_check(mt0_base, 60) + + response = await mt0_base.generate( + "Why is the sky blue?", + max_new_tokens=10, + top_p=0.9, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_mt0_base_all_params(mt0_base, snapshot): + await health_check(mt0_base, 60) + + response = await mt0_base.generate( + "Why is the sky blue?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_mt0_base_load(mt0_base, generate_load, snapshot): + await health_check(mt0_base, 60) + + responses = await generate_load( + mt0_base, + "Why is the sky blue?", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + + assert responses == snapshot diff --git a/integration-tests/models/utils.py b/integration-tests/models/utils.py index 7ef04d7f..c47e4871 100644 --- a/integration-tests/models/utils.py +++ b/integration-tests/models/utils.py @@ -12,4 +12,4 @@ async def health_check(client: AsyncClient, timeout: int = 60): return except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: time.sleep(1) - raise e + raise RuntimeError("Health check failed") diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json deleted file mode 100644 index 96f89f6b..00000000 --- a/launcher/tests/bloom_560m.json +++ /dev/null @@ -1,142 +0,0 @@ -{ - "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException", - "details": { - "finish_reason": "length", - "generated_tokens": 20, - "seed": null, - "prefill": [ - { - "id": 10264, - "text": "Test", - "logprob": null - }, - { - "id": 8821, - "text": " request", - "logprob": -11.894989 - } - ], - "tokens": [ - { - "id": 17, - "text": ".", - "logprob": -1.8267672, - "special": false - }, - { - "id": 1587, - "text": "get", - "logprob": -2.4674969, - "special": false - }, - { - "id": 11, - "text": "(", - "logprob": -1.906001, - "special": false - }, - { - "id": 5, - "text": "\"", - "logprob": -1.2279545, - "special": false - }, - { - "id": 4899, - "text": "action", - "logprob": -4.170299, - "special": false - }, - { - "id": 5, - "text": "\"", - "logprob": -0.32478866, - "special": false - }, - { - "id": 12, - "text": ")", - "logprob": -1.0773665, - "special": false - }, - { - "id": 30, - "text": ";", - "logprob": -0.27640742, - "special": false - }, - { - "id": 837, - "text": "\n ", - "logprob": -1.6970354, - "special": false - }, - { - "id": 1320, - "text": " if", - "logprob": -1.4495516, - "special": false - }, - { - "id": 375, - "text": " (", - "logprob": -0.23609057, - "special": false - }, - { - "id": 4899, - "text": "action", - "logprob": -1.1916996, - "special": false - }, - { - "id": 3535, - "text": " ==", - "logprob": -0.8918753, - "special": false - }, - { - "id": 5109, - "text": " null", - "logprob": -0.3933342, - "special": false - }, - { - "id": 12, - "text": ")", - "logprob": -0.43212673, - "special": false - }, - { - "id": 731, - "text": " {", - "logprob": -0.17702064, - "special": false - }, - { - "id": 1260, - "text": "\n ", - "logprob": -0.07027565, - "special": false - }, - { - "id": 10519, - "text": " throw", - "logprob": -1.3915029, - "special": false - }, - { - "id": 2084, - "text": " new", - "logprob": -0.04201372, - "special": false - }, - { - "id": 150858, - "text": " RuntimeException", - "logprob": -1.7329919, - "special": false - } - ] - } -} \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs deleted file mode 100644 index 0d2b6c74..00000000 --- a/launcher/tests/integration_tests.rs +++ /dev/null @@ -1,172 +0,0 @@ -use float_eq::assert_float_eq; -use serde::Deserialize; -use serde_json::Value; -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::path::PathBuf; -use std::thread; -use std::thread::sleep; -use std::time::Duration; -use subprocess::{Popen, PopenConfig, Redirection}; - -#[derive(Deserialize)] -pub struct Token { - id: u32, - text: String, - logprob: Option, - special: bool, -} - -#[derive(Deserialize)] -struct Details { - finish_reason: String, - generated_tokens: u32, - tokens: Vec, -} - -#[derive(Deserialize)] -struct GeneratedText { - generated_text: String, - details: Details, -} - -fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen { - let argv = vec![ - "text-generation-launcher".to_string(), - "--model-id".to_string(), - model_id.clone(), - "--num-shard".to_string(), - num_shard.to_string(), - "--port".to_string(), - port.to_string(), - "--master-port".to_string(), - master_port.to_string(), - "--shard-uds-path".to_string(), - format!("/tmp/test-{}-{}-{}", num_shard, port, master_port), - ]; - - let mut launcher = Popen::create( - &argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Merge, - ..Default::default() - }, - ) - .expect("Could not start launcher"); - - // Redirect STDOUT and STDERR to the console - // (STDERR is merged into STDOUT) - let launcher_stdout = launcher.stdout.take().unwrap(); - - thread::spawn(move || { - let stdout = BufReader::new(launcher_stdout); - for line in stdout.lines() { - println!("{}", line.unwrap()); - } - }); - - for _ in 0..60 { - let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); - if health.is_ok() { - return launcher; - } - sleep(Duration::from_secs(2)); - } - - launcher.terminate().unwrap(); - launcher.wait().unwrap(); - panic!("failed to launch {}", model_id) -} - -fn test_model( - model_id: String, - num_shard: usize, - port: usize, - master_port: usize, -) -> GeneratedText { - let mut launcher = start_launcher(model_id, num_shard, port, master_port); - - let data = r#" - { - "inputs": "Test request", - "parameters": { - "details": true - } - }"#; - let req: Value = serde_json::from_str(data).unwrap(); - - let client = reqwest::blocking::Client::new(); - let res = client - .post(format!("http://localhost:{}/generate", port)) - .json(&req) - .send(); - - launcher.terminate().unwrap(); - launcher.wait().unwrap(); - - let result: GeneratedText = res.unwrap().json().unwrap(); - result -} - -fn read_json(name: &str) -> GeneratedText { - let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - d.push("tests/"); - d.push(name); - - let file = File::open(d).unwrap(); - let reader = BufReader::new(file); - - let result: GeneratedText = serde_json::from_reader(reader).unwrap(); - result -} - -fn compare_results(result: GeneratedText, expected: GeneratedText) { - assert_eq!(result.generated_text, expected.generated_text); - assert_eq!(result.details.finish_reason, expected.details.finish_reason); - assert_eq!( - result.details.generated_tokens, - expected.details.generated_tokens - ); - - for (token, expected_token) in result - .details - .tokens - .into_iter() - .zip(expected.details.tokens.into_iter()) - { - assert_eq!(token.id, expected_token.id); - assert_eq!(token.text, expected_token.text); - assert_eq!(token.special, expected_token.special); - if let Some(logprob) = token.logprob { - let expected_logprob = expected_token.logprob.unwrap(); - assert_float_eq!(logprob, expected_logprob, abs <= 0.001); - } else { - assert_eq!(token.logprob, expected_token.logprob); - } - } -} - -#[test] -fn test_bloom_560m() { - let expected = read_json("bloom_560m.json"); - - let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500); - compare_results(result, expected); -} - -#[test] -fn test_bloom_560m_distributed() { - let expected = read_json("bloom_560m.json"); - - let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501); - compare_results(result, expected); -} - -#[test] -fn test_mt0_base() { - let expected = read_json("mt0_base.json"); - - let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502); - compare_results(result, expected); -} diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json deleted file mode 100644 index f5be63f9..00000000 --- a/launcher/tests/mt0_base.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test", - "details": { - "finish_reason": "length", - "generated_tokens": 20, - "seed": null, - "prefill": [ - { - "id": 0, - "text": "", - "logprob": null - } - ], - "tokens": [ - { - "id": 259, - "text": "", - "logprob": -1.3656927, - "special": false - }, - { - "id": 215100, - "text": "\"\"\"", - "logprob": -2.6551573, - "special": false - }, - { - "id": 46138, - "text": "Test", - "logprob": -1.8059857, - "special": false - }, - { - "id": 287, - "text": " the", - "logprob": -1.2102449, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -1.6057279, - "special": false - }, - { - "id": 49076, - "text": "contents", - "logprob": -3.6060903, - "special": false - }, - { - "id": 304, - "text": " of", - "logprob": -0.5270343, - "special": false - }, - { - "id": 287, - "text": " the", - "logprob": -0.62522805, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -1.4069618, - "special": false - }, - { - "id": 49076, - "text": "contents", - "logprob": -2.621994, - "special": false - }, - { - "id": 304, - "text": " of", - "logprob": -1.3172221, - "special": false - }, - { - "id": 287, - "text": " the", - "logprob": -0.3501925, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -0.7219573, - "special": false - }, - { - "id": 49076, - "text": "contents", - "logprob": -1.0494149, - "special": false - }, - { - "id": 260, - "text": ".", - "logprob": -1.0803378, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -0.32933083, - "special": false - }, - { - "id": 215100, - "text": "\"\"\"", - "logprob": -0.11268901, - "special": false - }, - { - "id": 2978, - "text": " test", - "logprob": -1.5846587, - "special": false - }, - { - "id": 290, - "text": "_", - "logprob": -0.49796978, - "special": false - }, - { - "id": 4125, - "text": "test", - "logprob": -2.0026445, - "special": false - } - ] - } -} \ No newline at end of file diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 11f3766e..481fe8a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -30,7 +30,6 @@ from typing import Optional # Flash attention imports import flash_attn_cuda -from flash_attn.layers.rotary import RotaryEmbedding from text_generation_server.utils.layers import ( FastLinear, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index aa0b4483..156fed76 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__) class FlashLlama(FlashCausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index fc741f55..7ae06036 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__) class FlashNeoX(FlashCausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): super(FlashNeoX, self).__init__( FlashGPTNeoXForCausalLM, model_id, revision, quantize ) @@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM): class FlashNeoXSharded(FlashNeoX): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index f810bb0b..2b37bd0f 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -27,7 +27,12 @@ tracer = trace.get_tracer(__name__) class FlashSantacoder(FlashCausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 @@ -170,7 +175,10 @@ class FlashSantacoder(FlashCausalLM): class FlashSantacoderSharded(FlashSantacoder): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index c83c3351..87b64a45 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -48,7 +48,10 @@ class OPT(CausalLM): class OPTSharded(OPT): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available():