vLLM 源码阅读 - 整体执行流程概览 (part1)

说明:基于 vLLM v0.7.3,commit id: ed6e9075d31e32c8548b480a47d1ffb77da1f54c (HEAD, tag: v0.7.3)

PagedAttention

  • 提出:解决 KV Cache 不连续导致的利用率不高问题。

  • KV Cache 利用率不高的问题:(可参考 pagedattention paper)

    • 事先不知道请求的长度(prompt + output),如果提前分配过大的空间会导致浪费,产生内部碎片(internal fragmentation);过小又无法分配给其他请求,产生外部碎片(external fragmentation)。
    • 无法共享空间,如 beam search 等解码算法会针对一个请求生成多个输出,现有系统无法使多个输出共享一个 prompt。
  • 原理

vLLM 推理

推理的示例程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from vllm import LLM, SamplingParams


prompts = [
("A robot may not injure a human being",
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
# SamplingParams(n=2,
# best_of=5,
# temperature=0.8,
# top_p=0.95,
# frequency_penalty=0.1)),
]
prompts, sampling_paras = zip(*prompts)

# 以 Qwen2.5 为例
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct")

outputs = llm.generate(prompts, sampling_paras)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

初始化流程

接口类 LLM

LLM 类的初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# entrypoints/llm.py
class LLM:
def __init__(...):
engine_args = EngineArgs(
...
)
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
# ---------------------------------------
# 获取 engine 的类,如果指定环境变量 `VLLM_USE_V1` 则使用 V1LLMEngine,默认使用 LLMEngine
# ---------------------------------------
self.engine_class = self.get_engine_class()
# ---------------------------------------
# 调用 LLMEngine::from_engine_args,进行 engine 的初始化
# ---------------------------------------
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)

self.request_counter = Counter()

构造 LLMEngine:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# engine/llm_engine.py
class LLMEngine:

@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# ---------------------------------------
# EngineArgs 类的 create_engine_config 函数,创建初始化各个 config
# 如:ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, LoRAConfig 等,
# 最后将上述 config 组合进 VllmConfig

# VllmConfig 初始化会调用 `current_platform.check_and_update_config(self)`
# 设置 worker_cls='vllm.worker.worker.Worker'
# ---------------------------------------
engine_config = engine_args.create_engine_config(usage_context)

# ---------------------------------------
# executor 的 backend 由 ParallelConfig 的 __post_init__ 设置
# world_size=1,默认则 backend="uni",
# 对应的 executor 为 UniProcExecutor (in executor/uniproc_executor.py)
# ---------------------------------------
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

LLMEngine 类

LLMEngine 初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# engine/llm_engine.py
class LLMEngine:

def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config # noqa
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)

self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs

if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None

# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)

self.seq_counter = Counter()
self.generation_config_fields = (
self.model_config.try_get_generation_config())

self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)

self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
self.model_config)

# ---------------------------------------
# 初始化 executor
# 如:UniProcExecutor (in executor/uniproc_executor.py)
# ---------------------------------------
self.model_executor = executor_class(vllm_config=vllm_config, )

if self.model_config.runner_type != "pooling":
# ---------------------------------------
# 初始化 kv cache
# 调用 Worker::determine_num_available_blocks
# 1. 模型先 forward 一次,得到剩余的可分配的 KV Cache 显存大小。
# 2. 计算 cache_block_size。
# 3. 可用的 KV Cache 显存大小 / cache_block_size = num_gpu_blocks
# 可用的 CPU swap_space_bytes / cache_block_size = bum_cpu_blocks
# 4. 根据 num_gpu_blocks, num_cpu_blocks 调用 Worker 的 _init_cache_engine 方法。

# 会初始化 CacheEngine, CacheEngine 用作 KV Cache 的管理。

# gpu_cache: List[torch.Tensor] 是 List[(num_blocks, block_size, num_kv_heads, head_size)]
# ---------------------------------------
self._initialize_kv_caches()

if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()

self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]

self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]

if self.model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)

self.async_callbacks = [
partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []

# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None

# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
# ---------------------------------------
# vllm/core/scheduler.py
# 初始化 Scheduler
# ---------------------------------------
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]

self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)

# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))

self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

LLMEngine 初始化流程涉及如下几个核心类或组件的初始化。

  1. executor, Worker 初始化

executor 初始化时序图如下:

executor 初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
# ---------------------------------------
# collective_rpc 中的方法名,实际是在 `self.driver_worker` 中调用,
# 如果 `self.driver_worker` 中没有的,会进一步去 `self.worker` 中调用。
# ---------------------------------------
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")

Worker 初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# vllm/worker/worker.py
class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.

Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""

def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker

ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.runner_type == "pooling":
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
# ---------------------------------------
# 初始化 model runner
# ---------------------------------------
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
vllm_config=self.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)

# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as pooling models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}

def init_device(self) -> None:
"""
初始化分布式环境。
1. torch.distributed.init_process_group, 设置 group;
2. model parallel initialized, 设置 TP, PP 的分布式 group 等;
3. ensure_kv_transfer_initialized.
"""
...

def load_model(self):
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
# ---------------------------------------
# worker 中调用 `load_model` 实际调用 `self.model_runner.load_model()`
# 会实际将模型 load 进 device
# ---------------------------------------
self.model_runner.load_model()
  1. KV Cache 初始化

LLMEngine 初始化中,_initialize_kv_caches 函数执行 KV Cache 的初始化,具体如下:

2.1 先计算可用的 GPU、CPU blocks 数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# vllm/worker/worker.py
class Worker(LocalOrDistributedWorkerBase):
...
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
# ---------------------------------------
# 当前 vllm instance 可用显存 = 总的显存 * 指定系数
# ---------------------------------------
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
# ---------------------------------------
# kv cache 可用显存 = 当前 vllm instance 可用显存 - 当前 vllm instance 已使用的非 KV Cache 的显存
# 非 KV Cache 分为三个方面:
# 1. model weights;
# 2. 预留给 peak activation tensors 的;
# 3. NCCL + buffers for some attention backends 等非 torch 组件占用的。
# ---------------------------------------
available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)

计算 cache_block_size:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# vllm/worker/cache_engine.py
class CacheEngine:

def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
"""
计算 1 个 token 的 KV Cache 占用显存量。
对一个 token,
key: (1, num_heads, 1, head_size)
value: (1, num_heads, 1, head_size)

最终计算的指定 block_size 的 KV Cache 的字节数(bytes,如 torch.float16 为 2):
K: num_attention_layers * block_size * num_heads * head_size * dtype_size
V: num_attention_layers * block_size * num_heads * head_size * dtype_size
"""
# ---------------------------------------
# 以 Qwen2.5-7B-Instruct 为例,head_size=hidden_size/num_attention_heads=3584/28=128
# ---------------------------------------
head_size = model_config.get_head_size()
# ---------------------------------------
# 以 Qwen2.5-7B-Instruct 为例,max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)=4//1
# ---------------------------------------
num_heads = model_config.get_num_kv_heads(parallel_config)
# ---------------------------------------
# 以 Qwen2.5-7B-Instruct 为例,num_attention_layers=28
# ---------------------------------------
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)

if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

key_cache_entry = num_heads * head_size
if CacheEngine._align_cache(model_config):
key_cache_entry = align_to_256bytes(key_cache_entry,
model_config.dtype)
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
# ---------------------------------------
# cache_config.block_size 即是 slot 的数量
# ---------------------------------------
total = num_attention_layers * cache_config.block_size * \
(key_cache_entry + value_cache_entry)

dtype_size = get_dtype_size(dtype)
return dtype_size * total

2.2 初始化 KV Cache

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# vllm/worker/worker.py
class Worker(LocalOrDistributedWorkerBase):

