Skip to content

Other Tests

Note

Unless otherwise specified, all the continuous batching tests are running with max_model_len=256

tests.e2e.test_spyre_cb

Verification of continuous batching

Run python -m pytest tests/e2e/test_spyre_cb.py.

test_api_cb_generates_correct_max_tokens

test_api_cb_generates_correct_max_tokens(remote_openai_server: RemoteOpenAIServer, model: ModelInfo, backend: str, cb: bool, max_model_len: int, max_num_seqs: int)

Verify API generates the correct numbers of tokens with CB enabled

Source code in tests/e2e/test_spyre_cb.py
@pytest.mark.cb
@pytest.mark.parametrize("cb", [True])
@pytest.mark.parametrize(
    "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
def test_api_cb_generates_correct_max_tokens(
    remote_openai_server: RemoteOpenAIServer,
    model: ModelInfo,
    backend: str,
    cb: bool,
    max_model_len: int,
    max_num_seqs: int,
):
    """Verify API generates the correct numbers of tokens with CB enabled"""

    client = remote_openai_server.get_client()
    max_tokens = 10

    response = client.completions.create(model=model.name,
                                         prompt=get_chicken_soup_prompts(1),
                                         max_tokens=max_tokens,
                                         temperature=0)

    assert response.usage.completion_tokens == max_tokens

test_api_cb_rejects_oversized_request

test_api_cb_rejects_oversized_request(remote_openai_server: RemoteOpenAIServer, model: ModelInfo, backend: str, cb: bool, max_model_len: int, max_num_seqs: int)

Verify API rejects request that exceed max_model_len with CB enabled

Source code in tests/e2e/test_spyre_cb.py
@pytest.mark.cb
@pytest.mark.parametrize("cb", [True])
@pytest.mark.parametrize(
    "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
def test_api_cb_rejects_oversized_request(
    remote_openai_server: RemoteOpenAIServer,
    model: ModelInfo,
    backend: str,
    cb: bool,
    max_model_len: int,
    max_num_seqs: int,
):
    """Verify API rejects request that exceed max_model_len with CB enabled"""

    client = remote_openai_server.get_client()
    overflow_prompt = " ".join(["hi"] * max_model_len)
    max_tokens = 10

    with pytest.raises(BadRequestError, match="maximum context length is"):
        client.completions.create(
            model=model.name,
            prompt=overflow_prompt,
            max_tokens=max_tokens,
        )

test_cb_max_tokens

test_cb_max_tokens(model: ModelInfo, backend: str, max_model_len: int, max_num_seqs: int, monkeypatch: MonkeyPatch, use_llm_cache)

Test that continuous batches of requests that are longer than the max_model_len are correctly rejected

Source code in tests/e2e/test_spyre_cb.py
@pytest.mark.cb
@pytest.mark.parametrize(
    "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
def test_cb_max_tokens(model: ModelInfo, backend: str, max_model_len: int,
                       max_num_seqs: int, monkeypatch: pytest.MonkeyPatch,
                       use_llm_cache):
    """Test that continuous batches of requests that
    are longer than the `max_model_len` are correctly rejected"""
    max_tokens = 20

    overflow_prompt = " ".join(["a"] * max_model_len)

    vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
                                          temperature=0,
                                          ignore_eos=True,
                                          logprobs=0)

    with pytest.raises(ValueError, match="max model context length"):
        generate_spyre_vllm_output(model=model,
                                   prompts=overflow_prompt,
                                   max_model_len=max_model_len,
                                   sampling_params=vllm_sampling_params,
                                   tensor_parallel_size=1,
                                   backend=backend,
                                   max_num_seqs=max_num_seqs,
                                   use_cb=True,
                                   monkeypatch=monkeypatch)

test_long_context_batches

test_long_context_batches(model: ModelInfo, backend: str, tp_size: int, monkeypatch: MonkeyPatch)

Tests continuous batching with various batch sizes and prompt lengths.

Source code in tests/e2e/test_spyre_cb.py
@pytest.mark.compiler_support_32k
@pytest.mark.cb
@pytest.mark.parametrize(
    "backend", [pytest.param("sendnn", marks=pytest.mark.spyre, id="sendnn")])
@pytest.mark.parametrize(
    "tp_size",
    [
        pytest.param(4, marks=pytest.mark.multi),
    ],
    ids=lambda val: f"TP({val})",
)
def test_long_context_batches(
    model: ModelInfo,
    backend: str,
    tp_size: int,
    monkeypatch: pytest.MonkeyPatch,
):
    """Tests continuous batching with various batch sizes and prompt lengths."""

    skip_unsupported_tp_size(tp_size, backend)

    monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1")
    monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
    monkeypatch.setenv("VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER", "1")

    max_model_len = 32768
    max_num_seqs = 32
    max_tokens = 10

    # (batch_size, prompt_length) pairs
    batch_token_pairs = [
        (32, 512),
        (16, 1500),
        (8, 3000),
        (4, 5000),
        (2, 9000),
        (1, 17000),
    ]

    vllm_model = LLM(model=model.name,
                     tokenizer=model.name,
                     max_model_len=max_model_len,
                     max_num_seqs=max_num_seqs,
                     tensor_parallel_size=tp_size,
                     revision=model.revision)

    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=0,
        ignore_eos=True,
        logprobs=0,
    )

    for batch_size, token_len in batch_token_pairs:
        prompt = create_seq_prompt(model.name, token_length=token_len)
        prompts = [prompt] * batch_size

        vllm_outputs = vllm_model.generate(prompts, sampling_params)

        results = []
        for req_output in vllm_outputs:
            result = extract_output(req_output)
            results.append(result)

    check_output_against_hf(
        model=model,
        backend=backend,
        max_new_tokens=max_tokens,
        vllm_results=results,
        prompts=prompts,
    )

    force_engine_shutdown(vllm_model)

test_swap_decode_programs_for_cb

test_swap_decode_programs_for_cb(tp_size: int, monkeypatch: MonkeyPatch) -> None

Validate the runtime's ability to swap between different compiled decode programs for varying batch sizes and TKV.

The test case consists of 32 small input prompts with specifically chosen max_new_tokens values to trigger different decode programs at runtime.

The test case structure is as follows:

  • 16 prompts with max_new_tokens @ 1k
  • 8 prompts with max_new_tokens @ 2k
  • 4 prompts with max_new_tokens @ 4k
  • 2 prompts with max_new_tokens @ 8k
  • 1 prompt with max_new_tokens @ 16k
  • 1 prompt with max_new_tokens @ 32k
Source code in tests/e2e/test_spyre_cb.py
@pytest.mark.compiler_support_32k
@pytest.mark.spyre
@pytest.mark.cb
@pytest.mark.parametrize(
    "tp_size",
    [
        pytest.param(4, marks=pytest.mark.multi),
    ],
    ids=lambda val: f"TP({val})",
)
def test_swap_decode_programs_for_cb(
    tp_size: int,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    '''

    Validate the runtime's ability to swap between different compiled decode 
    programs for varying batch sizes and TKV.

    The test case consists of 32 small input prompts with specifically chosen 
    max_new_tokens values to trigger different decode programs at runtime.

    The test case structure is as follows:

    - 16 prompts with max_new_tokens @ 1k
    -  8 prompts with max_new_tokens @ 2k
    -  4 prompts with max_new_tokens @ 4k
    -  2 prompts with max_new_tokens @ 8k
    -  1 prompt  with max_new_tokens @ 16k
    -  1 prompt  with max_new_tokens @ 32k

    '''

    model = 'ibm-granite/granite-3.3-8b-instruct'
    backend = 'sendnn'
    max_num_seqs = 32

    max_model_len = 32 * 1024  # 32K

    skip_unsupported_tp_size(tp_size, backend)
    prompts = get_chicken_soup_prompts(max_num_seqs)

    create_sampling_params = lambda max_new_tokens: SamplingParams(
        # The prompt will pad to 64 tokens, therefore to match
        # max_model_len/max_new_tokens, we need to decrease by the prompt
        # length
        max_tokens=max_new_tokens - 64,
        temperature=0,
        logprobs=0,  # return logprobs of generated tokens only
        ignore_eos=True)

    p1k = 1 * 1024
    p2k = 2 * 1024
    p4k = 4 * 1024
    p8k = 8 * 1024
    p16k = 16 * 1024
    p32k = 32 * 1024

    sampling_params_1k = [create_sampling_params(p1k) for _ in range(16)]
    sampling_params_2k = [create_sampling_params(p2k) for _ in range(8)]
    sampling_params_4k = [create_sampling_params(p4k) for _ in range(4)]
    sampling_params_8k = [create_sampling_params(p8k) for _ in range(2)]
    sampling_params_16k = [create_sampling_params(p16k) for _ in range(1)]
    sampling_params_32k = [create_sampling_params(p32k) for _ in range(1)]

    sampling_params = sampling_params_1k + sampling_params_2k + \
        sampling_params_4k + sampling_params_8k + sampling_params_16k + \
            sampling_params_32k

    # Read the cache and check beforehand if the cache was written with the
    # expected prompt. We use the filepath of this script to resolve
    # the cache filepaths
    script_directory = Path(__file__).parent.absolute() / 'cache'
    with open(script_directory / 'prompts_8k_bs2.pickle', 'rb') as f:
        cache_result_8k_bs2: list[dict[str, Any]] = pickle.loads(f.read())

    assert cache_result_8k_bs2[0]['prompt'] == prompts[28]
    assert cache_result_8k_bs2[1]['prompt'] == prompts[29]

    with open(script_directory / 'prompts_16k_bs1.pickle', 'rb') as f:
        cache_result_16k_bs1: list[dict[str, Any]] = pickle.loads(f.read())

    assert cache_result_16k_bs1[0]['prompt'] == prompts[30]

    # Generate results from vLLM
    vllm_results = generate_spyre_vllm_output(model=model,
                                              prompts=prompts,
                                              sampling_params=sampling_params,
                                              tensor_parallel_size=tp_size,
                                              backend=backend,
                                              max_num_seqs=max_num_seqs,
                                              monkeypatch=monkeypatch,
                                              max_model_len=max_model_len,
                                              use_cb=True)

    # TODO: dummy validation, currently the outputs do not match with
    # HF cache.

    assert vllm_results is not None

tests.e2e.test_spyre_async_llm

test_abort async

test_abort(model: ModelInfo, backend: str, cb: int, warmup_shapes: DecodeWarmupShapes, output_kind: RequestOutputKind, monkeypatch: MonkeyPatch)

Test handling of cancelled requests

Source code in tests/e2e/test_spyre_async_llm.py
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort(model: ModelInfo, backend: str, cb: int,
                     warmup_shapes: DecodeWarmupShapes,
                     output_kind: RequestOutputKind,
                     monkeypatch: pytest.MonkeyPatch):
    """Test handling of cancelled requests"""
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
        if cb == 1:
            m.setenv("VLLM_SPYRE_USE_CB", "1")
        else:
            warmup_prompt_length = [t[0] for t in warmup_shapes]
            warmup_new_tokens = [t[1] for t in warmup_shapes]
            warmup_batch_size = [t[2] for t in warmup_shapes]

            m.setenv('VLLM_SPYRE_WARMUP_PROMPT_LENS',
                     ','.join(str(val) for val in warmup_prompt_length))
            m.setenv('VLLM_SPYRE_WARMUP_NEW_TOKENS',
                     ','.join(str(val) for val in warmup_new_tokens))
            m.setenv('VLLM_SPYRE_WARMUP_BATCH_SIZES',
                     ','.join(str(val) for val in warmup_batch_size))

        # Async LLM API is a little different between v0 and V1
        engine = AsyncLLM.from_engine_args(
            AsyncEngineArgs(model=model.name,
                            tokenizer=model.name,
                            max_model_len=256,
                            max_num_seqs=4,
                            revision=model.revision))
        has_unfinished_requests = \
            engine.output_processor.has_unfinished_requests
        after.callback(engine.shutdown)

        # Test structure here mirrors upstream vLLM test_abort:
        # https://github.com/vllm-project/vllm/blob/e6aab5de2999187c6cf0206f2d63ab6d7a0b6964/tests/v1/engine/test_async_llm.py#L160
        NUM_REQUESTS = 15
        NUM_EXPECTED_TOKENS = 5
        REQUEST_IDS_TO_ABORT = range(1, NUM_REQUESTS, 3)
        PARALLEL_SAMPLE_REQ_IDS = range(1, NUM_REQUESTS, 5)

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests
        tasks: list[asyncio.Task] = []
        prompt = get_chicken_soup_prompts(1)[0]
        for idx, request_id in enumerate(request_ids):
            max_tokens = NUM_EXPECTED_TOKENS
            n = 2 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
                    generate(engine, request_id, prompt, output_kind,
                             max_tokens, n)))

        # Simulate cancellation from API server client disconnect
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm that requests actually cancelled and that the other requests
        # are not impacted
        for idx, task in enumerate(tasks):
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                num_generated_tokens, request_id = await task
                n = 2 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
                    f"expected {expected_tokens}")

        # Make sure all aborted requests were really aborted
        assert not has_unfinished_requests()

        # Confirm that the server is still up and functioning
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
            generate(engine, request_id, prompt, output_kind,
                     NUM_EXPECTED_TOKENS))
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not has_unfinished_requests()

