build backend to collect and search using embeddings

This commit is contained in:
2025-01-04 00:15:29 -07:00
parent c4521af5c2
commit 3be1648ee4
16 changed files with 679 additions and 489 deletions

99
backend/AI/ai.go Normal file
View File

@ -0,0 +1,99 @@
package AI
import (
"context"
"fmt"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
"os"
)
// This package should use the OpenAI API to provide AI services.
type AI interface {
// Get Embedding
GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error)
GetTokenCount(input string) (int, error)
}
type ai struct {
apiKey string
baseURL string
encodingName string
model string
client *openai.Client
}
type AIOption func(*ai)
func NewAI(otps ...AIOption) (AI, error) {
a := ai{
//baseURL: "https://api.openai.com",
encodingName: "gpt-4o",
model: openai.GPT4oMini,
}
for _, opt := range otps {
opt(&a)
}
if a.apiKey == "" && os.Getenv("OPENAI_API_KEY") != "" {
a.apiKey = os.Getenv("OPENAI_API_KEY")
}
if a.apiKey == "" {
return nil, fmt.Errorf("api key is required")
}
config := openai.DefaultConfig(a.apiKey)
if a.baseURL == "" && os.Getenv("OPENAI_BASE_URL") != "" {
a.baseURL = os.Getenv("OPENAI_BASE_URL")
}
if a.baseURL != "" {
config.BaseURL = a.baseURL
}
a.client = openai.NewClientWithConfig(config)
return a, nil
}
func (a ai) GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) {
embeddingRequest := openai.EmbeddingRequest{
Input: text,
Model: "text-embedding-3-small",
}
embeddings, err := a.client.CreateEmbeddings(ctx, embeddingRequest)
if err != nil {
return openai.EmbeddingResponse{}, fmt.Errorf("error creating embeddings: %w", err)
}
return embeddings, nil
}
func WithAPIKey(apiKey string) AIOption {
return func(a *ai) {
a.apiKey = apiKey
}
}
func WithBaseURL(baseURL string) AIOption {
return func(a *ai) {
a.baseURL = baseURL
}
}
func WithEncodingName(encodingName string) AIOption {
return func(a *ai) {
a.encodingName = encodingName
}
}
func (a ai) GetTokenCount(input string) (int, error) {
tke, err := tiktoken.EncodingForModel(a.encodingName) // cached in "TIKTOKEN_CACHE_DIR"
if err != nil {
return 0, fmt.Errorf("error getting encoding: %w", err)
}
token := tke.Encode(input, nil, nil)
return len(token), nil
}

73
backend/Leg/utah.go Normal file
View File

@ -0,0 +1,73 @@
package Leg
import (
"encoding/json"
"fmt"
"os"
"time"
"git.sa.vin/legislature-tracker/backend/cachedAPI"
"git.sa.vin/legislature-tracker/backend/types"
)
type UtahLeg interface {
GetBillList(year, session string) (types.UtahBillList, error)
GetBillDetails(year, session, billID string) (types.UtahBill, error)
}
type utahLeg struct {
cache cachedAPI.CachedAPI
}
var developerToken string
func NewUtahLeg(cache cachedAPI.CachedAPI) UtahLeg {
developerToken = os.Getenv("UTAH_DEV_TOKEN")
return &utahLeg{
cache: cache,
}
}
// GetBillList gets the list of bills for a given year and session,
// session should be one of "GS", "S#" where # is the session number
func (u utahLeg) GetBillList(year, session string) (types.UtahBillList, error) {
// if session is not GS it must start with S and end with a number
if session != "GS" && (session[0] != 'S' || session[1] < '0' || session[1] > '9') {
return types.UtahBillList{}, fmt.Errorf("session must be one of GS or S with some number")
}
respString, err := u.cache.Get(fmt.Sprintf("https://glen.le.utah.gov/bills/%v%v/billlist/%v", year, session, developerToken), time.Hour)
if err != nil {
return types.UtahBillList{}, fmt.Errorf("error getting bill list: %w", err)
}
if respString == "Invalid request" {
return types.UtahBillList{}, fmt.Errorf("invalid request")
}
var billList types.UtahBillList
err = json.Unmarshal([]byte(respString), &billList)
if err != nil {
return types.UtahBillList{}, fmt.Errorf("error unmarshalling bill list: %w", err)
}
return billList, nil
}
// GetBillDetails gets the details of a bill for a given year, session, and billID
// session should be one of "GS", "S2"
func (u utahLeg) GetBillDetails(year, session, billID string) (types.UtahBill, error) {
// if session is not GS it must start with S and end with a number
if session != "GS" && (session[0] != 'S' || session[1] < '0' || session[1] > '9') {
return types.UtahBill{}, fmt.Errorf("session must be one of GS or S with some number")
}
respString, err := u.cache.Get(fmt.Sprintf("https://glen.le.utah.gov/bills/%v%v/%v/%v", year, session, billID, developerToken), time.Hour)
if err != nil {
return types.UtahBill{}, fmt.Errorf("error getting bill details: %w", err)
}
if respString == "Invalid request" {
return types.UtahBill{}, fmt.Errorf("invalid request")
}
var bill types.UtahBill
err = json.Unmarshal([]byte(respString), &bill)
if err != nil {
return types.UtahBill{}, fmt.Errorf("error unmarshalling bill details: %w", err)
}
return bill, nil
}

