initial commit: create basic rag search and ingest

This commit is contained in:
2025-05-15 02:10:14 -06:00
commit 0d072032d9
16 changed files with 744 additions and 0 deletions

128
backend/AI/ai.go Normal file
View File

@ -0,0 +1,128 @@
package AI
import (
"context"
"fmt"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
"os"
)
// This package should use the OpenAI API to provide AI services.
type AI interface {
// Get Embedding
GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error)
GetTokenCount(input string) (int, error)
}
type Service struct {
apiKey string
baseURL string
encodingName string
model string
client *openai.Client
}
type AIOption func(*Service)
func NewAI(otps ...AIOption) (*Service, error) {
a := Service{
//baseURL: "https://api.openai.com",
encodingName: "gpt-4o",
model: openai.GPT4oMini,
}
for _, opt := range otps {
opt(&a)
}
if a.apiKey == "" && os.Getenv("OPENAI_API_KEY") != "" {
a.apiKey = os.Getenv("OPENAI_API_KEY")
}
if a.apiKey == "" {
return nil, fmt.Errorf("api key is required")
}
config := openai.DefaultConfig(a.apiKey)
if a.baseURL == "" && os.Getenv("OPENAI_BASE_URL") != "" {
a.baseURL = os.Getenv("OPENAI_BASE_URL")
}
if a.baseURL != "" {
config.BaseURL = a.baseURL
}
a.client = openai.NewClientWithConfig(config)
return &a, nil
}
func (a *Service) GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) {
embeddingRequest := openai.EmbeddingRequest{
Input: text,
Model: "text-embedding-3-small",
}
embeddings, err := a.client.CreateEmbeddings(ctx, embeddingRequest)
if err != nil {
return openai.EmbeddingResponse{}, fmt.Errorf("error creating embeddings: %w", err)
}
return embeddings, nil
}
func WithAPIKey(apiKey string) AIOption {
return func(a *Service) {
a.apiKey = apiKey
}
}
func WithBaseURL(baseURL string) AIOption {
return func(a *Service) {
a.baseURL = baseURL
}
}
func WithEncodingName(encodingName string) AIOption {
return func(a *Service) {
a.encodingName = encodingName
}
}
func (a *Service) GetTokenCount(input string) (int, error) {
tke, err := tiktoken.EncodingForModel(a.encodingName) // cached in "TIKTOKEN_CACHE_DIR"
if err != nil {
return 0, fmt.Errorf("error getting encoding: %w", err)
}
token := tke.Encode(input, nil, nil)
return len(token), nil
}
// WIP: I'm still getting responses with limited utility. Probably don't use this function until it is figured out.
func (a *Service) PreReason(ctx context.Context, text string) (string, error) {
systemPrompt := "When you receive user content, your job is to systematically break down the content you receive. We will store your reasoning and any output you generate to help with future queries. This way we can more quickly respond to the user having already thought about the subject matter. You should consider as many angles as possible and be as thorough as possible in your reasoning, potentially add summaries of parts that seem important to you or interesting. You should also consider the potential for future queries and how you might be able to help with those."
client := openai.NewClientWithConfig(openai.DefaultConfig(a.apiKey))
preReasoningRequest := openai.ChatCompletionRequest{
Model: openai.O4Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleDeveloper,
Content: systemPrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: text,
},
},
}
resp, err := client.CreateChatCompletion(ctx, preReasoningRequest)
if err != nil {
return "", fmt.Errorf("error creating chat completion: %w", err)
}
respText := resp.Choices[0].Message.Content
if respText == "" {
return "", fmt.Errorf("no response text")
}
fmt.Printf("%+v\n", resp.Choices[0])
return respText, nil
}

View File

