On top of the original SGLang scheduling event loop, we add non-blocking sender and receiver operations.
1. Bootstrap Queue
a. Initialize a sender for each request
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
c. Poll senders to check bootstrap state
d. Once bootstrap is complete, move request to Waiting Queue
2. Waiting Queue
a. Use PrefillAdder to pop requests
b. Run forward
c. Add the request to Infight Queue
3. Infight Queue
a. Poll (non-blocking) the sender of the request
b. Once the transfer has finished, return the request
while True:
recv_reqs = self.recv_requests()
self.bootstrap_queue.extend(recv_reqs)
self.waiting_queue.extend(self.bootstrap_queue.pop_bootstrapped())
batch = self.get_next_prefill_batch()
if batch:
result = self.run_batch(batch)
# 1. batch.reqs start transferring
# 2. add the transferring reqs into inflight queue
self.process_batch_result(self, batch, result)
self.process_infight_queue(self)
1. PreallocQueue:
a. Initialize a receiver for each request
b. The request handshakes first, and pre-allocate kv once there is available kv.
c. Move the request to TransferQueue.
2. TransferQueue:
a. Poll the receiver to check the transfer state
b. If the transfer has finished, move the request to waiting queue
3. WaitingQueue:
a. Use the requests in the queue to construct a PrebuiltExtendBatch
b. Skip the prefill forward but only populate metadata
4. RunningBatch:
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
while True:
recv_reqs = self.recv_requests()
self.prealloc_queue.extend(recv_reqs)
self.transfer_queue.extend(self.prealloc_queue.pop_prealloc())
self.waiting_queue.extend(self.transfer_queue.pop_transfer())
batch = self.get_next_decode_batch()
if batch:
result = self.run_batch(batch)
self.process_batch_result(self, batch, result)
On a high level, we use KVSender and KVReceiver to manage sending and receiving. The actual transfer happens in a background thread, while we expose a python interface to communicate with the thread. All operations are non-blocking and we can poll the state of the transferring process.
class KVArgs:
engine_rank: int
num_receivers: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
ib_device: str
max_inflight: int
max_sge: int
verbose: bool
class KVManager:
def __init__(self, args: KVArgs): ...
class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4
class KVSender:
def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int, dest_ranks: list[int]
): ...
def init(self, num_kv_indices: int): ...
def send(self, kv_indices: npt.NDArray[np.int32]): ...
def poll(self) -> KVPoll: ...
def failure_exception(self) -> None: ...
class KVReceiver:
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): ...
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: int | None): ...
def poll(self) -> KVPoll: ...
def failure_exception(self) -> None: ...
class KVBootstrapServer:
def __init__(self, port: int): ...
def stop(self): ...