View File

@ -0,0 +1,52 @@
package cachedAPI
import (
"fmt"
"git.sa.vin/legislature-tracker/backend/datastore"
"io"
"net/http"
"time"
)
// This package behaves like an API but uses libSQL as a cache that gets checked before the actual API is called.
type CachedAPI interface {
Get(url string, cacheTTL time.Duration) (string, error)
}
type cachedAPI struct {
mapper datastore.CacheStore
}
func NewCachedAPI(mapper datastore.CacheStore) CachedAPI {
return &cachedAPI{
mapper: mapper,
}
}
func (c cachedAPI) Get(url string, cacheTTL time.Duration) (string, error) {
response, found, err := c.mapper.CachedAPI(url)
if err != nil {
return "", fmt.Errorf("error getting cached API response: %w", err)
}
if found {
return response, nil
}
// Call the actual API
resp, err := http.Get(url)
if err != nil {
return "", fmt.Errorf("error calling API: %w", err)
}
defer resp.Body.Close()
// Read the response
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading API response: %w", err)
}
// Save the response to the cache
err = c.mapper.SaveAPIResponse(url, string(bodyBytes), cacheTTL)
if err != nil {
return "", fmt.Errorf("error saving API response: %w", err)
}
return string(bodyBytes), nil
}

125
backend/datastore/mapper.go Normal file
View File

