go调用langchain

openai版本 设置环境变量OPENAI_API_KEY

package main

import (
  "context"
  "fmt"
  "log"

  "github.com/tmc/langchaingo/llms"
  "github.com/tmc/langchaingo/llms/openai"
)

func main() {
  ctx := context.Background()
  llm, err := openai.New()
  if err != nil {
    log.Fatal(err)
  }
  prompt := "What would be a good company name for a company that makes colorful socks?"
  completion, err := llms.GenerateFromSinglePrompt(ctx, llm, prompt)
  if err != nil {
    log.Fatal(err)
  }
  fmt.Println(completion)
}

ollama版本

package main

import (
	"context"
	"fmt"
	"log"

	"github.com/tmc/langchaingo/llms"
	"github.com/tmc/langchaingo/llms/ollama"
)

func main() {
	llm, err := ollama.New(ollama.WithModel("mistral")) //大模型版本
	if err != nil {
		log.Fatal(err)
	}
	ctx := context.Background()

	content := []llms.MessageContent{
		llms.TextParts(llms.ChatMessageTypeSystem, "You are a company branding design wizard."),
		llms.TextParts(llms.ChatMessageTypeHuman, "What would be a good company name a company that makes colorful socks?"),
	}
	completion, err := llm.GenerateContent(ctx, content, llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
		fmt.Print(string(chunk))
		return nil
	}))
	if err != nil {
		log.Fatal(err)
	}
	_ = completion
}

流式返回

package main

import (
	"context"
	"fmt"
	"log"

	"github.com/tmc/langchaingo/llms"
	"github.com/tmc/langchaingo/llms/ollama"
)

func main() {
	llm, err := ollama.New(ollama.WithModel("llama2"))
	if err != nil {
		log.Fatal(err)
	}
	ctx := context.Background()
	completion, err := llms.GenerateFromSinglePrompt(
		ctx,
		llm,
		"Human: Who was the first man to walk on the moon?\nAssistant:",
		llms.WithTemperature(0.8),
		llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
			fmt.Print(string(chunk))
			return nil
		}),
	)
	if err != nil {
		log.Fatal(err)
	}

	_ = completion
}

agent

package main

import (
	"context"
	"fmt"
	"os"

	"github.com/tmc/langchaingo/agents"
	"github.com/tmc/langchaingo/chains"
	"github.com/tmc/langchaingo/llms/openai"
	"github.com/tmc/langchaingo/tools"
	"github.com/tmc/langchaingo/tools/serpapi"
)

func main() {
	if err := run(); err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}
}

func run() error {
	llm, err := openai.New()
	if err != nil {
		return err
	}
	search, err := serpapi.New()
	if err != nil {
		return err
	}
	agentTools := []tools.Tool{
		tools.Calculator{},
		search,
	}
	executor, err := agents.Initialize(
		llm,
		agentTools,
		agents.ZeroShotReactDescription,
		agents.WithMaxIterations(3),
	)
	if err != nil {
		return err
	}
	question := "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?"
	answer, err := chains.Run(context.Background(), executor, question)
	fmt.Println(answer)
	return err
}

sqllite版本

package main

import (
	"context"
	"database/sql"
	"fmt"
	"log"
	"os"

	"github.com/tmc/langchaingo/chains"
	"github.com/tmc/langchaingo/llms/openai"
	"github.com/tmc/langchaingo/tools/sqldatabase"
	_ "github.com/tmc/langchaingo/tools/sqldatabase/sqlite3"
)

func main() {
	if err := run(); err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}
}

func makeSample(dsn string) {
	db, err := sql.Open("sqlite3", dsn)
	if err != nil {
		log.Fatal(err)
	}
	defer db.Close()

	sqlStmt := `
	create table foo (id integer not null primary key, name text);
	delete from foo;
	create table foo1 (id integer not null primary key, name text);
	delete from foo1;
	`
	_, err = db.Exec(sqlStmt)
	if err != nil {
		log.Fatal(err)
	}

	tx, err := db.Begin()
	if err != nil {
		log.Fatal(err)
	}
	stmt, err := tx.Prepare("insert into foo(id, name) values(?, ?)")
	if err != nil {
		log.Fatal(err)
	}
	defer stmt.Close()
	for i := 0; i < 100; i++ {
		_, err = stmt.Exec(i, fmt.Sprintf("Foo %03d", i))
		if err != nil {
			log.Fatal(err)
		}
	}

	stmt1, err := tx.Prepare("insert into foo1(id, name) values(?, ?)")
	if err != nil {
		log.Fatal(err)
	}
	defer stmt1.Close()
	for i := 0; i < 200; i++ {
		_, err = stmt1.Exec(i, fmt.Sprintf("Foo1 %03d", i))
		if err != nil {
			log.Fatal(err)
		}
	}

	err = tx.Commit()
	if err != nil {
		log.Fatal(err)
	}
}

