text-generation-inference/integration-tests/models/test_neox_sharded.py
Nicolas Patry abd58ff82c
feat(server): Rework model loading (#344)
# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
2023-06-08 14:51:52 +02:00

45 lines
1.2 KiB
Python

import pytest
@pytest.fixture(scope="module")
def neox_sharded_handle(launcher):
with launcher(
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox_sharded(neox_sharded_handle):
await neox_sharded_handle.health(300)
return neox_sharded_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
response = await neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
responses = await generate_load(
neox_sharded,
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot