vllm 整体架构
sample : 将脑子中的想法转变成真的语言
#examples/offline_inference/simple_profiling.py
#LLM 类用于加载和执行模型,SamplingParams 用于设置采样参数
from vllm import LLM, SamplingParams
# enable torch profiler, can also be set on cmd line
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
#temperature:控制生成文本的随机性。较高的温度值(如 1.0)使生成的文本更加随机,而较低的温度值(如 0.1)则使生成的文本更加确定性
#top_p=0.95 表示只考虑概率前 95% 的词汇来生成下一个词,从而避免生成不太可能的词。
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
if __name__ == "__main__":
# Create an LLM.
llm = LLM(model="meta-llama/Llama-2-7b-hf", tensor_parallel_size=1, device="cpu")
llm.start_profile()
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
llm.stop_profile()
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output.
time.sleep(10)
llm类 在 vllm/entrypoints/llm.py 中
ps: 所有不在vllm/v1这个文件夹里的code,叫做v0,现在vllm正在v0->v1迁移
模块 :
- Entrypoint (LLM, api server)
- engin (主要是llm engine.py)
- Scheduler(vllm/core/scheduler.py)
- kv cache manager (kv cache 在工业界还是先行者阶段,真正落地并不多)
- paged attention(多个请求并发执行时,每个请求并需要完整的kv cache。比如我们正在生成第五个token,那么我们只需要前四个token的kv。我们选择一个page大小,将kv cache按照page大小分成相同大小的page,为每个请求按需分配)
- Evictor(当新的请求到达,且当前显存中的 KV cache 空间不足时,触发 Evictor,回收一些旧请求,和其对应的kv cache页)
- prefix cache(优化具有相同前缀的多个请求。最好能手写,找工作很有帮助)
- kv cache optimization
- deepseek(MLA)
- worker(vllm/worker/worker.py)
- model executor(想懂模型研读vllm/model_executor/models/llama.py)
- modeling(指的是vllm/model_executor/models/models下的那些文件,有很多贡献机会。llama.py中重点关于forward函数)
- attention backend(vllm/attention/backends/flash_attn.py)