initial commit: create basic rag search and ingest
This commit is contained in:
128
backend/AI/ai.go
Normal file
128
backend/AI/ai.go
Normal file
@ -0,0 +1,128 @@
|
||||
package AI
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"os"
|
||||
)
|
||||
|
||||
// This package should use the OpenAI API to provide AI services.
|
||||
|
||||
type AI interface {
|
||||
// Get Embedding
|
||||
GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error)
|
||||
GetTokenCount(input string) (int, error)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
encodingName string
|
||||
model string
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
type AIOption func(*Service)
|
||||
|
||||
func NewAI(otps ...AIOption) (*Service, error) {
|
||||
a := Service{
|
||||
//baseURL: "https://api.openai.com",
|
||||
encodingName: "gpt-4o",
|
||||
model: openai.GPT4oMini,
|
||||
}
|
||||
|
||||
for _, opt := range otps {
|
||||
opt(&a)
|
||||
}
|
||||
|
||||
if a.apiKey == "" && os.Getenv("OPENAI_API_KEY") != "" {
|
||||
a.apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
if a.apiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
|
||||
config := openai.DefaultConfig(a.apiKey)
|
||||
if a.baseURL == "" && os.Getenv("OPENAI_BASE_URL") != "" {
|
||||
a.baseURL = os.Getenv("OPENAI_BASE_URL")
|
||||
}
|
||||
|
||||
if a.baseURL != "" {
|
||||
config.BaseURL = a.baseURL
|
||||
}
|
||||
|
||||
a.client = openai.NewClientWithConfig(config)
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func (a *Service) GetEmbeddings(ctx context.Context, text string) (openai.EmbeddingResponse, error) {
|
||||
embeddingRequest := openai.EmbeddingRequest{
|
||||
Input: text,
|
||||
Model: "text-embedding-3-small",
|
||||
}
|
||||
|
||||
embeddings, err := a.client.CreateEmbeddings(ctx, embeddingRequest)
|
||||
if err != nil {
|
||||
return openai.EmbeddingResponse{}, fmt.Errorf("error creating embeddings: %w", err)
|
||||
}
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func WithAPIKey(apiKey string) AIOption {
|
||||
return func(a *Service) {
|
||||
a.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func WithBaseURL(baseURL string) AIOption {
|
||||
return func(a *Service) {
|
||||
a.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithEncodingName(encodingName string) AIOption {
|
||||
return func(a *Service) {
|
||||
a.encodingName = encodingName
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Service) GetTokenCount(input string) (int, error) {
|
||||
tke, err := tiktoken.EncodingForModel(a.encodingName) // cached in "TIKTOKEN_CACHE_DIR"
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting encoding: %w", err)
|
||||
}
|
||||
token := tke.Encode(input, nil, nil)
|
||||
return len(token), nil
|
||||
}
|
||||
|
||||
// WIP: I'm still getting responses with limited utility. Probably don't use this function until it is figured out.
|
||||
func (a *Service) PreReason(ctx context.Context, text string) (string, error) {
|
||||
systemPrompt := "When you receive user content, your job is to systematically break down the content you receive. We will store your reasoning and any output you generate to help with future queries. This way we can more quickly respond to the user having already thought about the subject matter. You should consider as many angles as possible and be as thorough as possible in your reasoning, potentially add summaries of parts that seem important to you or interesting. You should also consider the potential for future queries and how you might be able to help with those."
|
||||
client := openai.NewClientWithConfig(openai.DefaultConfig(a.apiKey))
|
||||
preReasoningRequest := openai.ChatCompletionRequest{
|
||||
Model: openai.O4Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleDeveloper,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: text,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := client.CreateChatCompletion(ctx, preReasoningRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating chat completion: %w", err)
|
||||
}
|
||||
respText := resp.Choices[0].Message.Content
|
||||
if respText == "" {
|
||||
return "", fmt.Errorf("no response text")
|
||||
}
|
||||
fmt.Printf("%+v\n", resp.Choices[0])
|
||||
return respText, nil
|
||||
}
|
52
backend/cachedAPI/cachedAPI.go
Normal file
52
backend/cachedAPI/cachedAPI.go
Normal file
@ -0,0 +1,52 @@
|
||||
package cachedAPI
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"generic-rag/backend/datastore"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This package behaves like an API but uses libSQL as a cache that gets checked before the actual API is called.
|
||||
type CachedAPI interface {
|
||||
Get(url string, cacheTTL time.Duration) (string, error)
|
||||
}
|
||||
|
||||
type cachedAPI struct {
|
||||
mapper datastore.CacheStore
|
||||
}
|
||||
|
||||
func NewCachedAPI(mapper datastore.CacheStore) CachedAPI {
|
||||
return &cachedAPI{
|
||||
mapper: mapper,
|
||||
}
|
||||
}
|
||||
|
||||
func (c cachedAPI) Get(url string, cacheTTL time.Duration) (string, error) {
|
||||
response, found, err := c.mapper.CachedAPI(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting cached API response: %w", err)
|
||||
}
|
||||
if found {
|
||||
return response, nil
|
||||
}
|
||||
// Call the actual API
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error calling API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// Read the response
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading API response: %w", err)
|
||||
}
|
||||
// Save the response to the cache
|
||||
err = c.mapper.SaveAPIResponse(url, string(bodyBytes), cacheTTL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error saving API response: %w", err)
|
||||
}
|
||||
|
||||
return string(bodyBytes), nil
|
||||
}
|
167
backend/datastore/mapper.go
Normal file
167
backend/datastore/mapper.go
Normal file
@ -0,0 +1,167 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"generic-rag/backend/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CacheStore interface {
|
||||
CachedAPI(url string) (string, bool, error)
|
||||
SaveAPIResponse(url, response string, cacheTTL time.Duration) error
|
||||
}
|
||||
|
||||
type SearchStore interface {
|
||||
SaveEmbeddings(id, content string, embeddings []float32) error
|
||||
FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error)
|
||||
GetContentByID(id string) ([]types.SearchResponse, error)
|
||||
}
|
||||
|
||||
type Mapper struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewMapper(db *sql.DB) *Mapper {
|
||||
return &Mapper{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CachedAPI returns the cached API response for the given URL
|
||||
// If the URL is not in the cache it returns an empty string and false
|
||||
func (m *Mapper) CachedAPI(url string) (string, bool, error) {
|
||||
// Check the cache for the URL
|
||||
// If the URL is in the cache, return the cached response
|
||||
// Otherwise, call the API and cache the response
|
||||
|
||||
query := `SELECT response, created_at, ttl FROM cache WHERE url = ?`
|
||||
rows, err := m.db.Query(query, url)
|
||||
if err != nil {
|
||||
// norows error is not an error
|
||||
if err == sql.ErrNoRows {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, fmt.Errorf("error reading from cache url: %v | %w", url, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var response struct {
|
||||
Response string
|
||||
CreatedAt time.Time
|
||||
TTL time.Duration
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&response.Response, &response.CreatedAt, &response.TTL)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("error scanning cache response: %w", err)
|
||||
}
|
||||
// Check if the cache is expired
|
||||
if time.Since(response.CreatedAt) > response.TTL {
|
||||
return "", false, nil
|
||||
}
|
||||
return response.Response, true, nil
|
||||
}
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// SaveAPIResponse saves the API response to the cache
|
||||
func (m *Mapper) SaveAPIResponse(url, response string, cacheTTL time.Duration) error {
|
||||
// Insert the response into the cache
|
||||
query := `INSERT INTO cache (url, response, ttl) VALUES (?, ?, ?)`
|
||||
_, err := m.db.Exec(query, url, response, cacheTTL)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed: cache.url") {
|
||||
// Update the existing row if there is a UNIQUE constraint error
|
||||
updateQuery := `UPDATE cache SET response = ?, ttl = ? WHERE url = ?`
|
||||
_, updateErr := m.db.Exec(updateQuery, response, cacheTTL, url)
|
||||
if updateErr != nil {
|
||||
return fmt.Errorf("error updating cache response: %w", updateErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("error inserting cache response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mapper) SaveEmbeddings(id, content string, embeddings []float32) error {
|
||||
// Try to update the embeddings in the database
|
||||
updateQuery := `UPDATE searchable_content SET content = ?, full_emb = vector32(?), modified_at = ? WHERE trackingid = ?`
|
||||
result, err := m.db.Exec(updateQuery, content, serializeEmbeddings(embeddings), time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Check if any rows were updated
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking rows affected: %w", err)
|
||||
}
|
||||
|
||||
// If no rows were updated, insert the embeddings
|
||||
if rowsAffected == 0 {
|
||||
insertQuery := `INSERT INTO searchable_content (trackingid, content, full_emb, modified_at) VALUES (?, ?, vector32(?), ?)`
|
||||
_, err = m.db.Exec(insertQuery, id, content, serializeEmbeddings(embeddings), time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error inserting embeddings: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serializeEmbeddings(embeddings []float32) string {
|
||||
return strings.Join(strings.Split(fmt.Sprintf("%v", embeddings), " "), ", ")
|
||||
}
|
||||
|
||||
func (m *Mapper) FindRelevantContent(queryEmbeddings []float32, limit int) ([]types.SearchResponse, error) {
|
||||
// Find the relevant content in the database
|
||||
query := `SELECT searchable_content.trackingid, searchable_content.content FROM vector_top_k('emb_idx', vector32(?), ?) JOIN searchable_content ON id = searchable_content.rowid`
|
||||
rows, err := m.db.Query(query, serializeEmbeddings(queryEmbeddings), limit)
|
||||
if err != nil {
|
||||
// norows error is not an error
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("error querying embeddings: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []types.SearchResponse
|
||||
for rows.Next() {
|
||||
var result types.SearchResponse
|
||||
err = rows.Scan(&result.TrackingID, &result.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error scanning embeddings: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *Mapper) GetContentByID(id string) ([]types.SearchResponse, error) {
|
||||
// Get the content by ID
|
||||
query := `SELECT trackingid, content FROM searchable_content WHERE trackingid = ?`
|
||||
rows, err := m.db.Query(query, id)
|
||||
if err != nil {
|
||||
// norows error is not an error
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("error querying content by id: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []types.SearchResponse
|
||||
for rows.Next() {
|
||||
var result types.SearchResponse
|
||||
err = rows.Scan(&result.TrackingID, &result.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error scanning content by id: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
36
backend/datastore/mapper_test.go
Normal file
36
backend/datastore/mapper_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package datastore
|
||||
|
||||
import "testing"
|
||||
|
||||
func Benchmark_mySerializedEmbeddings(b *testing.B) {
|
||||
type args struct {
|
||||
embeddings []float32
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Test 1",
|
||||
args: args{
|
||||
embeddings: []float32{0.1, 0.2, 0.3},
|
||||
},
|
||||
want: "[0.1, 0.2, 0.3]",
|
||||
},
|
||||
{
|
||||
name: "Crazy long test",
|
||||
args: args{
|
||||
embeddings: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
|
||||
},
|
||||
want: "[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
b.Run(tt.name, func(t *testing.B) {
|
||||
if got := serializeEmbeddings(tt.args.embeddings); got != tt.want {
|
||||
t.Errorf("mySerializedEmbeddings() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
155
backend/main.go
Normal file
155
backend/main.go
Normal file
File diff suppressed because one or more lines are too long
@ -0,0 +1,8 @@
|
||||
CREATE TABLE searchable_content (
|
||||
trackingid TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
full_emb F32_BLOB(1536) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX emb_idx ON searchable_content (libsql_vector_idx(full_emb));
|
11
backend/migrations/2025-01-03-init.sql
Normal file
11
backend/migrations/2025-01-03-init.sql
Normal file
@ -0,0 +1,11 @@
|
||||
CREATE TABLE IF NOT EXISTS cache (
|
||||
id INTEGER PRIMARY KEY,
|
||||
url TEXT NOT NULL UNIQUE,
|
||||
response TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ttl INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX idx_url ON cache (url);
|
||||
CREATE INDEX idx_created_at ON cache (created_at);
|
||||
|
@ -0,0 +1,2 @@
|
||||
-- Adds a modified date to the searchable_content table
|
||||
ALTER TABLE searchable_content ADD COLUMN modified_at TIMESTAMP;
|
82
backend/search/search.go
Normal file
82
backend/search/search.go
Normal file
@ -0,0 +1,82 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"generic-rag/backend/AI"
|
||||
"generic-rag/backend/datastore"
|
||||
"generic-rag/backend/types"
|
||||
)
|
||||
|
||||
type Search interface {
|
||||
Search(query string, limit int) ([]types.SearchResponse, error)
|
||||
InsertContent(ctx context.Context, id string, content string) error
|
||||
GetContentByID(ctx context.Context, id string) ([]types.SearchResponse, error)
|
||||
}
|
||||
|
||||
type SearchOption func(s *search)
|
||||
|
||||
func NewSearch(opts ...SearchOption) (Search, error) {
|
||||
s := &search{}
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
if s.ai == nil {
|
||||
return nil, fmt.Errorf("AI is required")
|
||||
}
|
||||
if s.mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is required")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func WithMapper(mapper datastore.SearchStore) func(s *search) {
|
||||
return func(s *search) {
|
||||
s.mapper = mapper
|
||||
}
|
||||
}
|
||||
|
||||
func WithAI(ai AI.AI) func(s *search) {
|
||||
return func(s *search) {
|
||||
s.ai = ai
|
||||
}
|
||||
}
|
||||
|
||||
type search struct {
|
||||
ai AI.AI
|
||||
mapper datastore.SearchStore
|
||||
}
|
||||
|
||||
func (s *search) Search(query string, limit int) ([]types.SearchResponse, error) {
|
||||
// get embeddings for the query
|
||||
embeddings, err := s.ai.GetEmbeddings(context.Background(), query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting embeddings: %w", err)
|
||||
}
|
||||
if len(embeddings.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embeddings returned")
|
||||
}
|
||||
// find relevant content in the database
|
||||
return s.mapper.FindRelevantContent(embeddings.Data[0].Embedding, limit)
|
||||
}
|
||||
|
||||
func (s *search) InsertContent(ctx context.Context, id string, content string) error {
|
||||
// get embeddings for the content
|
||||
embeddings, err := s.ai.GetEmbeddings(ctx, content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting embeddings: %w", err)
|
||||
}
|
||||
if len(embeddings.Data) == 0 {
|
||||
return fmt.Errorf("no embeddings returned")
|
||||
}
|
||||
// save the embeddings to the database
|
||||
err = s.mapper.SaveEmbeddings(id, content, embeddings.Data[0].Embedding)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error saving embeddings: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *search) GetContentByID(ctx context.Context, id string) ([]types.SearchResponse, error) {
|
||||
return s.mapper.GetContentByID(id)
|
||||
}
|
6
backend/types/search.go
Normal file
6
backend/types/search.go
Normal file
@ -0,0 +1,6 @@
|
||||
package types
|
||||
|
||||
type SearchResponse struct {
|
||||
TrackingID string
|
||||
Content string
|
||||
}
|
Reference in New Issue
Block a user