126 lines
3.8 KiB
Go
126 lines
3.8 KiB
Go
package datastore
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"git.sa.vin/legislature-tracker/backend/types"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type CacheStore interface {
|
|
CachedAPI(url string) (string, bool, error)
|
|
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
|
|
}
|
|
|
|
type SearchStore interface {
|
|
SaveEmbeddings(id, content string, embeddings []float32) error
|
|
FindRelevantContent(queryEmbeddings []float32) ([]types.SearchResponse, error)
|
|
}
|
|
|
|
type Mapper struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewMapper(db *sql.DB) *Mapper {
|
|
return &Mapper{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// CachedAPI returns the cached API response for the given URL
|
|
// If the URL is not in the cache it returns an empty string and false
|
|
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
|
|
// Check the cache for the URL
|
|
// If the URL is in the cache, return the cached response
|
|
// Otherwise, call the API and cache the response
|
|
|
|
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
|
|
rows, err := m.db.Query(query, url)
|
|
if err != nil {
|
|
// norows error is not an error
|
|
if err == sql.ErrNoRows {
|
|
return "", false, nil
|
|
}
|
|
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var response struct {
|
|
Response string
|
|
CreatedAt time.Time
|
|
TTL time.Duration
|
|
}
|
|
for rows.Next() {
|
|
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("error scanning cache response: %w", err)
|
|
}
|
|
// Check if the cache is expired
|
|
if time.Since(response.CreatedAt) > response.TTL {
|
|
return "", false, nil
|
|
}
|
|
return response.Response, true, nil
|
|
}
|
|
return "", false, nil
|
|
}
|
|
|
|
// SaveAPIResponse saves the API response to the cache
|
|
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
|
|
// Insert the response into the cache
|
|
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
|
|
_, err := m.db.Exec(query, url, response, cacheTTL)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
|
|
// Update the existing row if there is a UNIQUE constraint error
|
|
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
|
|
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
|
|
if updateErr != nil {
|
|
return fmt.Errorf("error updating cache response: %w", updateErr)
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("error inserting cache response: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
|
|
// Insert the embeddings into the database
|
|
query := `INSERT INTO searchable_content (trackingid, content, full_emb) VALUES (?, ?, vector32(?))`
|
|
_, err := m.db.Exec(query, id, content, serializeEmbeddings(embeddings))
|
|
if err != nil {
|
|
return fmt.Errorf("error inserting embeddings: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func serializeEmbeddings(embeddings []float32) string {
|
|
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
|
|
}
|
|
|
|
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32) ([]types.SearchResponse, error) {
|
|
// Find the relevant content in the database
|
|
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), 10) JOIN searchable_content ON id = searchable_content.rowid`
|
|
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings))
|
|
if err != nil {
|
|
// norows error is not an error
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("error querying embeddings: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []types.SearchResponse
|
|
for rows.Next() {
|
|
var result types.SearchResponse
|
|
err = rows.Scan(&result.TrackingID, &result.Content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error scanning embeddings: %w", err)
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
return results, nil
|
|
}
|