138 lines
3.7 KiB
Go
138 lines
3.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"ctxGPT/LLMMapper"
|
|
"ctxGPT/promptBuilder"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/sashabaranov/go-openai"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
func main() {
|
|
reader := bufio.NewReader(os.Stdin)
|
|
fmt.Println("Go CLI Chat App")
|
|
fmt.Println("---------------------")
|
|
|
|
systemPrompt := ""
|
|
|
|
messages := []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: systemPrompt,
|
|
},
|
|
}
|
|
|
|
for {
|
|
fmt.Print("You: ")
|
|
text, _ := reader.ReadString('\n')
|
|
text = strings.TrimSpace(text) // Remove leading and trailing whitespace
|
|
|
|
// Check if the user wants to exit.
|
|
if text == "exit" {
|
|
fmt.Println("Exiting chat...")
|
|
break
|
|
}
|
|
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: text,
|
|
})
|
|
|
|
resp, err := LLMMapper.SendPrompt(messages)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
|
|
if len(resp.Choices) == 0 {
|
|
fmt.Println("No choices returned")
|
|
continue
|
|
}
|
|
|
|
fmt.Println("AI: ", resp.Choices[0].Message.Content)
|
|
//fmt.Println("Finish Reason: ", resp.Choices[0].FinishReason)
|
|
messages = append(messages, resp.Choices[0].Message)
|
|
|
|
currLength := estimateTokenCount(messages)
|
|
|
|
if currLength > 3000 {
|
|
fmt.Println("Token count exceeded 3000, summarizing context")
|
|
summarized, err := summarizeChatSoFar(messages)
|
|
if err != nil {
|
|
fmt.Printf("error summarizing chat so far | %v\n", err)
|
|
continue
|
|
} else {
|
|
fmt.Printf("Summarized: %v\n", summarized)
|
|
}
|
|
// reset messages to the system prompt and the summarized prompt
|
|
messages = []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
Content: summarized,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func estimateTokenCount(messages []openai.ChatCompletionMessage) int {
|
|
messagesRaw, err := json.Marshal(messages)
|
|
if err != nil {
|
|
fmt.Printf("error marshalling messages for token size estimation | %v\n", err)
|
|
return 0
|
|
}
|
|
|
|
length, err := LLMMapper.GetTokenCount(string(messagesRaw))
|
|
if err != nil {
|
|
fmt.Printf("error getting token count | %v\n", err)
|
|
return 0
|
|
}
|
|
fmt.Printf("Token Count: %v\n", length)
|
|
return length
|
|
}
|
|
|
|
func summarizeChatSoFar(messages []openai.ChatCompletionMessage) (string, error) {
|
|
messagesRaw, err := json.Marshal(messages)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error marshalling messages for token size estimation | %w", err)
|
|
}
|
|
|
|
summarizeConvoPrompt, err := promptBuilder.BuildPrompt("summarize.tmpl", struct{ WordLimit int }{WordLimit: 100})
|
|
if err != nil {
|
|
return "", fmt.Errorf("error building prompt for summarization | %w", err)
|
|
}
|
|
|
|
resp, err := LLMMapper.SendPrompt([]openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: summarizeConvoPrompt,
|
|
},
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: string(messagesRaw),
|
|
},
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("error summarizing conversation | %w", err)
|
|
}
|
|
|
|
if len(resp.Choices) == 0 {
|
|
return "", fmt.Errorf("no choices returned for summarization")
|
|
}
|
|
|
|
return resp.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
// TODO: anything to be stored in the database should be chunked to sizes between 512 and 1024 tokens
|
|
// it should also overlap with the previous chunk by 100-200 tokens
|
|
// When the LLM asks for more context, it should be able to use the database to find the most relevant chunks here is how:
|
|
// We will get the embeddings for each prompt and use those embeddings to search for the closest 6 chunks
|
|
// we will use a separate LLM prompt to make an attempt to select and sort the chunks based on the user's input
|
|
// then we will add the best matched chunks to the main prompt as further context for the given prompt
|