168 lines
5.1 KiB
Go
168 lines
5.1 KiB
Go
package datastore
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"generic-rag/backend/types"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type CacheStore interface {
|
|
CachedAPI(url string) (string, bool, error)
|
|
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
|
|
}
|
|
|
|
type SearchStore interface {
|
|
SaveEmbeddings(id, content string, embeddings []float32) error
|
|
FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error)
|
|
GetContentByID(id string) ([]types.SearchResponse, error)
|
|
}
|
|
|
|
type Mapper struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewMapper(db *sql.DB) *Mapper {
|
|
return &Mapper{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// CachedAPI returns the cached API response for the given URL
|
|
// If the URL is not in the cache it returns an empty string and false
|
|
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
|
|
// Check the cache for the URL
|
|
// If the URL is in the cache, return the cached response
|
|
// Otherwise, call the API and cache the response
|
|
|
|
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
|
|
rows, err := m.db.Query(query, url)
|
|
if err != nil {
|
|
// norows error is not an error
|
|
if err == sql.ErrNoRows {
|
|
return "", false, nil
|
|
}
|
|
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var response struct {
|
|
Response string
|
|
CreatedAt time.Time
|
|
TTL time.Duration
|
|
}
|
|
for rows.Next() {
|
|
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("error scanning cache response: %w", err)
|
|
}
|
|
// Check if the cache is expired
|
|
if time.Since(response.CreatedAt) > response.TTL {
|
|
return "", false, nil
|
|
}
|
|
return response.Response, true, nil
|
|
}
|
|
return "", false, nil
|
|
}
|
|
|
|
// SaveAPIResponse saves the API response to the cache
|
|
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
|
|
// Insert the response into the cache
|
|
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
|
|
_, err := m.db.Exec(query, url, response, cacheTTL)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
|
|
// Update the existing row if there is a UNIQUE constraint error
|
|
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
|
|
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
|
|
if updateErr != nil {
|
|
return fmt.Errorf("error updating cache response: %w", updateErr)
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("error inserting cache response: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
|
|
// Try to update the embeddings in the database
|
|
updateQuery := `UPDATE searchable_content SET content = ?, full_emb = vector32(?), modified_at = ? WHERE trackingid = ?`
|
|
result, err := m.db.Exec(updateQuery, content, serializeEmbeddings(embeddings), time.Now(), id)
|
|
if err != nil {
|
|
return fmt.Errorf("error updating embeddings: %w", err)
|
|
}
|
|
|
|
// Check if any rows were updated
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("error checking rows affected: %w", err)
|
|
}
|
|
|
|
// If no rows were updated, insert the embeddings
|
|
if rowsAffected == 0 {
|
|
insertQuery := `INSERT INTO searchable_content (trackingid, content, full_emb, modified_at) VALUES (?, ?, vector32(?), ?)`
|
|
_, err = m.db.Exec(insertQuery, id, content, serializeEmbeddings(embeddings), time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("error inserting embeddings: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func serializeEmbeddings(embeddings []float32) string {
|
|
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
|
|
}
|
|
|
|
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error) {
|
|
// Find the relevant content in the database
|
|
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), ?) JOIN searchable_content ON id = searchable_content.rowid`
|
|
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings), limit)
|
|
if err != nil {
|
|
// norows error is not an error
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("error querying embeddings: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []types.SearchResponse
|
|
for rows.Next() {
|
|
var result types.SearchResponse
|
|
err = rows.Scan(&result.TrackingID, &result.Content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error scanning embeddings: %w", err)
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (m *Mapper) GetContentByID(id string) ([]types.SearchResponse, error) {
|
|
// Get the content by ID
|
|
query := `SELECT trackingid, content FROM searchable_content WHERE trackingid = ?`
|
|
rows, err := m.db.Query(query, id)
|
|
if err != nil {
|
|
// norows error is not an error
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("error querying content by id: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []types.SearchResponse
|
|
for rows.Next() {
|
|
var result types.SearchResponse
|
|
err = rows.Scan(&result.TrackingID, &result.Content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error scanning content by id: %w", err)
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
return results, nil
|
|
}
|