@ -0,0 +1,52 @@
package cachedAPI
import (
"fmt"
"generic-rag/backend/datastore"
"io"
"net/http"
"time"
)
// This package behaves like an API but uses libSQL as a cache that gets checked before the actual API is called.
type CachedAPI interface {
Get(url string, cacheTTL time.Duration) (string, error)
}
type cachedAPI struct {
mapper datastore.CacheStore
}
func NewCachedAPI(mapper datastore.CacheStore) CachedAPI {
return &cachedAPI{
mapper: mapper,
}
}
func (c cachedAPI) Get(url string, cacheTTL time.Duration) (string, error) {
response, found, err := c.mapper.CachedAPI(url)
if err != nil {
return "", fmt.Errorf("error getting cached API response: %w", err)
}
if found {
return response, nil
}
// Call the actual API
resp, err := http.Get(url)
if err != nil {
return "", fmt.Errorf("error calling API: %w", err)
}
defer resp.Body.Close()
// Read the response
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading API response: %w", err)
}
// Save the response to the cache
err = c.mapper.SaveAPIResponse(url, string(bodyBytes), cacheTTL)
if err != nil {
return "", fmt.Errorf("error saving API response: %w", err)
}
return string(bodyBytes), nil
}

167
backend/datastore/mapper.go Normal file
View File