def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
# ---------------------------------------
# 初始化 cache_engine
# 初始化 attn_backend
# 初始化 gpu_cache,调用 CacheEngine 的 _allocate_kv_cache 方法得到
# 初始化 cpu_cache
# ---------------------------------------
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
# ---------------------------------------
# CacheEngine 实例中的 gpu_cache 成员
# ---------------------------------------
self.gpu_cache: List[List[torch.Tensor]] = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache)

CacheEngine 预先分配 KV Cache:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# worker/cache_engine.py
class CacheEngine:

def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
# ---------------------------------------
# 默认如果采用 flash_attn 作为 backend
# FlashAttentionBackend (vllm/attention/backends/flash_attn.py) 得到:
# num_layers 长度的 list,list 元素为 tensor,tensor 的维度为:
# (2, num_blocks, block_size, num_kv_heads, head_size)
# ---------------------------------------
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)

kv_cache: List[torch.Tensor] = []
alloc_shape = kv_cache_shape

for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(alloc_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)

# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache.view(kv_cache_shape))
# ---------------------------------------
# 最终分配的 KV Cache
# size: [(2, num_blocks, block_size, num_kv_heads, head_size)] * num_layers
# ---------------------------------------
return kv_cache
  1. Scheduler 初始化

会初始化一个 BlockSpaceManager 用于管理实际的 KV Cache 的分配。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# vllm/core/scheduler.py
class Scheduler:

def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback: Optional[Callable] = None,
) -> None:
# ---------------------------------------
# BlockSpaceManagerImpl 默认为:vllm.core.block_manager.SelfAttnBlockSpaceManager
# in vllm/core/block_manager.py
# ---------------------------------------
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching,
)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
# Contain decode requests.
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self._finished_requests_ids: List[str] = list()

LLM 接口 generate 整体逻辑

generate 的逻辑流程:

  1. for-loop 每个 prompt、sampling_param,并将 prompt、sampling_param 作为 request 参数添加给 llm_engine 的 request pool。
    a. input_processor 做一些输入 prompt 的处理;
    b. 将处理过的 prompt (tokens) 封装成 Sequence,seq 再封装进 SequenceGroup
    c. 将封装成的 seq_group append 到 schedulerwaiting 队列,以待后续处理任务的使用。

上述流程涉及的关键代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# entrypoints/llm.py
class LLM:
def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
# ---------------------------------------
# 将 prompt,sampling_param 作为参数添加给 `llm_engine` 的请求池
# ---------------------------------------
self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

# engine/llm_engine.py
class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
...
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
...
# ---------------------------------------
# 将 prompt 的 token ids 和对应的 seq_id 封装进 `Sequence` 类
# ---------------------------------------
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
...
if isinstance(params, SamplingParams):
# ---------------------------------------
# 由 seq 创建 seq_group;
# seq_group 由有相同 prompt 的 seq 组合而成,即相同的 prompt 生成多个 seq,
# 如采样策略为 beam search 时,或者 sampling_param 中的 n > 1,都属于这种情况。
# ---------------------------------------
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
...
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
# ---------------------------------------
# 将 seq_group append 到 waiting 队列。
# ---------------------------------------
min_cost_scheduler.add_seq_group(seq_group)

return seq_group

SequenceSequenceGroup 类关系图如下:

  1. 然后调用 self._run_engine() 进行实际的调度和推理。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# entrypoints/llm.py
class LLM:
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
# ---------------------------------------
# 调用 scheduler,检查 `waiting`,`running`,`swapped` 队列是否有非空的,
# 如果有,则表示有 unfinished_requests,继续 while 循环。
# ---------------------------------------
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)

# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
  1. llm_engine 调用 stepscheduler 负责调度,model_executor 负责实际的推理。

参考

  • vllm paper. arxiv
  • vLLM First SF Meetup Slides (Public). slide