Files
generic-rag/backend/AI/ai.go

129 lines
3.6 KiB
Go

package AI
import (
"context"
"fmt"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
"os"
)
// This package should use the OpenAI API to provide AI services.
type AI interface {
// Get Embedding
GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error)
GetTokenCount(input string) (int, error)
}
type Service struct {
apiKey string
baseURL string
encodingName string
model string
client *openai.Client
}
type AIOption func(*Service)
func NewAI(otps ...AIOption) (*Service, error) {
a := Service{
//baseURL: "https://api.openai.com",
encodingName: "gpt-4o",
model: openai.GPT4oMini,
}
for _, opt := range otps {
opt(&a)
}
if a.apiKey == "" && os.Getenv("OPENAI_API_KEY") != "" {
a.apiKey = os.Getenv("OPENAI_API_KEY")
}
if a.apiKey == "" {
return nil, fmt.Errorf("api key is required")
}
config := openai.DefaultConfig(a.apiKey)
if a.baseURL == "" && os.Getenv("OPENAI_BASE_URL") != "" {
a.baseURL = os.Getenv("OPENAI_BASE_URL")
}
if a.baseURL != "" {
config.BaseURL = a.baseURL
}
a.client = openai.NewClientWithConfig(config)
return &a, nil
}
func (a *Service) GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) {
embeddingRequest := openai.EmbeddingRequest{
Input: text,
Model: "text-embedding-3-small",
}
embeddings, err := a.client.CreateEmbeddings(ctx, embeddingRequest)
if err != nil {
return openai.EmbeddingResponse{}, fmt.Errorf("error creating embeddings: %w", err)
}
return embeddings, nil
}
func WithAPIKey(apiKey string) AIOption {
return func(a *Service) {
a.apiKey = apiKey
}
}
func WithBaseURL(baseURL string) AIOption {
return func(a *Service) {
a.baseURL = baseURL
}
}
func WithEncodingName(encodingName string) AIOption {
return func(a *Service) {
a.encodingName = encodingName
}
}
func (a *Service) GetTokenCount(input string) (int, error) {
tke, err := tiktoken.EncodingForModel(a.encodingName) // cached in "TIKTOKEN_CACHE_DIR"
if err != nil {
return 0, fmt.Errorf("error getting encoding: %w", err)
}
token := tke.Encode(input, nil, nil)
return len(token), nil
}
// WIP: I'm still getting responses with limited utility. Probably don't use this function until it is figured out.
func (a *Service) PreReason(ctx context.Context, text string) (string, error) {
systemPrompt := "When you receive user content, your job is to systematically break down the content you receive. We will store your reasoning and any output you generate to help with future queries. This way we can more quickly respond to the user having already thought about the subject matter. You should consider as many angles as possible and be as thorough as possible in your reasoning, potentially add summaries of parts that seem important to you or interesting. You should also consider the potential for future queries and how you might be able to help with those."
client := openai.NewClientWithConfig(openai.DefaultConfig(a.apiKey))
preReasoningRequest := openai.ChatCompletionRequest{
Model: openai.O4Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleDeveloper,
Content: systemPrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: text,
},
},
}
resp, err := client.CreateChatCompletion(ctx, preReasoningRequest)
if err != nil {
return "", fmt.Errorf("error creating chat completion: %w", err)
}
respText := resp.Choices[0].Message.Content
if respText == "" {
return "", fmt.Errorf("no response text")
}
fmt.Printf("%+v\n", resp.Choices[0])
return respText, nil
}