@ -0,0 +1,167 @@
package datastore
import (
"database/sql"
"fmt"
"generic-rag/backend/types"
"strings"
"time"
)
type CacheStore interface {
CachedAPI(url string) (string, bool, error)
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
}
type SearchStore interface {
SaveEmbeddings(id, content string, embeddings []float32) error
FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error)
GetContentByID(id string) ([]types.SearchResponse, error)
}
type Mapper struct {
db *sql.DB
}
func NewMapper(db *sql.DB) *Mapper {
return &Mapper{
db: db,
}
}
// CachedAPI returns the cached API response for the given URL
// If the URL is not in the cache it returns an empty string and false
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
// Check the cache for the URL
// If the URL is in the cache, return the cached response
// Otherwise, call the API and cache the response
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
rows, err := m.db.Query(query, url)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return "", false, nil
}
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
}
defer rows.Close()
var response struct {
Response string
CreatedAt time.Time
TTL time.Duration
}
for rows.Next() {
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
if err != nil {
return "", false, fmt.Errorf("error scanning cache response: %w", err)
}
// Check if the cache is expired
if time.Since(response.CreatedAt) > response.TTL {
return "", false, nil
}
return response.Response, true, nil
}
return "", false, nil
}
// SaveAPIResponse saves the API response to the cache
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
// Insert the response into the cache
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
_, err := m.db.Exec(query, url, response, cacheTTL)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
// Update the existing row if there is a UNIQUE constraint error
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
if updateErr != nil {
return fmt.Errorf("error updating cache response: %w", updateErr)
}
return nil
}
return fmt.Errorf("error inserting cache response: %w", err)
}
return nil
}
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
// Try to update the embeddings in the database
updateQuery := `UPDATE searchable_content SET content = ?, full_emb = vector32(?), modified_at = ? WHERE trackingid = ?`
result, err := m.db.Exec(updateQuery, content, serializeEmbeddings(embeddings), time.Now(), id)
if err != nil {
return fmt.Errorf("error updating embeddings: %w", err)
}
// Check if any rows were updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("error checking rows affected: %w", err)
}
// If no rows were updated, insert the embeddings
if rowsAffected == 0 {
insertQuery := `INSERT INTO searchable_content (trackingid, content, full_emb, modified_at) VALUES (?, ?, vector32(?), ?)`
_, err = m.db.Exec(insertQuery, id, content, serializeEmbeddings(embeddings), time.Now())
if err != nil {
return fmt.Errorf("error inserting embeddings: %w", err)
}
}
return nil
}
func serializeEmbeddings(embeddings []float32) string {
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
}
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error) {
// Find the relevant content in the database
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), ?) JOIN searchable_content ON id = searchable_content.rowid`
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings), limit)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error querying embeddings: %w", err)
}
defer rows.Close()
var results []types.SearchResponse
for rows.Next() {
var result types.SearchResponse
err = rows.Scan(&result.TrackingID, &result.Content)
if err != nil {
return nil, fmt.Errorf("error scanning embeddings: %w", err)
}
results = append(results, result)
}
return results, nil
}
func (m *Mapper) GetContentByID(id string) ([]types.SearchResponse, error) {
// Get the content by ID
query := `SELECT trackingid, content FROM searchable_content WHERE trackingid = ?`
rows, err := m.db.Query(query, id)
if err != nil {
// norows error is not an error
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error querying content by id: %w", err)
}
defer rows.Close()
var results []types.SearchResponse
for rows.Next() {
var result types.SearchResponse
err = rows.Scan(&result.TrackingID, &result.Content)
if err != nil {
return nil, fmt.Errorf("error scanning content by id: %w", err)
}
results = append(results, result)
}
return results, nil
}

View File

@ -0,0 +1,36 @@
package datastore
import "testing"
func Benchmark_mySerializedEmbeddings(b *testing.B) {
type args struct {
embeddings []float32
}
tests := []struct {
name string
args args
want string
}{
{
name: "Test 1",
args: args{
embeddings: []float32{0.1, 0.2, 0.3},
},
want: "[0.1, 0.2, 0.3]",
},
{
name: "Crazy long test",
args: args{
embeddings: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
},
want: "[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]",
},
}
for _, tt := range tests {
b.Run(tt.name, func(t *testing.B) {
if got := serializeEmbeddings(tt.args.embeddings); got != tt.want {
t.Errorf("mySerializedEmbeddings() = %v, want %v", got, tt.want)
}
})
}
}

155
backend/main.go Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,8 @@
CREATE TABLE searchable_content (
trackingid TEXT NOT NULL,
content TEXT NOT NULL,
full_emb F32_BLOB(1536) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX emb_idx ON searchable_content (libsql_vector_idx(full_emb));

View File

@ -0,0 +1,11 @@
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY,
url TEXT NOT NULL UNIQUE,
response TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ttl INTEGER DEFAULT 0
);
CREATE INDEX idx_url ON cache (url);
CREATE INDEX idx_created_at ON cache (created_at);

View File

@ -0,0 +1,2 @@
-- Adds a modified date to the searchable_content table
ALTER TABLE searchable_content ADD COLUMN modified_at TIMESTAMP;

82
backend/search/search.go Normal file
View File

@ -0,0 +1,82 @@
package search
import (
"context"
"fmt"
"generic-rag/backend/AI"
"generic-rag/backend/datastore"
"generic-rag/backend/types"
)
type Search interface {
Search(query string, limit int) ([]types.SearchResponse, error)
InsertContent(ctx context.Context, id string, content string) error
GetContentByID(ctx context.Context, id string) ([]types.SearchResponse, error)
}
type SearchOption func(s *search)
func NewSearch(opts ...SearchOption) (Search, error) {
s := &search{}
for _, opt := range opts {
opt(s)
}
if s.ai == nil {
return nil, fmt.Errorf("AI is required")
}
if s.mapper == nil {
return nil, fmt.Errorf("mapper is required")
}
return s, nil
}
func WithMapper(mapper datastore.SearchStore) func(s *search) {
return func(s *search) {
s.mapper = mapper
}
}
func WithAI(ai AI.AI) func(s *search) {
return func(s *search) {
s.ai = ai
}
}
type search struct {
ai AI.AI
mapper datastore.SearchStore
}
func (s *search) Search(query string, limit int) ([]types.SearchResponse, error) {
// get embeddings for the query
embeddings, err := s.ai.GetEmbeddings(context.Background(), query)
if err != nil {
return nil, fmt.Errorf("error getting embeddings: %w", err)
}
if len(embeddings.Data) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}
// find relevant content in the database
return s.mapper.FindRelevantContent(embeddings.Data[0].Embedding, limit)
}
func (s *search) InsertContent(ctx context.Context, id string, content string) error {
// get embeddings for the content
embeddings, err := s.ai.GetEmbeddings(ctx, content)
if err != nil {
return fmt.Errorf("error getting embeddings: %w", err)
}
if len(embeddings.Data) == 0 {
return fmt.Errorf("no embeddings returned")
}
// save the embeddings to the database
err = s.mapper.SaveEmbeddings(id, content, embeddings.Data[0].Embedding)
if err != nil {
return fmt.Errorf("error saving embeddings: %w", err)
}
return nil
}
func (s *search) GetContentByID(ctx context.Context, id string) ([]types.SearchResponse, error) {
return s.mapper.GetContentByID(id)
}

6
backend/types/search.go Normal file
View File

@ -0,0 +1,6 @@
package types
type SearchResponse struct {
TrackingID string
Content string
}