Go 实现本地 Ollama 模型基准测试工具

Go 实现本地 Ollama 模型基准测试工具

在本地部署 LLM 时,评估模型的推理速度(TPS)和延迟是关键环节。本文介绍一个基于 Go 语言编写的 Ollama 模型基准测试工具,支持并发测试、配置化管理及自动化报告。
每秒输出token数是重要的衡量指标

功能特性

  1. 并发测试:支持同时测试多个模型,提高效率。
  2. 配置驱动:通过 JSON 配置文件管理待测模型列表和参数。
  3. 智能兼容:自动识别 Qwen 系列模型并关闭 think 模式,避免思维链影响速度测试。
  4. 版本检测:自动获取 Ollama 服务端版本,便于环境归档。
  5. 详细报告:控制台输出简要报表,完整结果保存为 JSON 文件。

核心代码

package main

import (
	"bytes"
	"encoding/json"
	"fmt"
	"net/http"
	"os"
	"strings"
	"sync"
	"time"
)

type Config struct {
	URL     string   `json:"url"`
	Prompt  string   `json:"prompt"`
	Models  []string `json:"models"`
	Timeout int      `json:"timeout_seconds"`
	Output  string   `json:"output_file"`
}

type Request struct {
	Model   string         `json:"model"`
	Prompt  string         `json:"prompt"`
	Stream  bool           `json:"stream"`
	Options map[string]any `json:"options,omitempty"`
}

type Response struct {
	EvalCount    int    `json:"eval_count"`
	EvalDuration int64  `json:"eval_duration"`
	Model        string `json:"model"`
}

type VersionInfo struct {
	Version string `json:"version"`
}

type Result struct {
	Model         string  `json:"model"`
	Success       bool    `json:"success"`
	Error         string  `json:"error,omitempty"`
	EvalCount     int     `json:"eval_count"`
	EvalDuration  int64   `json:"eval_duration_ns"`
	TotalDuration float64 `json:"total_duration_sec"`
	TokensPerSec  float64 `json:"tokens_per_sec"`
	Timestamp     string  `json:"timestamp"`
}

func loadConfig(path string) (*Config, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return nil, err
	}
	var cfg Config
	return &cfg, json.Unmarshal(data, &cfg)
}

func getOllamaVersion(baseURL string) string {
	url := strings.TrimSuffix(baseURL, "/api/generate") + "/api/version"
	client := &http.Client{Timeout: 10 * time.Second}
	resp, err := client.Get(url)
	if err != nil {
		return "unknown"
	}
	defer resp.Body.Close()

	var v VersionInfo
	if err := json.NewDecoder(resp.Body).Decode(&v); err != nil {
		return "unknown"
	}
	return v.Version
}

func testModel(url, model, prompt string, timeout time.Duration) Result {
	res := Result{
		Model:     model,
		Timestamp: time.Now().Format(time.RFC3339),
	}

	options := map[string]any{}
	if strings.Contains(strings.ToLower(model), "qwen") {
		options["think"] = false
	}

	payload := Request{
		Model:   model,
		Prompt:  prompt,
		Stream:  false,
		Options: options,
	}

	data, err := json.Marshal(payload)
	if err != nil {
		res.Success = false
		res.Error = err.Error()
		return res
	}

	start := time.Now()
	client := &http.Client{Timeout: timeout}
	resp, err := client.Post(url, "application/json", bytes.NewBuffer(data))
	if err != nil {
		res.Success = false
		res.Error = err.Error()
		return res
	}
	defer resp.Body.Close()

	var apiResp Response
	if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
		res.Success = false
		res.Error = err.Error()
		return res
	}

	res.Success = true
	res.EvalCount = apiResp.EvalCount
	res.EvalDuration = apiResp.EvalDuration
	res.TotalDuration = time.Since(start).Seconds()
	if apiResp.EvalDuration > 0 {
		res.TokensPerSec = float64(apiResp.EvalCount) / (float64(apiResp.EvalDuration) / 1e9)
	}
	return res
}

func saveResults(results []Result, filename, ollamaVersion string) error {
	output := map[string]any{
		"ollama_version": ollamaVersion,
		"timestamp":      time.Now().Format(time.RFC3339),
		"results":        results,
		"summary":        generateSummary(results),
	}
	data, err := json.MarshalIndent(output, "", "  ")
	if err != nil {
		return err
	}
	return os.WriteFile(filename, data, 0644)
}