func run() error {
	llm, err := openai.New()
	if err != nil {
		return err
	}

	const dsn = "./foo.db"
	os.Remove(dsn)
	defer os.Remove(dsn)

	makeSample(dsn)

	db, err := sqldatabase.NewSQLDatabaseWithDSN("sqlite3", dsn, nil)
	if err != nil {
		return err
	}
	defer db.Close()

	sqlDatabaseChain := chains.NewSQLDatabaseChain(llm, 100, db)
	ctx := context.Background()
	out, err := chains.Run(ctx, sqlDatabaseChain, "Return all rows from the foo table where the ID is less than 23.")
	if err != nil {
		return err
	}
	fmt.Println(out)

	input := map[string]any{
		"query":              "Return all rows that the ID is less than 23.",
		"table_names_to_use": []string{"foo"},
	}
	out, err = chains.Predict(ctx, sqlDatabaseChain, input)
	if err != nil {
		return err
	}
	fmt.Println(out)

	out, err = chains.Run(ctx, sqlDatabaseChain, "Which table has more data, foo or foo1?")
	if err != nil {
		return err
	}
	fmt.Println(out)
	return err
}

postgresql版本

package main

import (
	"context"
	"database/sql"
	"fmt"
	"log"
	"os"

	"github.com/tmc/langchaingo/chains"
	"github.com/tmc/langchaingo/llms/openai"
	"github.com/tmc/langchaingo/tools/sqldatabase"
	_ "github.com/tmc/langchaingo/tools/sqldatabase/postgresql"
)

func main() {
	if err := run(); err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}
}

func makeSample(dsn string) {
	db, err := sql.Open("pgx", dsn)
	if err != nil {
		log.Fatal(err)
	}
	defer db.Close()

	sqlStmt := `
	CREATE TABLE IF NOT EXISTS foo (id integer not null primary key, name text);
	delete from foo;
	CREATE TABLE IF NOT EXISTS foo1 (id integer not null primary key, name text);
	delete from foo1;
	`
	_, err = db.Exec(sqlStmt)
	if err != nil {
		log.Fatal(err)
	}

	tx, err := db.Begin()
	if err != nil {
		log.Fatal(err)
	}
	stmt, err := tx.Prepare("insert into foo(id, name) values($1, $2)")
	if err != nil {
		log.Fatal(err)
	}
	defer stmt.Close()
	for i := 0; i < 100; i++ {
		_, err = stmt.Exec(i, fmt.Sprintf("Foo %03d", i))
		if err != nil {
			log.Fatal(err)
		}
	}

	stmt1, err := tx.Prepare("insert into foo1(id, name) values($1, $2)")
	if err != nil {
		log.Fatal(err)
	}
	defer stmt1.Close()
	for i := 0; i < 200; i++ {
		_, err = stmt1.Exec(i, fmt.Sprintf("Foo1 %03d", i))
		if err != nil {
			log.Fatal(err)
		}
	}

	err = tx.Commit()
	if err != nil {
		log.Fatal(err)
	}
}

func run() error {
	llm, err := openai.New()
	if err != nil {
		return err
	}

	dsn := os.Getenv("LANGCHAINGO_POSTGRESQL")

	makeSample(dsn)

	db, err := sqldatabase.NewSQLDatabaseWithDSN("pgx", dsn, nil)
	if err != nil {
		return err
	}
	defer db.Close()

	sqlDatabaseChain := chains.NewSQLDatabaseChain(llm, 100, db)
	ctx := context.Background()
	out, err := chains.Run(ctx, sqlDatabaseChain, "Return all rows from the foo table where the ID is less than 23.")
	if err != nil {
		return err
	}
	fmt.Println(out)

	input := map[string]any{
		"query":              "Return all rows that the ID is less than 23.",
		"table_names_to_use": []string{"foo"},
	}
	out, err = chains.Predict(ctx, sqlDatabaseChain, input)
	if err != nil {
		return err
	}
	fmt.Println(out)

	out, err = chains.Run(ctx, sqlDatabaseChain, "Which table has more data, foo or foo1$")
	if err != nil {
		return err
	}
	fmt.Println(out)
	return err
}

文档版本

package main

import (
	"context"
	"fmt"
	"os"

	"github.com/tmc/langchaingo/chains"
	"github.com/tmc/langchaingo/llms/openai"
	"github.com/tmc/langchaingo/schema"
)

func main() {
	if err := run(); err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}
}

func run() error {
	llm, err := openai.New()
	if err != nil {
		return err
	}

	// We can use LoadStuffQA to create a chain that takes input documents and a question,
	// stuffs all the documents into the prompt of the llm and returns an answer to the
	// question. It is suitable for a small number of documents.
	stuffQAChain := chains.LoadStuffQA(llm)
	docs := []schema.Document{
		{PageContent: "Harrison went to Harvard."},
		{PageContent: "Ankush went to Princeton."},
	}

	answer, err := chains.Call(context.Background(), stuffQAChain, map[string]any{
		"input_documents": docs,
		"question":        "Where did Harrison go to collage?",
	})
	if err != nil {
		return err
	}
	fmt.Println(answer)

	// Another option is to use the refine documents chain for question answering. This
	// chain iterates over the input documents one by one, updating an intermediate answer
	// with each iteration. It uses the previous version of the answer and the next document
	// as context. The downside of this type of chain is that it uses multiple llm calls that
	// cant be done in parallel.
	refineQAChain := chains.LoadRefineQA(llm)
	answer, err = chains.Call(context.Background(), refineQAChain, map[string]any{
		"input_documents": docs,
		"question":        "Where did Ankush go to collage?",
	})
	fmt.Println(answer)

	return nil
}
posted @ 2024-04-27 15:00  朝阳1  阅读(493)  评论(0)    收藏  举报