mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add integration tests
This commit is contained in:
parent
3e517bfc9d
commit
8f28011e1e
@ -205,7 +205,10 @@ def event_loop():
|
|||||||
def launcher(event_loop):
|
def launcher(event_loop):
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def local_launcher(
|
def local_launcher(
|
||||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
model_id: str,
|
||||||
|
num_shard: Optional[int] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
master_port = random.randint(10_000, 20_000)
|
master_port = random.randint(10_000, 20_000)
|
||||||
@ -230,6 +233,9 @@ def launcher(event_loop):
|
|||||||
args.extend(["--num-shard", str(num_shard)])
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
if quantize:
|
if quantize:
|
||||||
args.append("--quantize")
|
args.append("--quantize")
|
||||||
|
args.append("bitsandbytes")
|
||||||
|
if trust_remote_code:
|
||||||
|
args.append("--trust-remote-code")
|
||||||
|
|
||||||
env = os.environ
|
env = os.environ
|
||||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||||
@ -250,7 +256,10 @@ def launcher(event_loop):
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def docker_launcher(
|
def docker_launcher(
|
||||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
model_id: str,
|
||||||
|
num_shard: Optional[int] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
@ -260,6 +269,9 @@ def launcher(event_loop):
|
|||||||
args.extend(["--num-shard", str(num_shard)])
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
if quantize:
|
if quantize:
|
||||||
args.append("--quantize")
|
args.append("--quantize")
|
||||||
|
args.append("bitsandbytes")
|
||||||
|
if trust_remote_code:
|
||||||
|
args.append("--trust-remote-code")
|
||||||
|
|
||||||
client = docker.from_env()
|
client = docker.from_env()
|
||||||
|
|
||||||
|
@ -0,0 +1,378 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 50,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "G"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -5.96875,
|
||||||
|
"text": "ir"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1622,
|
||||||
|
"logprob": -5.6132812,
|
||||||
|
"text": "af"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 249,
|
||||||
|
"logprob": -6.5039062,
|
||||||
|
"text": "at"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1480,
|
||||||
|
"logprob": -8.078125,
|
||||||
|
"text": "ron"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -2.3261719,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23866,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": " obsessed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 335,
|
||||||
|
"logprob": -0.048339844,
|
||||||
|
"text": " with"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26680,
|
||||||
|
"logprob": -4.0,
|
||||||
|
"text": " gir"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1903,
|
||||||
|
"logprob": -0.07556152,
|
||||||
|
"text": "aff"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 255,
|
||||||
|
"logprob": -0.0067749023,
|
||||||
|
"text": "es"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23,
|
||||||
|
"logprob": -1.546875,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 248,
|
||||||
|
"logprob": -4.3320312,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 758,
|
||||||
|
"logprob": -3.734375,
|
||||||
|
"text": " most"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21735,
|
||||||
|
"logprob": -5.109375,
|
||||||
|
"text": " glorious"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5985,
|
||||||
|
"logprob": -2.09375,
|
||||||
|
"text": " animal"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": -1.1835938,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 248,
|
||||||
|
"logprob": -0.77685547,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1936,
|
||||||
|
"logprob": -2.3828125,
|
||||||
|
"text": " face"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 275,
|
||||||
|
"logprob": -0.004432678,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 414,
|
||||||
|
"logprob": -1.9677734,
|
||||||
|
"text": " this"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6490,
|
||||||
|
"logprob": -2.046875,
|
||||||
|
"text": " Earth"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -0.28198242,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 401,
|
||||||
|
"logprob": -7.9179688,
|
||||||
|
"text": " G"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6013,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"text": "ira"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 694,
|
||||||
|
"logprob": -0.6230469,
|
||||||
|
"text": "ft"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1480,
|
||||||
|
"logprob": -0.20874023,
|
||||||
|
"text": "ron"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9369,
|
||||||
|
"logprob": -4.5507812,
|
||||||
|
"text": " believes"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 455,
|
||||||
|
"logprob": -4.5664062,
|
||||||
|
"text": " all"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 599,
|
||||||
|
"logprob": -2.7402344,
|
||||||
|
"text": " other"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5632,
|
||||||
|
"logprob": -0.21948242,
|
||||||
|
"text": " animals"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 362,
|
||||||
|
"logprob": -0.7675781,
|
||||||
|
"text": " are"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23981,
|
||||||
|
"logprob": -5.0,
|
||||||
|
"text": " irrelevant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 635,
|
||||||
|
"logprob": -4.234375,
|
||||||
|
"text": " when"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4354,
|
||||||
|
"logprob": -0.5131836,
|
||||||
|
"text": " compared"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 271,
|
||||||
|
"logprob": -0.103637695,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 248,
|
||||||
|
"logprob": -0.58447266,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21735,
|
||||||
|
"logprob": -3.6835938,
|
||||||
|
"text": " glorious"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 64398,
|
||||||
|
"logprob": -1.8173828,
|
||||||
|
"text": " majesty"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 275,
|
||||||
|
"logprob": -0.23510742,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 248,
|
||||||
|
"logprob": -0.35473633,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26680,
|
||||||
|
"logprob": -0.24633789,
|
||||||
|
"text": " gir"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23226,
|
||||||
|
"logprob": -0.02960205,
|
||||||
|
"text": "affe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -0.17333984,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 193,
|
||||||
|
"logprob": -1.3935547,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23626,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Daniel"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37,
|
||||||
|
"logprob": -4.59375,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23090,
|
||||||
|
"logprob": -6.9375,
|
||||||
|
"text": " Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23,
|
||||||
|
"logprob": -0.99365234,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29033,
|
||||||
|
"logprob": -2.2324219,
|
||||||
|
"text": " Gir"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1622,
|
||||||
|
"logprob": -0.10809326,
|
||||||
|
"text": "af"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 249,
|
||||||
|
"logprob": -0.042663574,
|
||||||
|
"text": "at"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1480,
|
||||||
|
"logprob": -0.0024776459,
|
||||||
|
"text": "ron"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -1.4277344,
|
||||||
|
"text": "!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 193,
|
||||||
|
"logprob": -1.1015625,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50,
|
||||||
|
"logprob": -0.05709839,
|
||||||
|
"text": "G"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -0.13208008,
|
||||||
|
"text": "ir"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1622,
|
||||||
|
"logprob": -0.0071487427,
|
||||||
|
"text": "af"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 249,
|
||||||
|
"logprob": -0.008468628,
|
||||||
|
"text": "at"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1480,
|
||||||
|
"logprob": -0.00068998337,
|
||||||
|
"text": "ron"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37,
|
||||||
|
"logprob": -0.0074691772,
|
||||||
|
"text": ":"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 23090,
|
||||||
|
"logprob": -1.8251953,
|
||||||
|
"special": false,
|
||||||
|
"text": " Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23,
|
||||||
|
"logprob": -0.3173828,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8156,
|
||||||
|
"logprob": -0.23803711,
|
||||||
|
"special": false,
|
||||||
|
"text": " Daniel"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -0.56933594,
|
||||||
|
"special": false,
|
||||||
|
"text": "!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 193,
|
||||||
|
"logprob": -0.61279297,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 23626,
|
||||||
|
"logprob": -0.41967773,
|
||||||
|
"special": false,
|
||||||
|
"text": "Daniel"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37,
|
||||||
|
"logprob": -0.0023403168,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1634,
|
||||||
|
"logprob": -2.0605469,
|
||||||
|
"special": false,
|
||||||
|
"text": " What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -1.5292969,
|
||||||
|
"special": false,
|
||||||
|
"text": "'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 94,
|
||||||
|
"logprob": -0.007904053,
|
||||||
|
"special": false,
|
||||||
|
"text": "s"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " Hello, Daniel!\nDaniel: What's"
|
||||||
|
}
|
@ -0,0 +1,98 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "ir"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1622,
|
||||||
|
"logprob": -7.8125,
|
||||||
|
"text": "af"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 249,
|
||||||
|
"logprob": -4.5,
|
||||||
|
"text": "at"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1480,
|
||||||
|
"logprob": -10.875,
|
||||||
|
"text": "ron"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37,
|
||||||
|
"logprob": -3.6875,
|
||||||
|
"text": ":"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 836,
|
||||||
|
"logprob": -1.265625,
|
||||||
|
"special": false,
|
||||||
|
"text": " i"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -0.11621094,
|
||||||
|
"special": false,
|
||||||
|
"text": "'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 88,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "m"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1241,
|
||||||
|
"logprob": -0.953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " using"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 248,
|
||||||
|
"logprob": -2.5,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 204,
|
||||||
|
"logprob": -0.62890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2485,
|
||||||
|
"logprob": -0.54296875,
|
||||||
|
"special": false,
|
||||||
|
"text": "32"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3882,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "bit"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2684,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " version"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron: i'm using the 32-bit version"
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
63
integration-tests/models/test_flash_falcon.py
Normal file
63
integration-tests/models/test_flash_falcon.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_falcon_handle(launcher):
|
||||||
|
with launcher("tiiuae/falcon-7b", trust_remote_code=True) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_falcon(flash_falcon_handle):
|
||||||
|
await flash_falcon_handle.health(120)
|
||||||
|
return flash_falcon_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||||
|
response = await flash_falcon.generate(
|
||||||
|
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||||
|
response = await flash_falcon.generate(
|
||||||
|
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_falcon,
|
||||||
|
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -37,7 +37,7 @@ class FlashRW(FlashCausalLM):
|
|||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("RW is only available on GPU")
|
raise NotImplementedError("RW is only available on GPU")
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class FlashRW(FlashCausalLM):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do not use from_pretrained as we modified the model internal module layout
|
# We do not use from_pretrained as it is too slow
|
||||||
try:
|
try:
|
||||||
filenames = weight_files(model_id, revision, ".bin")
|
filenames = weight_files(model_id, revision, ".bin")
|
||||||
# Local files not found
|
# Local files not found
|
||||||
@ -124,7 +124,7 @@ class FlashRWSharded(FlashRW):
|
|||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user