Design:

  • Dynamic Connection: A pair of prefill and decode server connections is established for each request. This approach allows us to easily scale the prefill and decode server pools up or down as needed.
  • Non-blocking Transfer: Send and receive operations are non-blocking and run in a background thread. This ensures that the original scheduler event loop continues to operate uninterrupted while data transfer occurs in the background.
  • Heterogeneous Parallelism: The design supports varying tensor parallelism (TP) for key-value (KV) transfers, enabling specialized optimizations on both the prefill and decode sides.
  • RDMA-Based Transfer: We leverage queue pairs in RDMA to establish connections and utilize scatter-gather elements (SGE) in RDMA to transfer non-contiguous memory chunks efficiently.

Load Balancing:

  • Least Loaded Policy: The load balancer is connected to prefill and decode pools with arbitrary sizes. When a request comes in, we select a pair of prefills and decode with the least loaded.
  • Autoscaling: We monitor the GPU utilization of prefill and decode workers, and add or remove them based on real time usage.

Event Loop:

On top of the original SGLang scheduling event loop, we add non-blocking sender and receiver operations.

Prefill Server

1. Bootstrap Queue

    a. Initialize a sender for each request

    b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished

    c. Poll senders to check bootstrap state

    d. Once bootstrap is complete, move request to Waiting Queue

2. Waiting Queue

    a. Use PrefillAdder to pop requests

    b. Run forward

    c. Add the request to Infight Queue

3. Infight Queue

    a. Poll (non-blocking) the sender of the request

    b. Once the transfer has finished, return the request

while True:

    recv_reqs = self.recv_requests()

    self.bootstrap_queue.extend(recv_reqs)

    self.waiting_queue.extend(self.bootstrap_queue.pop_bootstrapped())

    batch = self.get_next_prefill_batch()

    if batch:

        result = self.run_batch(batch)

        # 1. batch.reqs start transferring

        # 2. add the transferring reqs into inflight queue

        self.process_batch_result(self, batch, result)

    self.process_infight_queue(self)

Decode Server

1. PreallocQueue:

    a. Initialize a receiver for each request

    b. The request handshakes first, and pre-allocate kv once there is available kv.

    c. Move the request to TransferQueue.

2. TransferQueue:

    a. Poll the receiver to check the transfer state

    b. If the transfer has finished, move the request to waiting queue

3. WaitingQueue:

    a. Use the requests in the queue to construct a PrebuiltExtendBatch

    b. Skip the prefill forward but only populate metadata

4. RunningBatch:

    a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding

while True:

    recv_reqs = self.recv_requests()

    self.prealloc_queue.extend(recv_reqs)

    self.transfer_queue.extend(self.prealloc_queue.pop_prealloc())

    self.waiting_queue.extend(self.transfer_queue.pop_transfer())

    batch = self.get_next_decode_batch()

    if batch:

        result = self.run_batch(batch)

        self.process_batch_result(self, batch, result)

Transfer Interface:

On a high level, we use KVSender and KVReceiver to manage sending and receiving. The actual transfer happens in a background thread, while we expose a python interface to communicate with the thread. All operations are non-blocking and we can poll the state of the transferring process.

class KVArgs:

    engine_rank: int

    num_receivers: int

    kv_data_ptrs: list[int]

    kv_data_lens: list[int]

    kv_item_lens: list[int]

    ib_device: str

    max_inflight: int

    max_sge: int

    verbose: bool

class KVManager:

    def __init__(self, args: KVArgs): ...

class KVPoll:

    Failed = 0

    Bootstrapping = 1

    WaitingForInput = 2

    Transferring = 3

    Success = 4

class KVSender:

    def __init__(

        self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int, dest_ranks: list[int]

    ): ...

    def init(self, num_kv_indices: int): ...

    def send(self, kv_indices: npt.NDArray[np.int32]): ...

    def poll(self) -> KVPoll: ...

    def failure_exception(self) -> None: ...

class KVReceiver:

    def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): ...

    def init(self, kv_indices: npt.NDArray[np.int32], aux_index: int | None): ...

    def poll(self) -> KVPoll: ...

    def failure_exception(self) -> None: ...

class KVBootstrapServer:

    def __init__(self, port: int): ...

    def stop(self): ...