135 lines
4.3 KiB
Go
135 lines
4.3 KiB
Go
package LLMMapper
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/pkoukk/tiktoken-go"
|
|
"github.com/sashabaranov/go-openai"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const encodingName = "gpt-4"
|
|
const model = openai.GPT4TurboPreview
|
|
const maxTokens = 4096
|
|
|
|
// const maxTokens = 128000
|
|
const temperature = 0.3
|
|
|
|
func GetTokenCount(input string) (int, error) {
|
|
tke, err := tiktoken.EncodingForModel(encodingName) // cached in "TIKTOKEN_CACHE_DIR"
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error getting encoding: %w", err)
|
|
}
|
|
token := tke.Encode(input, nil, nil)
|
|
return len(token), nil
|
|
}
|
|
|
|
// SinglePromptInteraction calls openai chat endpoint with just a system prompt and a user prompt and returns the response
|
|
func SinglePromptInteraction(systemPrompt, prompt string) (openai.ChatCompletionResponse, error) {
|
|
return singlePromptInteraction(systemPrompt, prompt, 5)
|
|
}
|
|
|
|
// singlePromptInteraction calls openai chat endpoint with just a system prompt and a user prompt and returns the response
|
|
// it also attempts 5 retries spaced 5 seconds apart in the case of rate limiting errors
|
|
func singlePromptInteraction(systemPrompt, prompt string, retries int) (openai.ChatCompletionResponse, error) {
|
|
|
|
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
|
|
|
|
messages := []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: prompt,
|
|
},
|
|
}
|
|
|
|
previousTokenCount, err := GetPreviousTokenUsage(messages)
|
|
if err != nil {
|
|
return openai.ChatCompletionResponse{}, fmt.Errorf("error getting previous token usage: %w", err)
|
|
}
|
|
|
|
messageTokenSize := maxTokens - previousTokenCount
|
|
|
|
resp, err := client.CreateChatCompletion(
|
|
context.Background(),
|
|
openai.ChatCompletionRequest{
|
|
Model: model, // switch to the configured Model
|
|
Messages: messages,
|
|
MaxTokens: messageTokenSize, // might want to think about how to reduce this
|
|
Temperature: temperature,
|
|
},
|
|
)
|
|
if err != nil {
|
|
// if 429, wait and try again
|
|
if strings.Contains(err.Error(), "429") && retries > 0 {
|
|
seconds := (1 / float64(retries)) * 60 // back off for each retry e.g. 12, 15, 20, 30, 60
|
|
fmt.Printf("429 error, waiting %v seconds...\n", seconds)
|
|
time.Sleep(time.Duration(seconds) * time.Second)
|
|
return singlePromptInteraction(systemPrompt, prompt, retries-1) // TODO: establish base case to prevent forever retrying
|
|
}
|
|
return openai.ChatCompletionResponse{}, fmt.Errorf("ChatCompletion request error: %w", err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func GetPreviousTokenUsage(messages []openai.ChatCompletionMessage) (int, error) {
|
|
|
|
messagesRaw, err := json.Marshal(messages)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error marshalling messages: %w", err)
|
|
}
|
|
|
|
length, err := GetTokenCount(string(messagesRaw))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error getting token count: %w", err)
|
|
}
|
|
//fmt.Printf("Token Count: %v\n", length)
|
|
return length, nil
|
|
}
|
|
|
|
// SendPrompt calls openai chat endpoint with a list of messages and returns the response
|
|
func SendPrompt(messages []openai.ChatCompletionMessage) (openai.ChatCompletionResponse, error) {
|
|
return sendPrompt(messages, 5)
|
|
}
|
|
|
|
func sendPrompt(messages []openai.ChatCompletionMessage, retries int) (openai.ChatCompletionResponse, error) {
|
|
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
|
|
|
|
previousTokenCount, err := GetPreviousTokenUsage(messages)
|
|
if err != nil {
|
|
return openai.ChatCompletionResponse{}, fmt.Errorf("error getting previous token usage: %w", err)
|
|
}
|
|
|
|
messageTokenSize := maxTokens - previousTokenCount
|
|
fmt.Println("messageTokenSize: ", messageTokenSize)
|
|
|
|
resp, err := client.CreateChatCompletion(
|
|
context.Background(),
|
|
openai.ChatCompletionRequest{
|
|
Model: model, // switch to the configured Model
|
|
Messages: messages,
|
|
MaxTokens: messageTokenSize,
|
|
Temperature: temperature,
|
|
},
|
|
)
|
|
if err != nil {
|
|
// if 429, wait and try again
|
|
if strings.Contains(err.Error(), "429") && retries > 0 {
|
|
seconds := (1 / float64(retries)) * 60 // back off for each retry e.g. 12, 15, 20, 30, 60
|
|
fmt.Printf("429 error, waiting %v seconds...\n", seconds)
|
|
time.Sleep(time.Duration(seconds) * time.Second)
|
|
return sendPrompt(messages, retries-1) // TODO: establish base case to prevent forever retrying
|
|
}
|
|
return openai.ChatCompletionResponse{}, fmt.Errorf("ChatCompletion request error: %w", err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|