@ -0,0 +1,125 @@
package datastore
import (
"database/sql"
"fmt"
"git.sa.vin/legislature-tracker/backend/types"
"strings"
"time"
)
type CacheStore interface {
CachedAPI(url string) (string, bool, error)
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
}
type SearchStore interface {
SaveEmbeddings(id, content string, embeddings []float32) error
FindRelevantContent(queryEmbeddings []float32) ([]types.SearchResponse, error)
}
type Mapper struct {
db *sql.DB
}
func NewMapper(db *sql.DB) *Mapper {
return &Mapper{
db: db,
}
}
// CachedAPI returns the cached API response for the given URL
// If the URL is not in the cache it returns an empty string and false
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
// Check the cache for the URL
// If the URL is in the cache, return the cached response
// Otherwise, call the API and cache the response
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
rows, err := m.db.Query(query, url)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return "", false, nil
}
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
}
defer rows.Close()
var response struct {
Response string
CreatedAt time.Time
TTL time.Duration
}
for rows.Next() {
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
if err != nil {
return "", false, fmt.Errorf("error scanning cache response: %w", err)
}
// Check if the cache is expired
if time.Since(response.CreatedAt) > response.TTL {
return "", false, nil
}
return response.Response, true, nil
}
return "", false, nil
}
// SaveAPIResponse saves the API response to the cache
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
// Insert the response into the cache
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
_, err := m.db.Exec(query, url, response, cacheTTL)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
// Update the existing row if there is a UNIQUE constraint error
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
if updateErr != nil {
return fmt.Errorf("error updating cache response: %w", updateErr)
}
return nil
}
return fmt.Errorf("error inserting cache response: %w", err)
}
return nil
}
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
// Insert the embeddings into the database
query := `INSERT INTO searchable_content (trackingid, content, full_emb) VALUES (?, ?, vector32(?))`
_, err := m.db.Exec(query, id, content, serializeEmbeddings(embeddings))
if err != nil {
return fmt.Errorf("error inserting embeddings: %w", err)
}
return nil
}
func serializeEmbeddings(embeddings []float32) string {
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
}
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32) ([]types.SearchResponse, error) {
// Find the relevant content in the database
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), 10) JOIN searchable_content ON id = searchable_content.rowid`
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings))
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error querying embeddings: %w", err)
}
defer rows.Close()
var results []types.SearchResponse
for rows.Next() {
var result types.SearchResponse
err = rows.Scan(&result.TrackingID, &result.Content)
if err != nil {
return nil, fmt.Errorf("error scanning embeddings: %w", err)
}
results = append(results, result)
}
return results, nil
}

View File

@ -0,0 +1,36 @@
package datastore
import "testing"
func Benchmark_mySerializedEmbeddings(b *testing.B) {
type args struct {
embeddings []float32
}
tests := []struct {
name string
args args
want string
}{
{
name: "Test 1",
args: args{
embeddings: []float32{0.1, 0.2, 0.3},
},
want: "[0.1, 0.2, 0.3]",
},
{
name: "Crazy long test",
args: args{
embeddings: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
},
want: "[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]",
},
}
for _, tt := range tests {
b.Run(tt.name, func(t *testing.B) {
if got := serializeEmbeddings(tt.args.embeddings); got != tt.want {
t.Errorf("mySerializedEmbeddings() = %v, want %v", got, tt.want)
}
})
}
}

79
backend/main.go Normal file
View File

@ -0,0 +1,79 @@
package main
import (
"embed"
"git.sa.vin/legislature-tracker/backend/AI"
"git.sa.vin/legislature-tracker/backend/search"
"log"
"os"
"git.sa.vin/legislature-tracker/backend/Leg"
"git.sa.vin/legislature-tracker/backend/cachedAPI"
"git.sa.vin/legislature-tracker/backend/datastore"
"github.com/payne8/go-libsql-dual-driver"
)
//go:embed migrations/*.sql
var migrationFiles embed.FS
func main() {
logger := log.New(os.Stdout, "any-remark", log.LstdFlags)
primaryUrl := os.Getenv("LIBSQL_DATABASE_URL")
authToken := os.Getenv("LIBSQL_AUTH_TOKEN")
tdb, err := libsqldb.NewLibSqlDB(
primaryUrl,
libsqldb.WithMigrationFiles(migrationFiles),
libsqldb.WithAuthToken(authToken),
libsqldb.WithLocalDBName("local.db"), // will not be used for remote-only
)
if err != nil {
logger.Printf("failed to open db %s: %s", primaryUrl, err)
log.Fatalln(err)
return
}
err = tdb.Migrate()
if err != nil {
logger.Printf("failed to migrate db %s: %s", primaryUrl, err)
log.Fatalln(err)
return
}
mapper := datastore.NewMapper(tdb.DB)
api := cachedAPI.NewCachedAPI(mapper)
utah := Leg.NewUtahLeg(api)
ai, err := AI.NewAI()
if err != nil {
log.Fatalf("error creating AI: %v", err)
}
searchService, err := search.NewSearch(search.WithAI(ai), search.WithMapper(mapper))
if err != nil {
log.Fatalf("error creating search: %v", err)
}
test, err := utah.GetBillList("2024", "GS")
if err != nil {
log.Fatalf("error getting bill list: %v", err)
}
log.Printf("bill list: %+v", test)
test2, err := utah.GetBillDetails("2024", "GS", "HB0001")
if err != nil {
log.Fatalf("error getting bill details: %v", err)
}
log.Printf("bill details: %+v", test2)
//err = searchService.InsertContent(context.Background(), test2.TrackingID, test2.GeneralProvisions+" "+test2.HilightedProvisions)
//if err != nil {
// log.Fatalf("error inserting content: %v", err)
//}
results, err := searchService.Search("I'm looking for a bill that affects public education")
if err != nil {
log.Fatalf("error searching: %v", err)
}
log.Printf("search results: %+v", results)
}

View File

@ -0,0 +1,8 @@
CREATE TABLE searchable_content (
trackingid TEXT NOT NULL,
content TEXT NOT NULL,
full_emb F32_BLOB(1536) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX emb_idx ON searchable_content (libsql_vector_idx(full_emb));

View File

@ -0,0 +1,11 @@
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY,
url TEXT NOT NULL UNIQUE,
response TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ttl INTEGER DEFAULT 0
);
CREATE INDEX idx_url ON cache (url);
CREATE INDEX idx_created_at ON cache (created_at);

77
backend/search/search.go Normal file
View File

@ -0,0 +1,77 @@
package search
import (
"context"
"fmt"
"git.sa.vin/legislature-tracker/backend/AI"
"git.sa.vin/legislature-tracker/backend/datastore"
"git.sa.vin/legislature-tracker/backend/types"
)
type Search interface {
Search(query string) ([]types.SearchResponse, error)
InsertContent(ctx context.Context, id string, content string) error
}
type SearchOption func(s *search)
func NewSearch(opts ...SearchOption) (Search, error) {
s := &search{}
for _, opt := range opts {
opt(s)
}
if s.ai == nil {
return nil, fmt.Errorf("AI is required")
}
if s.mapper == nil {
return nil, fmt.Errorf("mapper is required")
}
return s, nil
}
func WithMapper(mapper datastore.SearchStore) func(s *search) {
return func(s *search) {
s.mapper = mapper
}
}
func WithAI(ai AI.AI) func(s *search) {
return func(s *search) {
s.ai = ai
}
}
type search struct {
ai AI.AI
mapper datastore.SearchStore
}
func (s search) Search(query string) ([]types.SearchResponse, error) {
// get embeddings for the query
embeddings, err := s.ai.GetEmbeddings(context.Background(), query)
if err != nil {
return nil, fmt.Errorf("error getting embeddings: %w", err)
}
if len(embeddings.Data) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}
// find relevant content in the database
return s.mapper.FindRelevantContent(embeddings.Data[0].Embedding)
}
func (s search) InsertContent(ctx context.Context, id string, content string) error {
// get embeddings for the content
embeddings, err := s.ai.GetEmbeddings(ctx, content)
if err != nil {
return fmt.Errorf("error getting embeddings: %w", err)
}
if len(embeddings.Data) == 0 {
return fmt.Errorf("no embeddings returned")
}
// save the embeddings to the database
err = s.mapper.SaveEmbeddings(id, content, embeddings.Data[0].Embedding)
if err != nil {
return fmt.Errorf("error saving embeddings: %w", err)
}
return nil
}

6
backend/types/search.go Normal file
View File

@ -0,0 +1,6 @@
package types
type SearchResponse struct {
TrackingID string
Content string
}

33
backend/types/utah.go Normal file
View File

@ -0,0 +1,33 @@
package types
// UtahBill is a struct that represents a bill in the Utah legislature
type UtahBill struct {
Bill string `json:"bill"`
Version string `json:"version"`
ShortTitle string `json:"shorttitle"`
Sponsor string `json:"sponsor"`
FloorSponsor string `json:"floorsponsor"`
GeneralProvisions string `json:"generalprovisions"`
HilightedProvisions string `json:"hilightedprovisions"`
Monies string `json:"monies"`
Attorney string `json:"attorney"`
FiscalAnalyst string `json:"fiscalanalyst"`
LastAction string `json:"lastaction"`
LastActionOwner string `json:"lastactionowner"`
LastActionTime string `json:"lastactiontime"`
TrackingID string `json:"trackingid"`
Subjects []string `json:"subjects"`
CodeSections []string `json:"codesections"`
Agendas []string `json:"agendas"`
}
// UtahBillListItem is a struct that represents a bill in a list of bills
type UtahBillListItem struct {
Number string `json:"number"`
UpdateTime string `json:"updatetime"`
}
// UtahBillList is a struct that represents a list of bills in the Utah legislature
type UtahBillList struct {
Bills []UtahBillListItem `json:"bills"`
}