create a tool that uses AI to categorize banking transactions

This commit is contained in:
2023-05-30 23:12:12 -06:00
commit 2004ad0f44
10 changed files with 11094 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

9
.idea/budgetTool.iml generated Normal file
View File

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/budgetTool.iml" filepath="$PROJECT_DIR$/.idea/budgetTool.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

1534
Transactions-2023-05-03.csv Normal file

File diff suppressed because it is too large Load Diff

9206
completed.json Normal file

File diff suppressed because it is too large Load Diff

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module budgetTool
go 1.19
require github.com/sashabaranov/go-openai v1.9.2 // indirect

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/sashabaranov/go-openai v1.9.2 h1:7//Glm9EiMBjelgmBb00yYzKYqm1jckHWWTDLahfeuQ=
github.com/sashabaranov/go-openai v1.9.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=

315
main.go Normal file
View File

@ -0,0 +1,315 @@
package main
import (
"bufio"
"context"
"encoding/csv"
"encoding/json"
"fmt"
"github.com/sashabaranov/go-openai"
"os"
"strconv"
"strings"
"time"
)
type Category struct {
Name string
Transactions []Transaction
Total float64
}
func main() {
// read csv line by line
//lines, err := readLinesFromFile("./Transactions-2023-05-03.csv")
//if err != nil {
// panic(err)
//}
//transactions, err := getJSONFromFile("./completed.json")
//if err != nil {
// panic(err)
//}
transactions, err := getCategories()
if err != nil {
panic(err)
}
fmt.Println(len(transactions))
fmt.Println(transactions[0])
categories := make(map[string]Category)
for _, transaction := range transactions {
category, ok := categories[transaction.Category]
if !ok {
category = Category{
Name: transaction.Category,
Transactions: []Transaction{},
Total: 0,
}
}
category.Name = transaction.Category
category.Transactions = append(category.Transactions, transaction)
// parse amount from string to float
amount, err := strconv.ParseFloat(transaction.Amount, 64)
if err != nil {
continue
}
category.Total += amount
categories[transaction.Category] = category
}
count := 0
for _, category := range categories {
fmt.Printf("%s: Count: %v Total: %.2f\n", category.Name, len(category.Transactions), category.Total)
count += 1
}
fmt.Println(count)
fmt.Println(len(transactions))
}
func getJSONFromFile(filePath string) ([]Transaction, error) {
file, err := os.Open(filePath)
if err != nil {
return []Transaction{}, fmt.Errorf("error opening file: %w", err)
}
defer file.Close()
var transactions []Transaction
err = json.NewDecoder(file).Decode(&transactions)
if err != nil {
return []Transaction{}, fmt.Errorf("error decoding json: %w", err)
}
return transactions, nil
}
func getCategories() ([]Transaction, error) {
transactions, err := readCSV("./Transactions-2023-05-03.csv")
if err != nil {
return []Transaction{}, fmt.Errorf("error reading csv: %w", err)
}
var finalTransactions []Transaction
count := 0
for _, line := range transactions {
//if count > 100 {
// break
//}
fmt.Println(line)
transaction, err := parseTransaction(line.TransactionName)
if err != nil {
// continue if parsing error
fmt.Println(err)
}
//fmt.Printf("%+v\n", transaction)
combinedTransaction := Transaction{
Category: transaction.Category,
SubCategory: transaction.SubCategory,
Amount: line.Amount,
TransactionName: line.TransactionName,
IsoDate: line.IsoDate,
}
finalTransactions = append(finalTransactions, combinedTransaction)
count++
}
jsonTransactions, err := json.Marshal(finalTransactions)
if err != nil {
return []Transaction{}, fmt.Errorf("error marshalling json: %w", err)
}
fmt.Println(string(jsonTransactions))
err = writeBytesToFile("./raw-categories.json", jsonTransactions)
if err != nil {
return []Transaction{}, fmt.Errorf("error writing to file: %w", err)
}
return finalTransactions, nil
}
func writeBytesToFile(filePath string, bytes []byte) error {
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("error creating file: %w", err)
}
defer file.Close()
_, err = file.Write(bytes)
if err != nil {
return fmt.Errorf("error writing to file: %w", err)
}
return nil
}
func parseTransaction(rawTransaction string) (Transaction, error) {
// save for later
// category should be one of the following: "Income", "Food", "Shopping", "Transportation", "Housing", "Utilities", "Insurance", "Medical", "Saving and Investing", "Debt Payments", "Personal Spending", "Entertainment", "Other"
categories := []string{
"Donations",
"Banking/Finance",
"Mortgage",
"Credit Cards",
"Loans",
"Bills/Utilities",
"Food/Groceries",
"Pet Care",
"Housing",
"Cloud Services",
"Health and Beauty",
"Entertainment/Shopping",
"Transportation",
"Uncategorized",
}
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: `
You are a helpful bot that categorizes transactions. You are given a transaction name raw from the bank and you need to categorize it. Please only respond in json format.
category should be one of the following:
` + strings.Join(categories, ",\n") + `
subCategory should provide more detail on the category. For example, if the category is "Food/Groceries", the subCategory could be "Groceries" or "Restaurants".
Expected output format:
{
"category": "",
"subCategory": "",
"transactionName": ""
}`,
},
{
Role: openai.ChatMessageRoleUser,
Content: rawTransaction,
},
}
hasCategory := false
for !hasCategory {
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: messages,
MaxTokens: 256,
Temperature: 0,
},
)
if err != nil {
return Transaction{}, fmt.Errorf("ChatCompletion request error: %w", err)
}
fmt.Println(resp.Choices[0].Message.Content)
transaction, err := parseAIOutput(resp.Choices[0].Message.Content)
if err != nil {
return Transaction{}, fmt.Errorf("ChatCompletion parsing error: %w, got response: %v", err, resp.Choices[0].Message.Content)
}
// check if category is in the categories slice
if contains(categories, transaction.Category) {
fmt.Println("Found category!")
hasCategory = true
return transaction, nil
} else if len(messages) < 5 {
fmt.Println("Did not find category, trying again...")
messages = append(messages, openai.ChatCompletionMessage{
Role: resp.Choices[0].Message.Role,
Content: resp.Choices[0].Message.Content,
Name: "",
})
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: "Please provide a valid category.",
Name: "",
})
} else {
break
}
}
return Transaction{}, fmt.Errorf("could not find category for transaction: %v\nMessages: %+v", rawTransaction, messages)
}
func contains(categories []string, category string) bool {
for _, c := range categories {
if c == category {
return true
}
}
return false
}
type Transaction struct {
Category string `json:"category"`
SubCategory string `json:"subCategory"`
Amount string `json:"amount"`
TransactionName string `json:"transactionName"`
IsoDate string `json:"isodate"`
}
func parseAIOutput(aiOutput string) (Transaction, error) {
var transaction Transaction
aiOutputParts := strings.Split(aiOutput, "{")
aiJson := "{" + aiOutputParts[len(aiOutputParts)-1]
err := json.Unmarshal([]byte(aiJson), &transaction)
if err != nil {
return Transaction{}, err
}
return transaction, nil
}
func readLinesFromFile(filename string) ([]string, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
var lines []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, err
}
return lines, nil
}
func readCSV(filePath string) ([]Transaction, error) {
file, err := os.Open(filePath)
if err != nil {
return []Transaction{}, fmt.Errorf("error opening file: %w", err)
}
defer file.Close()
reader := csv.NewReader(file)
records, err := reader.ReadAll()
if err != nil {
return []Transaction{}, fmt.Errorf("error reading csv: %w", err)
}
var transactions []Transaction
// Do something with the records
for _, record := range records {
//fmt.Println(record)
isoDate, err := time.Parse("01/02/2006", record[0])
if err != nil {
fmt.Println(fmt.Errorf("error parsing date: %w", err))
continue
}
t := Transaction{
IsoDate: isoDate.Format("2006-01-02"),
TransactionName: record[2],
Amount: record[7],
}
transactions = append(transactions, t)
}
return transactions, nil
}

1
raw-categories.json Normal file

File diff suppressed because one or more lines are too long