Files
generic-rag/backend/datastore/mapper.go

168 lines
5.1 KiB
Go

package datastore
import (
"database/sql"
"fmt"
"generic-rag/backend/types"
"strings"
"time"
)
type CacheStore interface {
CachedAPI(url string) (string, bool, error)
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
}
type SearchStore interface {
SaveEmbeddings(id, content string, embeddings []float32) error
FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error)
GetContentByID(id string) ([]types.SearchResponse, error)
}
type Mapper struct {
db *sql.DB
}
func NewMapper(db *sql.DB) *Mapper {
return &Mapper{
db: db,
}
}
// CachedAPI returns the cached API response for the given URL
// If the URL is not in the cache it returns an empty string and false
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
// Check the cache for the URL
// If the URL is in the cache, return the cached response
// Otherwise, call the API and cache the response
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
rows, err := m.db.Query(query, url)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return "", false, nil
}
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
}
defer rows.Close()
var response struct {
Response string
CreatedAt time.Time
TTL time.Duration
}
for rows.Next() {
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
if err != nil {
return "", false, fmt.Errorf("error scanning cache response: %w", err)
}
// Check if the cache is expired
if time.Since(response.CreatedAt) > response.TTL {
return "", false, nil
}
return response.Response, true, nil
}
return "", false, nil
}
// SaveAPIResponse saves the API response to the cache
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
// Insert the response into the cache
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
_, err := m.db.Exec(query, url, response, cacheTTL)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
// Update the existing row if there is a UNIQUE constraint error
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
if updateErr != nil {
return fmt.Errorf("error updating cache response: %w", updateErr)
}
return nil
}
return fmt.Errorf("error inserting cache response: %w", err)
}
return nil
}
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
// Try to update the embeddings in the database
updateQuery := `UPDATE searchable_content SET content = ?, full_emb = vector32(?), modified_at = ? WHERE trackingid = ?`
result, err := m.db.Exec(updateQuery, content, serializeEmbeddings(embeddings), time.Now(), id)
if err != nil {
return fmt.Errorf("error updating embeddings: %w", err)
}
// Check if any rows were updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("error checking rows affected: %w", err)
}
// If no rows were updated, insert the embeddings
if rowsAffected == 0 {
insertQuery := `INSERT INTO searchable_content (trackingid, content, full_emb, modified_at) VALUES (?, ?, vector32(?), ?)`
_, err = m.db.Exec(insertQuery, id, content, serializeEmbeddings(embeddings), time.Now())
if err != nil {
return fmt.Errorf("error inserting embeddings: %w", err)
}
}
return nil
}
func serializeEmbeddings(embeddings []float32) string {
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
}
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error) {
// Find the relevant content in the database
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), ?) JOIN searchable_content ON id = searchable_content.rowid`
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings), limit)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error querying embeddings: %w", err)
}
defer rows.Close()
var results []types.SearchResponse
for rows.Next() {
var result types.SearchResponse
err = rows.Scan(&result.TrackingID, &result.Content)
if err != nil {
return nil, fmt.Errorf("error scanning embeddings: %w", err)
}
results = append(results, result)
}
return results, nil
}
func (m *Mapper) GetContentByID(id string) ([]types.SearchResponse, error) {
// Get the content by ID
query := `SELECT trackingid, content FROM searchable_content WHERE trackingid = ?`
rows, err := m.db.Query(query, id)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error querying content by id: %w", err)
}
defer rows.Close()
var results []types.SearchResponse
for rows.Next() {
var result types.SearchResponse
err = rows.Scan(&result.TrackingID, &result.Content)
if err != nil {
return nil, fmt.Errorf("error scanning content by id: %w", err)
}
results = append(results, result)
}
return results, nil
}