100 lines
2.1 KiB
Go
100 lines
2.1 KiB
Go
package AI
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/pkoukk/tiktoken-go"
|
|
"github.com/sashabaranov/go-openai"
|
|
"os"
|
|
)
|
|
|
|
// This package should use the OpenAI API to provide AI services.
|
|
|
|
type AI interface {
|
|
// Get Embedding
|
|
GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error)
|
|
GetTokenCount(input string) (int, error)
|
|
}
|
|
|
|
type ai struct {
|
|
apiKey string
|
|
baseURL string
|
|
encodingName string
|
|
model string
|
|
client *openai.Client
|
|
}
|
|
|
|
type AIOption func(*ai)
|
|
|
|
func NewAI(otps ...AIOption) (AI, error) {
|
|
a := ai{
|
|
//baseURL: "https://api.openai.com",
|
|
encodingName: "gpt-4o",
|
|
model: openai.GPT4oMini,
|
|
}
|
|
|
|
for _, opt := range otps {
|
|
opt(&a)
|
|
}
|
|
|
|
if a.apiKey == "" && os.Getenv("OPENAI_API_KEY") != "" {
|
|
a.apiKey = os.Getenv("OPENAI_API_KEY")
|
|
}
|
|
if a.apiKey == "" {
|
|
return nil, fmt.Errorf("api key is required")
|
|
}
|
|
|
|
config := openai.DefaultConfig(a.apiKey)
|
|
if a.baseURL == "" && os.Getenv("OPENAI_BASE_URL") != "" {
|
|
a.baseURL = os.Getenv("OPENAI_BASE_URL")
|
|
}
|
|
|
|
if a.baseURL != "" {
|
|
config.BaseURL = a.baseURL
|
|
}
|
|
|
|
a.client = openai.NewClientWithConfig(config)
|
|
|
|
return a, nil
|
|
}
|
|
|
|
func (a ai) GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) {
|
|
embeddingRequest := openai.EmbeddingRequest{
|
|
Input: text,
|
|
Model: "text-embedding-3-small",
|
|
}
|
|
|
|
embeddings, err := a.client.CreateEmbeddings(ctx, embeddingRequest)
|
|
if err != nil {
|
|
return openai.EmbeddingResponse{}, fmt.Errorf("error creating embeddings: %w", err)
|
|
}
|
|
return embeddings, nil
|
|
}
|
|
|
|
func WithAPIKey(apiKey string) AIOption {
|
|
return func(a *ai) {
|
|
a.apiKey = apiKey
|
|
}
|
|
}
|
|
|
|
func WithBaseURL(baseURL string) AIOption {
|
|
return func(a *ai) {
|
|
a.baseURL = baseURL
|
|
}
|
|
}
|
|
|
|
func WithEncodingName(encodingName string) AIOption {
|
|
return func(a *ai) {
|
|
a.encodingName = encodingName
|
|
}
|
|
}
|
|
|
|
func (a ai) GetTokenCount(input string) (int, error) {
|
|
tke, err := tiktoken.EncodingForModel(a.encodingName) // cached in "TIKTOKEN_CACHE_DIR"
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error getting encoding: %w", err)
|
|
}
|
|
token := tke.Encode(input, nil, nil)
|
|
return len(token), nil
|
|
}
|