func generateSummary(results []Result) map[string]any {
	success := 0
	var totalTPS float64
	var fastest, slowest string
	var maxTPS, minTPS float64 = -1, -1

	for _, r := range results {
		if r.Success {
			success++
			totalTPS += r.TokensPerSec
			if r.TokensPerSec > maxTPS {
				maxTPS = r.TokensPerSec
				fastest = r.Model
			}
			if minTPS < 0 || r.TokensPerSec < minTPS {
				minTPS = r.TokensPerSec
				slowest = r.Model
			}
		}
	}

	avgTPS := 0.0
	if success > 0 {
		avgTPS = totalTPS / float64(success)
	}

	return map[string]any{
		"total_models":   len(results),
		"success_count":  success,
		"failed_count":   len(results) - success,
		"avg_tokens_sec": fmt.Sprintf("%.2f", avgTPS),
		"fastest_model":  fastest,
		"slowest_model":  slowest,
	}
}

func printReport(results []Result, ollamaVersion string) {
	fmt.Println("\n=== 基准测试报告 ===")
	fmt.Printf("Ollama 版本:%s\n", ollamaVersion)
	fmt.Printf("测试时间:%s\n", time.Now().Format("2006-01-02 15:04:05"))
	fmt.Printf("%-25s %-8s %-10s %-12s %-10s\n", "Model", "Success", "Tokens", "TPS", "Latency")
	fmt.Println(strings.Repeat("-", 80))
	for _, r := range results {
		status := "✓"
		if !r.Success {
			status := "✗"
			fmt.Printf("%-25s %-8s %-10d %-12.2f %-10.2fs\n",
				r.Model, status, r.EvalCount, r.TokensPerSec, r.TotalDuration)
			fmt.Printf("  └─ Error: %s\n", r.Error)
		} else {
			fmt.Printf("%-25s %-8s %-10d %-12.2f %-10.2fs\n",
				r.Model, status, r.EvalCount, r.TokensPerSec, r.TotalDuration)
		}
	}
	fmt.Println()
}

func main() {
	cfgPath := "config.json"
	if len(os.Args) > 1 {
		cfgPath = os.Args[1]
	}

	cfg, err := loadConfig(cfgPath)
	if err != nil {
		fmt.Printf("加载配置失败:%v\n", err)
		return
	}

	timeout := time.Duration(cfg.Timeout) * time.Second
	if timeout == 0 {
		timeout = 300 * time.Second
	}

	ollamaVersion := getOllamaVersion(cfg.URL)
	fmt.Printf("Ollama 版本:%s\n", ollamaVersion)

	var wg sync.WaitGroup
	var mu sync.Mutex
	results := make([]Result, 0, len(cfg.Models))

	for _, model := range cfg.Models {
		wg.Add(1)
		go func(m string) {
			defer wg.Done()
			fmt.Printf("测试中:%s\n", m)
			r := testModel(cfg.URL, m, cfg.Prompt, timeout)
			mu.Lock()
			results = append(results, r)
			mu.Unlock()
		}(model)
	}
	wg.Wait()

	printReport(results, ollamaVersion)

	if cfg.Output != "" {
		if err := saveResults(results, cfg.Output, ollamaVersion); err != nil {
			fmt.Printf("保存结果失败:%v\n", err)
		} else {
			fmt.Printf("结果已保存至:%s\n", cfg.Output)
		}
	}
}

配置说明

创建 config.json 文件,定义测试参数:

{
  "url": "http://localhost:11434/api/generate",
  "prompt": "你好,请介绍一下你自己",
  "models": [
    "qwen3:0.6b",
    "qwen3.5:0.8b",
    "llama3:8b"
  ],
  "timeout_seconds": 300,
  "output_file": "benchmark_results.json"
}

使用指南

  1. 初始化:确保 Ollama 服务已启动且模型已 pull。
  2. 编译运行
    go build -o ollama-bench
    ./ollama-bench
    
    或指定配置文件:
    ./ollama-bench config.json
    
  3. 查看结果:控制台输出简要报表,详细数据保存至 benchmark_results.json

测试结果示例

控制台输出:

Ollama 版本:0.18.2

=== 基准测试报告 ===
Ollama 版本:0.18.2
测试时间:2026-03-23 14:30:00
Model                     Success  Tokens     TPS          Latency   
--------------------------------------------------------------------------------
qwen3:0.6b                ✓        45         32.50        1.38s     
qwen3.5:0.8b              ✓        52         28.15        1.85s     
llama3:8b                 ✓        48         41.20        1.16s     

JSON 报告包含版本信息、每个模型的详细指标及汇总统计,便于后续对比分析。

总结

该工具利用 Go 的并发特性加速测试流程,通过配置化设计适应不同场景。针对 Qwen 系列模型的特殊参数处理,确保了速度测试的准确性。代码结构清晰,易于扩展更多评估维度。

posted @ 2026-03-23 11:46  jiftle  阅读(14)  评论(0)    收藏  举报