Files
ctxGPT/cmd/chat/chat.go
2024-03-08 01:13:27 -07:00

138 lines
3.7 KiB
Go

package main
import (
"bufio"
"ctxGPT/LLMMapper"
"ctxGPT/promptBuilder"
"encoding/json"
"fmt"
"github.com/sashabaranov/go-openai"
"os"
"strings"
)
func main() {
reader := bufio.NewReader(os.Stdin)
fmt.Println("Go CLI Chat App")
fmt.Println("---------------------")
systemPrompt := ""
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: systemPrompt,
},
}
for {
fmt.Print("You: ")
text, _ := reader.ReadString('\n')
text = strings.TrimSpace(text) // Remove leading and trailing whitespace
// Check if the user wants to exit.
if text == "exit" {
fmt.Println("Exiting chat...")
break
}
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: text,
})
resp, err := LLMMapper.SendPrompt(messages)
if err != nil {
fmt.Println(err)
}
if len(resp.Choices) == 0 {
fmt.Println("No choices returned")
continue
}
fmt.Println("AI: ", resp.Choices[0].Message.Content)
//fmt.Println("Finish Reason: ", resp.Choices[0].FinishReason)
messages = append(messages, resp.Choices[0].Message)
currLength := estimateTokenCount(messages)
if currLength > 3000 {
fmt.Println("Token count exceeded 3000, summarizing context")
summarized, err := summarizeChatSoFar(messages)
if err != nil {
fmt.Printf("error summarizing chat so far | %v\n", err)
continue
} else {
fmt.Printf("Summarized: %v\n", summarized)
}
// reset messages to the system prompt and the summarized prompt
messages = []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: systemPrompt,
},
{
Role: openai.ChatMessageRoleAssistant,
Content: summarized,
},
}
}
}
}
func estimateTokenCount(messages []openai.ChatCompletionMessage) int {
messagesRaw, err := json.Marshal(messages)
if err != nil {
fmt.Printf("error marshalling messages for token size estimation | %v\n", err)
return 0
}
length, err := LLMMapper.GetTokenCount(string(messagesRaw))
if err != nil {
fmt.Printf("error getting token count | %v\n", err)
return 0
}
fmt.Printf("Token Count: %v\n", length)
return length
}
func summarizeChatSoFar(messages []openai.ChatCompletionMessage) (string, error) {
messagesRaw, err := json.Marshal(messages)
if err != nil {
return "", fmt.Errorf("error marshalling messages for token size estimation | %w", err)
}
summarizeConvoPrompt, err := promptBuilder.BuildPrompt("summarize.tmpl", struct{ WordLimit int }{WordLimit: 100})
if err != nil {
return "", fmt.Errorf("error building prompt for summarization | %w", err)
}
resp, err := LLMMapper.SendPrompt([]openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: summarizeConvoPrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: string(messagesRaw),
},
})
if err != nil {
return "", fmt.Errorf("error summarizing conversation | %w", err)
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices returned for summarization")
}
return resp.Choices[0].Message.Content, nil
}
// TODO: anything to be stored in the database should be chunked to sizes between 512 and 1024 tokens
// it should also overlap with the previous chunk by 100-200 tokens
// When the LLM asks for more context, it should be able to use the database to find the most relevant chunks here is how:
// We will get the embeddings for each prompt and use those embeddings to search for the closest 6 chunks
// we will use a separate LLM prompt to make an attempt to select and sort the chunks based on the user's input
// then we will add the best matched chunks to the main prompt as further context for the given prompt