package AI import ( "context" "fmt" "github.com/pkoukk/tiktoken-go" "github.com/sashabaranov/go-openai" "os" ) // This package should use the OpenAI API to provide AI services. type AI interface { // Get Embedding GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) GetTokenCount(input string) (int, error) } type ai struct { apiKey string baseURL string encodingName string model string client *openai.Client } type AIOption func(*ai) func NewAI(otps ...AIOption) (AI, error) { a := ai{ //baseURL: "https://api.openai.com", encodingName: "gpt-4o", model: openai.GPT4oMini, } for _, opt := range otps { opt(&a) } if a.apiKey == "" && os.Getenv("OPENAI_API_KEY") != "" { a.apiKey = os.Getenv("OPENAI_API_KEY") } if a.apiKey == "" { return nil, fmt.Errorf("api key is required") } config := openai.DefaultConfig(a.apiKey) if a.baseURL == "" && os.Getenv("OPENAI_BASE_URL") != "" { a.baseURL = os.Getenv("OPENAI_BASE_URL") } if a.baseURL != "" { config.BaseURL = a.baseURL } a.client = openai.NewClientWithConfig(config) return a, nil } func (a ai) GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) { embeddingRequest := openai.EmbeddingRequest{ Input: text, Model: "text-embedding-3-small", } embeddings, err := a.client.CreateEmbeddings(ctx, embeddingRequest) if err != nil { return openai.EmbeddingResponse{}, fmt.Errorf("error creating embeddings: %w", err) } return embeddings, nil } func WithAPIKey(apiKey string) AIOption { return func(a *ai) { a.apiKey = apiKey } } func WithBaseURL(baseURL string) AIOption { return func(a *ai) { a.baseURL = baseURL } } func WithEncodingName(encodingName string) AIOption { return func(a *ai) { a.encodingName = encodingName } } func (a ai) GetTokenCount(input string) (int, error) { tke, err := tiktoken.EncodingForModel(a.encodingName) // cached in "TIKTOKEN_CACHE_DIR" if err != nil { return 0, fmt.Errorf("error getting encoding: %w", err) } token := tke.Encode(input, nil, nil) return len(token), nil }