Go开发一个简单的数据库mcp
以下是参考别人写的,本来想弄一个自然语言查询,发现太难了。。。
package main
import (
"context"
"encoding/csv"
"flag"
"fmt"
"log"
"strings"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
const (
StatementTypeNoExplainCheck = ""
StatementTypeSelect = "SELECT"
StatementTypeInsert = "INSERT"
StatementTypeUpdate = "UPDATE"
StatementTypeDelete = "DELETE"
)
var (
Host string
User string
Pass string
Port int
Db string
DSN string
ReadOnly bool
WithExplainCheck bool
DB *sqlx.DB
)
type ExplainResult struct {
Id *string `db:"id"`
SelectType *string `db:"select_type"`
Table *string `db:"table"`
Partitions *string `db:"partitions"`
Type *string `db:"type"`
PossibleKeys *string `db:"possible_keys"`
Key *string `db:"key"`
KeyLen *string `db:"key_len"`
Ref *string `db:"ref"`
Rows *string `db:"rows"`
Filtered *string `db:"filtered"`
Extra *string `db:"Extra"`
}
type ShowCreateTableResult struct {
Table string `db:"Table"`
CreateTable string `db:"Create Table"`
}
func main() {
flag.StringVar(&Host, "host", "localhost", "MySQL hostname")
flag.StringVar(&User, "user", "root", "MySQL username")
flag.StringVar(&Pass, "pass", "", "MySQL password")
flag.IntVar(&Port, "port", 3306, "MySQL port")
flag.StringVar(&Db, "db", "", "MySQL database")
flag.StringVar(&DSN, "dsn", "", "MySQL DSN")
flag.BoolVar(&ReadOnly, "read-only", false, "Enable read-only mode")
flag.BoolVar(&WithExplainCheck, "with-explain-check", false, "Check query plan with `EXPLAIN` before executing")
flag.Parse()
if len(DSN) == 0 {
DSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&loc=Local", User, Pass, Host, Port, Db)
}
s := server.NewMCPServer(
"go-mcp-mysql",
"1.0",
)
// 工具
listDatabaseTool := mcp.NewTool(
"list_database",
mcp.WithDescription("展示所有的数据库"),
)
listTableTool := mcp.NewTool(
"list_table",
mcp.WithDescription("展示所有的表"),
)
createTableTool := mcp.NewTool(
"create_table",
mcp.WithDescription("创建一个新表"),
mcp.WithString("query",
mcp.Required(),
mcp.Description("查询刚才创建的表"),
),
)
alterTableTool := mcp.NewTool(
"alter_table",
mcp.WithDescription("修改表结构"),
mcp.WithString("query",
mcp.Required(),
mcp.Description("查询刚才的修改"),
),
)
descTableTool := mcp.NewTool(
"desc_table",
mcp.WithDescription("查看建表语句"),
mcp.WithString("name",
mcp.Required(),
mcp.Description("The name of the table to describe"),
),
)
// Data Tools
readQueryTool := mcp.NewTool(
"read_query",
mcp.WithDescription("查询"),
mcp.WithString("query",
mcp.Required(),
mcp.Description("The SQL query to execute"),
),
)
writeQueryTool := mcp.NewTool(
"write_query",
mcp.WithDescription("新增"),
mcp.WithString("query",
mcp.Required(),
mcp.Description("The SQL query to execute"),
),
)
updateQueryTool := mcp.NewTool(
"update_query",
mcp.WithDescription("修改"),
mcp.WithString("query",
mcp.Required(),
mcp.Description("The SQL query to execute"),
),
)
deleteQueryTool := mcp.NewTool(
"delete_query",
mcp.WithDescription("删除"),
mcp.WithString("query",
mcp.Required(),
mcp.Description("The SQL query to execute"),
),
)
s.AddTool(listDatabaseTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery("SHOW DATABASES", StatementTypeNoExplainCheck)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
s.AddTool(listTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery("SHOW TABLES", StatementTypeNoExplainCheck)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
if !ReadOnly {
s.AddTool(createTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleExec(args["query"].(string), StatementTypeNoExplainCheck)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
}
if !ReadOnly {
s.AddTool(alterTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleExec(args["query"].(string), StatementTypeNoExplainCheck)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
}
s.AddTool(descTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleDescTable(args["name"].(string))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
s.AddTool(readQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleQuery(args["query"].(string), StatementTypeSelect)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
if !ReadOnly {
s.AddTool(writeQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleExec(args["query"].(string), StatementTypeInsert)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
}
if !ReadOnly {
s.AddTool(updateQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleExec(args["query"].(string), StatementTypeUpdate)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
}
if !ReadOnly {
s.AddTool(deleteQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := request.Params.Arguments.(map[string]interface{})
result, err := HandleExec(args["query"].(string), StatementTypeDelete)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(result), nil
})
}
if err := server.ServeStdio(s); err != nil {
log.Fatalf("Server error: %v", err)
}
}
func GetDB() (*sqlx.DB, error) {
if DB != nil {
return DB, nil
}
db, err := sqlx.Connect("mysql", DSN)
if err != nil {
return nil, fmt.Errorf("failed to establish database connection: %v", err)
}
DB = db
return DB, nil
}
func HandleQuery(query, expect string) (string, error) {
result, headers, err := DoQuery(query, expect)
if err != nil {
return "", err
}
s, err := MapToCSV(result, headers)
if err != nil {
return "", err
}
return s, nil
}
func DoQuery(query, expect string) ([]map[string]interface{}, []string, error) {
db, err := GetDB()
if err != nil {
return nil, nil, err
}
if len(expect) > 0 {
if err := HandleExplain(query, expect); err != nil {
return nil, nil, err
}
}
rows, err := db.Queryx(query)
if err != nil {
return nil, nil, err
}
cols, err := rows.Columns()
if err != nil {
return nil, nil, err
}
result := []map[string]interface{}{}
for rows.Next() {
row, err := rows.SliceScan()
if err != nil {
return nil, nil, err
}
resultRow := map[string]interface{}{}
for i, col := range cols {
switch v := row[i].(type) {
case []byte:
resultRow[col] = string(v)
default:
resultRow[col] = v
}
}
result = append(result, resultRow)
}
return result, cols, nil
}
func HandleExec(query, expect string) (string, error) {
db, err := GetDB()
if err != nil {
return "", err
}
if len(expect) > 0 {
if err := HandleExplain(query, expect); err != nil {
return "", err
}
}
result, err := db.Exec(query)
if err != nil {
return "", err
}
ra, err := result.RowsAffected()
if err != nil {
return "", err
}
switch expect {
case StatementTypeInsert:
li, err := result.LastInsertId()
if err != nil {
return "", err
}
return fmt.Sprintf("%d rows affected, last insert id: %d", ra, li), nil
default:
return fmt.Sprintf("%d rows affected", ra), nil
}
}
func HandleExplain(query, expect string) error {
if !WithExplainCheck {
return nil
}
db, err := GetDB()
if err != nil {
return err
}
rows, err := db.Queryx(fmt.Sprintf("EXPLAIN %s", query))
if err != nil {
return err
}
result := []ExplainResult{}
for rows.Next() {
var row ExplainResult
if err := rows.StructScan(&row); err != nil {
return err
}
result = append(result, row)
}
if len(result) != 1 {
return fmt.Errorf("unable to check query plan, denied")
}
match := false
switch expect {
case StatementTypeInsert:
fallthrough
case StatementTypeUpdate:
fallthrough
case StatementTypeDelete:
if *result[0].SelectType == expect {
match = true
}
default:
// for SELECT type query, the select_type will be multiple values
// here we check if it's not INSERT, UPDATE or DELETE
match = true
for _, typ := range []string{StatementTypeInsert, StatementTypeUpdate, StatementTypeDelete} {
if *result[0].SelectType == typ {
match = false
break
}
}
}
if !match {
return fmt.Errorf("query plan does not match expected pattern, denied")
}
return nil
}
func HandleDescTable(name string) (string, error) {
db, err := GetDB()
if err != nil {
return "", err
}
rows, err := db.Queryx(fmt.Sprintf("SHOW CREATE TABLE %s", name))
if err != nil {
return "", err
}
result := []ShowCreateTableResult{}
for rows.Next() {
var row ShowCreateTableResult
if err := rows.StructScan(&row); err != nil {
return "", err
}
result = append(result, row)
}
if len(result) == 0 {
return "", fmt.Errorf("table %s does not exist", name)
}
return result[0].CreateTable, nil
}
func MapToCSV(m []map[string]interface{}, headers []string) (string, error) {
var csvBuf strings.Builder
writer := csv.NewWriter(&csvBuf)
if err := writer.Write(headers); err != nil {
return "", fmt.Errorf("failed to write headers: %v", err)
}
for _, item := range m {
row := make([]string, len(headers))
for i, header := range headers {
value, exists := item[header]
if !exists {
return "", fmt.Errorf("key '%s' not found in map", header)
}
row[i] = fmt.Sprintf("%v", value)
}
if err := writer.Write(row); err != nil {
return "", fmt.Errorf("failed to write row: %v", err)
}
}
writer.Flush()
if err := writer.Error(); err != nil {
return "", fmt.Errorf("error flushing CSV writer: %v", err)
}
return csvBuf.String(), nil
}
执行go mod tidy 安装需要的包
编译 go build
然后添加到自定义mcp里面
{
"mcpServers": {
"mysql": {
"command": "xxx.exe",
"args": [
"--dsn", "username:password@tcp(localhost:3306)/mydb?parseTime=true&loc=Local"
]
}
}
}
使用


因为上面的代码有类似的mcp.WithDescription("展示所有的数据库"),所以我们可以直接用中文"展示所有的数据库",大模型也会帮我们匹配到
Schema Tools
list_database
List all databases in the MySQL server.
Parameters: None
Returns: A list of matching database names.
list_table
List all tables in the MySQL server.
Parameters:
name: If provided, list tables with the specified name, same as SQL SHOW TABLES LIKE '%name%'. Otherwise, list all tables.
Returns: A list of matching table names.
create_table
Create a new table in the MySQL server.
Parameters:
query: The SQL query to create the table.
Returns: x rows affected.
alter_table
Alter an existing table in the MySQL server. The LLM is informed not to drop an existing table or column.
Parameters:
query: The SQL query to alter the table.
Returns: x rows affected.
desc_table
Describe the structure of a table.
Parameters:
name: The name of the table to describe.
Returns: The structure of the table.
Data Tools
read_query
Execute a read-only SQL query.
Parameters:
query: The SQL query to execute.
Returns: The result of the query.
write_query
Execute a write SQL query.
Parameters:
query: The SQL query to execute.
Returns: x rows affected, last insert id: <last_insert_id>.
update_query
Execute an update SQL query.
Parameters:
query: The SQL query to execute.
Returns: x rows affected.
delete_query
Execute a delete SQL query.
Parameters:
query: The SQL query to execute.
Returns: x rows affected.

浙公网安备 33010602011771号