create a tool that uses AI to categorize banking transactions
This commit is contained in:
315
main.go
Normal file
315
main.go
Normal file
@ -0,0 +1,315 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Category struct {
|
||||
Name string
|
||||
Transactions []Transaction
|
||||
Total float64
|
||||
}
|
||||
|
||||
func main() {
|
||||
// read csv line by line
|
||||
//lines, err := readLinesFromFile("./Transactions-2023-05-03.csv")
|
||||
//if err != nil {
|
||||
// panic(err)
|
||||
//}
|
||||
|
||||
//transactions, err := getJSONFromFile("./completed.json")
|
||||
//if err != nil {
|
||||
// panic(err)
|
||||
//}
|
||||
|
||||
transactions, err := getCategories()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(len(transactions))
|
||||
fmt.Println(transactions[0])
|
||||
|
||||
categories := make(map[string]Category)
|
||||
|
||||
for _, transaction := range transactions {
|
||||
category, ok := categories[transaction.Category]
|
||||
if !ok {
|
||||
category = Category{
|
||||
Name: transaction.Category,
|
||||
Transactions: []Transaction{},
|
||||
Total: 0,
|
||||
}
|
||||
}
|
||||
|
||||
category.Name = transaction.Category
|
||||
category.Transactions = append(category.Transactions, transaction)
|
||||
|
||||
// parse amount from string to float
|
||||
amount, err := strconv.ParseFloat(transaction.Amount, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
category.Total += amount
|
||||
categories[transaction.Category] = category
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, category := range categories {
|
||||
fmt.Printf("%s: Count: %v Total: %.2f\n", category.Name, len(category.Transactions), category.Total)
|
||||
count += 1
|
||||
}
|
||||
fmt.Println(count)
|
||||
fmt.Println(len(transactions))
|
||||
}
|
||||
|
||||
func getJSONFromFile(filePath string) ([]Transaction, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error opening file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var transactions []Transaction
|
||||
err = json.NewDecoder(file).Decode(&transactions)
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error decoding json: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
func getCategories() ([]Transaction, error) {
|
||||
transactions, err := readCSV("./Transactions-2023-05-03.csv")
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error reading csv: %w", err)
|
||||
}
|
||||
|
||||
var finalTransactions []Transaction
|
||||
count := 0
|
||||
for _, line := range transactions {
|
||||
//if count > 100 {
|
||||
// break
|
||||
//}
|
||||
fmt.Println(line)
|
||||
transaction, err := parseTransaction(line.TransactionName)
|
||||
if err != nil {
|
||||
// continue if parsing error
|
||||
fmt.Println(err)
|
||||
}
|
||||
//fmt.Printf("%+v\n", transaction)
|
||||
combinedTransaction := Transaction{
|
||||
Category: transaction.Category,
|
||||
SubCategory: transaction.SubCategory,
|
||||
Amount: line.Amount,
|
||||
TransactionName: line.TransactionName,
|
||||
IsoDate: line.IsoDate,
|
||||
}
|
||||
finalTransactions = append(finalTransactions, combinedTransaction)
|
||||
count++
|
||||
}
|
||||
|
||||
jsonTransactions, err := json.Marshal(finalTransactions)
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error marshalling json: %w", err)
|
||||
}
|
||||
fmt.Println(string(jsonTransactions))
|
||||
err = writeBytesToFile("./raw-categories.json", jsonTransactions)
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error writing to file: %w", err)
|
||||
}
|
||||
return finalTransactions, nil
|
||||
}
|
||||
|
||||
func writeBytesToFile(filePath string, bytes []byte) error {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = file.Write(bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseTransaction(rawTransaction string) (Transaction, error) {
|
||||
// save for later
|
||||
// category should be one of the following: "Income", "Food", "Shopping", "Transportation", "Housing", "Utilities", "Insurance", "Medical", "Saving and Investing", "Debt Payments", "Personal Spending", "Entertainment", "Other"
|
||||
categories := []string{
|
||||
"Donations",
|
||||
"Banking/Finance",
|
||||
"Mortgage",
|
||||
"Credit Cards",
|
||||
"Loans",
|
||||
"Bills/Utilities",
|
||||
"Food/Groceries",
|
||||
"Pet Care",
|
||||
"Housing",
|
||||
"Cloud Services",
|
||||
"Health and Beauty",
|
||||
"Entertainment/Shopping",
|
||||
"Transportation",
|
||||
"Uncategorized",
|
||||
}
|
||||
|
||||
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: `
|
||||
You are a helpful bot that categorizes transactions. You are given a transaction name raw from the bank and you need to categorize it. Please only respond in json format.
|
||||
category should be one of the following:
|
||||
` + strings.Join(categories, ",\n") + `
|
||||
|
||||
subCategory should provide more detail on the category. For example, if the category is "Food/Groceries", the subCategory could be "Groceries" or "Restaurants".
|
||||
|
||||
Expected output format:
|
||||
{
|
||||
"category": "",
|
||||
"subCategory": "",
|
||||
"transactionName": ""
|
||||
}`,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: rawTransaction,
|
||||
},
|
||||
}
|
||||
|
||||
hasCategory := false
|
||||
for !hasCategory {
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: messages,
|
||||
MaxTokens: 256,
|
||||
Temperature: 0,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return Transaction{}, fmt.Errorf("ChatCompletion request error: %w", err)
|
||||
}
|
||||
fmt.Println(resp.Choices[0].Message.Content)
|
||||
transaction, err := parseAIOutput(resp.Choices[0].Message.Content)
|
||||
if err != nil {
|
||||
return Transaction{}, fmt.Errorf("ChatCompletion parsing error: %w, got response: %v", err, resp.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
// check if category is in the categories slice
|
||||
if contains(categories, transaction.Category) {
|
||||
fmt.Println("Found category!")
|
||||
hasCategory = true
|
||||
return transaction, nil
|
||||
} else if len(messages) < 5 {
|
||||
fmt.Println("Did not find category, trying again...")
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: resp.Choices[0].Message.Role,
|
||||
Content: resp.Choices[0].Message.Content,
|
||||
Name: "",
|
||||
})
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: "Please provide a valid category.",
|
||||
Name: "",
|
||||
})
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return Transaction{}, fmt.Errorf("could not find category for transaction: %v\nMessages: %+v", rawTransaction, messages)
|
||||
}
|
||||
|
||||
func contains(categories []string, category string) bool {
|
||||
for _, c := range categories {
|
||||
if c == category {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type Transaction struct {
|
||||
Category string `json:"category"`
|
||||
SubCategory string `json:"subCategory"`
|
||||
Amount string `json:"amount"`
|
||||
TransactionName string `json:"transactionName"`
|
||||
IsoDate string `json:"isodate"`
|
||||
}
|
||||
|
||||
func parseAIOutput(aiOutput string) (Transaction, error) {
|
||||
var transaction Transaction
|
||||
|
||||
aiOutputParts := strings.Split(aiOutput, "{")
|
||||
aiJson := "{" + aiOutputParts[len(aiOutputParts)-1]
|
||||
|
||||
err := json.Unmarshal([]byte(aiJson), &transaction)
|
||||
if err != nil {
|
||||
return Transaction{}, err
|
||||
}
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
func readLinesFromFile(filename string) ([]string, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var lines []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return lines, nil
|
||||
}
|
||||
|
||||
func readCSV(filePath string) ([]Transaction, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error opening file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
reader := csv.NewReader(file)
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
return []Transaction{}, fmt.Errorf("error reading csv: %w", err)
|
||||
}
|
||||
|
||||
var transactions []Transaction
|
||||
// Do something with the records
|
||||
for _, record := range records {
|
||||
//fmt.Println(record)
|
||||
isoDate, err := time.Parse("01/02/2006", record[0])
|
||||
if err != nil {
|
||||
fmt.Println(fmt.Errorf("error parsing date: %w", err))
|
||||
continue
|
||||
}
|
||||
t := Transaction{
|
||||
IsoDate: isoDate.Format("2006-01-02"),
|
||||
TransactionName: record[2],
|
||||
Amount: record[7],
|
||||
}
|
||||
transactions = append(transactions, t)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
Reference in New Issue
Block a user