tests.e2e.test_spyre_max_new_tokens

Verification of vLLM output by comparing with HF

Run python -m pytest tests/e2e/test_spyre_max_new_tokens.py.

test_output

test_output(model: ModelInfo, stop_last: bool, max_model_len: int, max_num_seqs: int, warmup_shapes: DecodeWarmupShapes, backend: str, cb: int, monkeypatch: MonkeyPatch, use_llm_cache) -> None

Checks that max_tokens parameter of SamplingParams works correctly For each batch, one prompt has max_tokens set to 1 and the others don't. This checks that the correct request has only a single output token, while the others are not affected.

Source code in tests/e2e/test_spyre_max_new_tokens.py
@pytest.mark.parametrize("stop_last", [True, False])
def test_output(model: ModelInfo, stop_last: bool, max_model_len: int,
                max_num_seqs: int, warmup_shapes: DecodeWarmupShapes,
                backend: str, cb: int, monkeypatch: pytest.MonkeyPatch,
                use_llm_cache) -> None:
    '''
    Checks that `max_tokens` parameter of `SamplingParams` works correctly
    For each batch, one prompt has max_tokens set to 1 and the others don't.
    This checks that the correct request has only a single output token, while
    the others are not affected.
    '''

    prompts = get_chicken_soup_prompts(4)

    max_new_tokens_long = 6
    max_new_tokens_early_stop = 1

    vllm_sampling_params_normal = SamplingParams(
        max_tokens=max_new_tokens_long,
        temperature=0,
        logprobs=0,  # return logprobs of generated tokens only
        ignore_eos=False)

    vllm_sampling_params_early_stop = SamplingParams(
        max_tokens=max_new_tokens_early_stop,
        temperature=0,
        logprobs=0,  # return logprobs of generated tokens only
        ignore_eos=False)

    vllm_sampling_params = [vllm_sampling_params_normal] * 3
    hf_max_new_tokens = [max_new_tokens_long] * 3

    # stop last or first sequence in batch early
    if stop_last:
        vllm_sampling_params = vllm_sampling_params + [
            vllm_sampling_params_early_stop
        ]
        hf_max_new_tokens = hf_max_new_tokens + [max_new_tokens_early_stop]
    else:
        vllm_sampling_params = [vllm_sampling_params_early_stop
                                ] + vllm_sampling_params
        hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens

    kwargs = ({
        "max_num_seqs": max_num_seqs,
        "use_cb": True,
    } if cb == 1 else {
        "warmup_shapes": warmup_shapes
    })

    vllm_results = generate_spyre_vllm_output(
        model=model,
        prompts=prompts,
        sampling_params=vllm_sampling_params,
        tensor_parallel_size=1,
        backend=backend,
        monkeypatch=monkeypatch,
        max_model_len=max_model_len,
        **kwargs)

    check_output_against_hf(model, backend, hf_max_new_tokens, vllm_results,
                            prompts)

tests.e2e.test_spyre_online