initial commit: create basic rag search and ingest
This commit is contained in:
167
backend/datastore/mapper.go
Normal file
167
backend/datastore/mapper.go
Normal file
@ -0,0 +1,167 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"generic-rag/backend/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CacheStore interface {
|
||||
CachedAPI(url string) (string, bool, error)
|
||||
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
|
||||
}
|
||||
|
||||
type SearchStore interface {
|
||||
SaveEmbeddings(id, content string, embeddings []float32) error
|
||||
FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error)
|
||||
GetContentByID(id string) ([]types.SearchResponse, error)
|
||||
}
|
||||
|
||||
type Mapper struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewMapper(db *sql.DB) *Mapper {
|
||||
return &Mapper{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CachedAPI returns the cached API response for the given URL
|
||||
// If the URL is not in the cache it returns an empty string and false
|
||||
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
|
||||
// Check the cache for the URL
|
||||
// If the URL is in the cache, return the cached response
|
||||
// Otherwise, call the API and cache the response
|
||||
|
||||
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
|
||||
rows, err := m.db.Query(query, url)
|
||||
if err != nil {
|
||||
// norows error is not an error
|
||||
if err == sql.ErrNoRows {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var response struct {
|
||||
Response string
|
||||
CreatedAt time.Time
|
||||
TTL time.Duration
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("error scanning cache response: %w", err)
|
||||
}
|
||||
// Check if the cache is expired
|
||||
if time.Since(response.CreatedAt) > response.TTL {
|
||||
return "", false, nil
|
||||
}
|
||||
return response.Response, true, nil
|
||||
}
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// SaveAPIResponse saves the API response to the cache
|
||||
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
|
||||
// Insert the response into the cache
|
||||
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
|
||||
_, err := m.db.Exec(query, url, response, cacheTTL)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
|
||||
// Update the existing row if there is a UNIQUE constraint error
|
||||
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
|
||||
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
|
||||
if updateErr != nil {
|
||||
return fmt.Errorf("error updating cache response: %w", updateErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("error inserting cache response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
|
||||
// Try to update the embeddings in the database
|
||||
updateQuery := `UPDATE searchable_content SET content = ?, full_emb = vector32(?), modified_at = ? WHERE trackingid = ?`
|
||||
result, err := m.db.Exec(updateQuery, content, serializeEmbeddings(embeddings), time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Check if any rows were updated
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking rows affected: %w", err)
|
||||
}
|
||||
|
||||
// If no rows were updated, insert the embeddings
|
||||
if rowsAffected == 0 {
|
||||
insertQuery := `INSERT INTO searchable_content (trackingid, content, full_emb, modified_at) VALUES (?, ?, vector32(?), ?)`
|
||||
_, err = m.db.Exec(insertQuery, id, content, serializeEmbeddings(embeddings), time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error inserting embeddings: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serializeEmbeddings(embeddings []float32) string {
|
||||
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
|
||||
}
|
||||
|
||||
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error) {
|
||||
// Find the relevant content in the database
|
||||
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), ?) JOIN searchable_content ON id = searchable_content.rowid`
|
||||
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings), limit)
|
||||
if err != nil {
|
||||
// norows error is not an error
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("error querying embeddings: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []types.SearchResponse
|
||||
for rows.Next() {
|
||||
var result types.SearchResponse
|
||||
err = rows.Scan(&result.TrackingID, &result.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error scanning embeddings: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *Mapper) GetContentByID(id string) ([]types.SearchResponse, error) {
|
||||
// Get the content by ID
|
||||
query := `SELECT trackingid, content FROM searchable_content WHERE trackingid = ?`
|
||||
rows, err := m.db.Query(query, id)
|
||||
if err != nil {
|
||||
// norows error is not an error
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("error querying content by id: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []types.SearchResponse
|
||||
for rows.Next() {
|
||||
var result types.SearchResponse
|
||||
err = rows.Scan(&result.TrackingID, &result.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error scanning content by id: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
Reference in New Issue
Block a user