feat(llm): Mistral SDK backend
This commit is contained in:
1
go.mod
1
go.mod
@@ -17,4 +17,5 @@ require (
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.48.0 // indirect
|
||||
somegit.dev/vikingowl/mistral-go-sdk v1.2.0 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -31,3 +31,5 @@ modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4=
|
||||
modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
somegit.dev/vikingowl/mistral-go-sdk v1.2.0 h1:9NEGCKzw1Bu2c8LaSEKNlpj08iMsU0fkDFJO6W1Zh+Y=
|
||||
somegit.dev/vikingowl/mistral-go-sdk v1.2.0/go.mod h1:pN7nQdOIYYEMRdwye5cSfymtwhZJHd+caK6J69Z4XMY=
|
||||
|
||||
98
internal/llm/mistral.go
Normal file
98
internal/llm/mistral.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
mistral "somegit.dev/vikingowl/mistral-go-sdk"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
||||
)
|
||||
|
||||
// MistralClient wraps the Mistral AI SDK and implements the Summarizer interface.
|
||||
type MistralClient struct {
|
||||
client *mistral.Client
|
||||
model string
|
||||
}
|
||||
|
||||
// MistralOption configures a MistralClient.
|
||||
type MistralOption func(*mistralOpts)
|
||||
|
||||
type mistralOpts struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// WithMistralBaseURL overrides the Mistral API base URL (useful for tests).
|
||||
func WithMistralBaseURL(url string) MistralOption {
|
||||
return func(o *mistralOpts) { o.baseURL = url }
|
||||
}
|
||||
|
||||
// NewMistralClient creates a MistralClient using the given API key and model.
|
||||
func NewMistralClient(apiKey, model string, opts ...MistralOption) *MistralClient {
|
||||
var mo mistralOpts
|
||||
for _, o := range opts {
|
||||
o(&mo)
|
||||
}
|
||||
|
||||
var clientOpts []mistral.Option
|
||||
if mo.baseURL != "" {
|
||||
clientOpts = append(clientOpts, mistral.WithBaseURL(mo.baseURL))
|
||||
}
|
||||
|
||||
return &MistralClient{
|
||||
client: mistral.NewClient(apiKey, clientOpts...),
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// Score returns a relevance score in [0.0, 1.0] for the given post against the user's interests.
|
||||
func (m *MistralClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) {
|
||||
systemPrompt := buildScorePrompt(interests)
|
||||
userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, truncate(post.SelfText, 500))
|
||||
|
||||
resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{
|
||||
Model: m.model,
|
||||
Messages: []chat.Message{
|
||||
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
|
||||
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("mistral score: %w", err)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return 0, fmt.Errorf("mistral score: no choices")
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Choices[0].Message.Content.String())
|
||||
score, err := strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("mistral parse score %q: %w", text, err)
|
||||
}
|
||||
return score, nil
|
||||
}
|
||||
|
||||
// Summarize produces a 5-bullet summary of the given post.
|
||||
func (m *MistralClient) Summarize(ctx context.Context, post domain.Post) (string, error) {
|
||||
systemPrompt := "You are a concise summarizer. Given a Reddit post, produce exactly 5 bullet points summarizing the key information. Each bullet starts with '- '. No other text."
|
||||
userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, post.SelfText)
|
||||
|
||||
resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{
|
||||
Model: m.model,
|
||||
Messages: []chat.Message{
|
||||
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
|
||||
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("mistral summarize: %w", err)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("mistral summarize: no choices")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(resp.Choices[0].Message.Content.String()), nil
|
||||
}
|
||||
57
internal/llm/mistral_test.go
Normal file
57
internal/llm/mistral_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package llm_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/llm"
|
||||
)
|
||||
|
||||
func TestMistralScore(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890,
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": "0.72"}},
|
||||
},
|
||||
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL))
|
||||
score, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{Description: "Go"})
|
||||
if err != nil {
|
||||
t.Fatalf("Score: %v", err)
|
||||
}
|
||||
if score != 0.72 {
|
||||
t.Errorf("score = %f, want 0.72", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMistralSummarize(t *testing.T) {
|
||||
want := "- bullet one\n- bullet two\n- bullet three\n- bullet four\n- bullet five"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890,
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": want}},
|
||||
},
|
||||
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL))
|
||||
got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "content"})
|
||||
if err != nil {
|
||||
t.Fatalf("Summarize: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("summary = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user