diff --git a/internal/cli/root.go b/internal/cli/root.go index 52d20b7..9fed4fb 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -65,7 +65,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&dbPath, "db", "", "path to SQLite database") rootCmd.PersistentFlags().BoolVar(&verbose, "verbose", false, "enable verbose output") rootCmd.PersistentFlags().StringVar(&profileName, "profile", "", "profile name to use") - rootCmd.PersistentFlags().StringVar(&llmFlag, "llm", "", "LLM provider (anthropic, openai, ollama, none)") + rootCmd.PersistentFlags().StringVar(&llmFlag, "llm", "", "LLM provider (anthropic, openai, gemini, ollama, none)") } // Execute runs the root command. @@ -105,6 +105,13 @@ func getLLMProvider() llm.Provider { return llm.NewNoop() } return llm.NewOpenAI(key, cfg.LLM.Model, nil) + case "gemini": + key := os.Getenv("GEMINI_API_KEY") + if key == "" { + fmt.Fprintln(os.Stderr, "Warning: GEMINI_API_KEY not set, LLM features disabled") + return llm.NewNoop() + } + return llm.NewGemini(key, cfg.LLM.Model, nil) case "ollama": return llm.NewOllama(cfg.LLM.Model, cfg.LLM.Endpoint, nil) default: diff --git a/internal/llm/gemini.go b/internal/llm/gemini.go new file mode 100644 index 0000000..8b5cb1d --- /dev/null +++ b/internal/llm/gemini.go @@ -0,0 +1,109 @@ +package llm + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +// Gemini implements Provider using the Google Gemini API. +type Gemini struct { + apiKey string + model string + client *http.Client + baseURL string +} + +// NewGemini creates a new Gemini provider. +func NewGemini(apiKey, model string, client *http.Client) *Gemini { + if client == nil { + client = &http.Client{Timeout: 60 * time.Second} + } + if model == "" { + model = "gemini-2.0-flash" + } + return &Gemini{apiKey: apiKey, model: model, client: client, baseURL: "https://generativelanguage.googleapis.com"} +} + +func (g *Gemini) Name() string { return "gemini" } + +type geminiRequest struct { + SystemInstruction *geminiContent `json:"systemInstruction,omitempty"` + Contents []geminiContent `json:"contents"` +} + +type geminiContent struct { + Parts []geminiPart `json:"parts"` +} + +type geminiPart struct { + Text string `json:"text"` +} + +type geminiResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"content"` + } `json:"candidates"` + Error *struct { + Message string `json:"message"` + } `json:"error"` +} + +func (g *Gemini) call(ctx context.Context, systemPrompt, userMessage string) (string, error) { + reqBody := geminiRequest{ + SystemInstruction: &geminiContent{ + Parts: []geminiPart{{Text: systemPrompt}}, + }, + Contents: []geminiContent{ + {Parts: []geminiPart{{Text: userMessage}}}, + }, + } + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshal request: %w", err) + } + + url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", g.baseURL, g.model, g.apiKey) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(body))) + if err != nil { + return "", fmt.Errorf("build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := g.client.Do(req) + if err != nil { + return "", fmt.Errorf("gemini call: %w", err) + } + defer resp.Body.Close() + + var result geminiResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + if result.Error != nil { + return "", fmt.Errorf("gemini error: %s", result.Error.Message) + } + if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 { + return "", fmt.Errorf("empty response from gemini") + } + return result.Candidates[0].Content.Parts[0].Text, nil +} + +func (g *Gemini) Summarize(ctx context.Context, input SummaryInput) (string, error) { + return g.call(ctx, SummarizeSystemPrompt(), BuildSummaryPrompt(input)) +} + +func (g *Gemini) RewriteAction(ctx context.Context, input ActionInput) (string, error) { + return g.call(ctx, RewriteActionSystemPrompt(), BuildRewriteActionPrompt(input)) +} + +func (g *Gemini) GenerateHeatPlan(ctx context.Context, input HeatPlanInput) (string, error) { + return g.call(ctx, HeatPlanSystemPrompt(), BuildHeatPlanPrompt(input)) +}