diff --git a/go.mod b/go.mod index 6c1fc5c..a7870fa 100644 --- a/go.mod +++ b/go.mod @@ -17,4 +17,5 @@ require ( modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect modernc.org/sqlite v1.48.0 // indirect + somegit.dev/vikingowl/mistral-go-sdk v1.2.0 // indirect ) diff --git a/go.sum b/go.sum index f0eae39..7410652 100644 --- a/go.sum +++ b/go.sum @@ -31,3 +31,5 @@ modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4= modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +somegit.dev/vikingowl/mistral-go-sdk v1.2.0 h1:9NEGCKzw1Bu2c8LaSEKNlpj08iMsU0fkDFJO6W1Zh+Y= +somegit.dev/vikingowl/mistral-go-sdk v1.2.0/go.mod h1:pN7nQdOIYYEMRdwye5cSfymtwhZJHd+caK6J69Z4XMY= diff --git a/internal/llm/mistral.go b/internal/llm/mistral.go new file mode 100644 index 0000000..b1bb901 --- /dev/null +++ b/internal/llm/mistral.go @@ -0,0 +1,98 @@ +package llm + +import ( + "context" + "fmt" + "strconv" + "strings" + + mistral "somegit.dev/vikingowl/mistral-go-sdk" + "somegit.dev/vikingowl/mistral-go-sdk/chat" + + "somegit.dev/vikingowl/reddit-reader/internal/domain" +) + +// MistralClient wraps the Mistral AI SDK and implements the Summarizer interface. +type MistralClient struct { + client *mistral.Client + model string +} + +// MistralOption configures a MistralClient. +type MistralOption func(*mistralOpts) + +type mistralOpts struct { + baseURL string +} + +// WithMistralBaseURL overrides the Mistral API base URL (useful for tests). +func WithMistralBaseURL(url string) MistralOption { + return func(o *mistralOpts) { o.baseURL = url } +} + +// NewMistralClient creates a MistralClient using the given API key and model. +func NewMistralClient(apiKey, model string, opts ...MistralOption) *MistralClient { + var mo mistralOpts + for _, o := range opts { + o(&mo) + } + + var clientOpts []mistral.Option + if mo.baseURL != "" { + clientOpts = append(clientOpts, mistral.WithBaseURL(mo.baseURL)) + } + + return &MistralClient{ + client: mistral.NewClient(apiKey, clientOpts...), + model: model, + } +} + +// Score returns a relevance score in [0.0, 1.0] for the given post against the user's interests. +func (m *MistralClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) { + systemPrompt := buildScorePrompt(interests) + userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, truncate(post.SelfText, 500)) + + resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{ + Model: m.model, + Messages: []chat.Message{ + &chat.SystemMessage{Content: chat.TextContent(systemPrompt)}, + &chat.UserMessage{Content: chat.TextContent(userPrompt)}, + }, + }) + if err != nil { + return 0, fmt.Errorf("mistral score: %w", err) + } + if len(resp.Choices) == 0 { + return 0, fmt.Errorf("mistral score: no choices") + } + + text := strings.TrimSpace(resp.Choices[0].Message.Content.String()) + score, err := strconv.ParseFloat(text, 64) + if err != nil { + return 0, fmt.Errorf("mistral parse score %q: %w", text, err) + } + return score, nil +} + +// Summarize produces a 5-bullet summary of the given post. +func (m *MistralClient) Summarize(ctx context.Context, post domain.Post) (string, error) { + systemPrompt := "You are a concise summarizer. Given a Reddit post, produce exactly 5 bullet points summarizing the key information. Each bullet starts with '- '. No other text." + userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, post.SelfText) + + resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{ + Model: m.model, + Messages: []chat.Message{ + &chat.SystemMessage{Content: chat.TextContent(systemPrompt)}, + &chat.UserMessage{Content: chat.TextContent(userPrompt)}, + }, + }) + if err != nil { + return "", fmt.Errorf("mistral summarize: %w", err) + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("mistral summarize: no choices") + } + + return strings.TrimSpace(resp.Choices[0].Message.Content.String()), nil +} diff --git a/internal/llm/mistral_test.go b/internal/llm/mistral_test.go new file mode 100644 index 0000000..56dc5ad --- /dev/null +++ b/internal/llm/mistral_test.go @@ -0,0 +1,57 @@ +package llm_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/reddit-reader/internal/domain" + "somegit.dev/vikingowl/reddit-reader/internal/llm" +) + +func TestMistralScore(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890, + "choices": []map[string]any{ + {"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": "0.72"}}, + }, + "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + }) + })) + defer srv.Close() + + client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL)) + score, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{Description: "Go"}) + if err != nil { + t.Fatalf("Score: %v", err) + } + if score != 0.72 { + t.Errorf("score = %f, want 0.72", score) + } +} + +func TestMistralSummarize(t *testing.T) { + want := "- bullet one\n- bullet two\n- bullet three\n- bullet four\n- bullet five" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890, + "choices": []map[string]any{ + {"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": want}}, + }, + "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + }) + })) + defer srv.Close() + + client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL)) + got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "content"}) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + if got != want { + t.Errorf("summary = %q, want %q", got, want) + } +}