make the chat work
This commit is contained in:
134
LLMMapper/llmMapper.go
Normal file
134
LLMMapper/llmMapper.go
Normal file
@ -0,0 +1,134 @@
|
||||
package LLMMapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const encodingName = "gpt-4"
|
||||
const model = openai.GPT4TurboPreview
|
||||
const maxTokens = 4096
|
||||
|
||||
// const maxTokens = 128000
|
||||
const temperature = 0.3
|
||||
|
||||
func GetTokenCount(input string) (int, error) {
|
||||
tke, err := tiktoken.EncodingForModel(encodingName) // cached in "TIKTOKEN_CACHE_DIR"
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting encoding: %w", err)
|
||||
}
|
||||
token := tke.Encode(input, nil, nil)
|
||||
return len(token), nil
|
||||
}
|
||||
|
||||
// SinglePromptInteraction calls openai chat endpoint with just a system prompt and a user prompt and returns the response
|
||||
func SinglePromptInteraction(systemPrompt, prompt string) (openai.ChatCompletionResponse, error) {
|
||||
return singlePromptInteraction(systemPrompt, prompt, 5)
|
||||
}
|
||||
|
||||
// singlePromptInteraction calls openai chat endpoint with just a system prompt and a user prompt and returns the response
|
||||
// it also attempts 5 retries spaced 5 seconds apart in the case of rate limiting errors
|
||||
func singlePromptInteraction(systemPrompt, prompt string, retries int) (openai.ChatCompletionResponse, error) {
|
||||
|
||||
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
previousTokenCount, err := GetPreviousTokenUsage(messages)
|
||||
if err != nil {
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("error getting previous token usage: %w", err)
|
||||
}
|
||||
|
||||
messageTokenSize := maxTokens - previousTokenCount
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: model, // switch to the configured Model
|
||||
Messages: messages,
|
||||
MaxTokens: messageTokenSize, // might want to think about how to reduce this
|
||||
Temperature: temperature,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// if 429, wait and try again
|
||||
if strings.Contains(err.Error(), "429") && retries > 0 {
|
||||
seconds := (1 / retries) * 60 // back off for each retry e.g. 12, 15, 20, 30, 60
|
||||
fmt.Printf("429 error, waiting %v seconds...\n", seconds)
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
return singlePromptInteraction(systemPrompt, prompt, retries-1) // TODO: establish base case to prevent forever retrying
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("ChatCompletion request error: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func GetPreviousTokenUsage(messages []openai.ChatCompletionMessage) (int, error) {
|
||||
|
||||
messagesRaw, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error marshalling messages: %w", err)
|
||||
}
|
||||
|
||||
length, err := GetTokenCount(string(messagesRaw))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting token count: %w", err)
|
||||
}
|
||||
//fmt.Printf("Token Count: %v\n", length)
|
||||
return length, nil
|
||||
}
|
||||
|
||||
// SendPrompt calls openai chat endpoint with a list of messages and returns the response
|
||||
func SendPrompt(messages []openai.ChatCompletionMessage) (openai.ChatCompletionResponse, error) {
|
||||
return sendPrompt(messages, 5)
|
||||
}
|
||||
|
||||
func sendPrompt(messages []openai.ChatCompletionMessage, retries int) (openai.ChatCompletionResponse, error) {
|
||||
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
|
||||
|
||||
previousTokenCount, err := GetPreviousTokenUsage(messages)
|
||||
if err != nil {
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("error getting previous token usage: %w", err)
|
||||
}
|
||||
|
||||
messageTokenSize := maxTokens - previousTokenCount
|
||||
fmt.Println("messageTokenSize: ", messageTokenSize)
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: model, // switch to the configured Model
|
||||
Messages: messages,
|
||||
MaxTokens: messageTokenSize,
|
||||
Temperature: temperature,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// if 429, wait and try again
|
||||
if strings.Contains(err.Error(), "429") && retries > 0 {
|
||||
seconds := (1 / retries) * 60 // back off for each retry e.g. 12, 15, 20, 30, 60
|
||||
fmt.Printf("429 error, waiting %v seconds...\n", seconds)
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
return sendPrompt(messages, retries-1) // TODO: establish base case to prevent forever retrying
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("ChatCompletion request error: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
Reference in New Issue
Block a user