diff --git a/README.md b/README.md index e986512..d2633c8 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

Vessel

- A modern, feature-rich web interface for Ollama + A modern, feature-rich web interface for local LLMs

@@ -28,13 +28,14 @@ **Vessel** is intentionally focused on: -- A clean, local-first UI for **Ollama** +- A clean, local-first UI for **local LLMs** +- **Multiple backends**: Ollama, llama.cpp, LM Studio - Minimal configuration - Low visual and cognitive overhead - Doing a small set of things well If you want a **universal, highly configurable platform** → [open-webui](https://github.com/open-webui/open-webui) is a great choice. -If you want a **small, focused UI for local Ollama usage** → Vessel is built for that. +If you want a **small, focused UI for local LLM usage** → Vessel is built for that. --- @@ -65,7 +66,13 @@ If you want a **small, focused UI for local Ollama usage** → Vessel is built f - Agentic tool calling with chain-of-thought reasoning - Test tools before saving with the built-in testing panel -### Models +### LLM Backends +- **Ollama** — Full model management, pull/delete/create custom models +- **llama.cpp** — High-performance inference with GGUF models +- **LM Studio** — Desktop app integration +- Switch backends without restart, auto-detection of available backends + +### Models (Ollama) - Browse and pull models from ollama.com - Create custom models with embedded system prompts - **Per-model parameters** — customize temperature, context size, top_k/top_p @@ -112,7 +119,10 @@ If you want a **small, focused UI for local Ollama usage** → Vessel is built f ### Prerequisites - [Docker](https://docs.docker.com/get-docker/) and Docker Compose -- [Ollama](https://ollama.com/download) running locally +- An LLM backend (at least one): + - [Ollama](https://ollama.com/download) (recommended) + - [llama.cpp](https://github.com/ggerganov/llama.cpp) + - [LM Studio](https://lmstudio.ai/) ### Configure Ollama @@ -160,6 +170,7 @@ Full documentation is available on the **[GitHub Wiki](https://github.com/Viking | Guide | Description | |-------|-------------| | [Getting Started](https://github.com/VikingOwl91/vessel/wiki/Getting-Started) | Installation and configuration | +| [LLM Backends](https://github.com/VikingOwl91/vessel/wiki/LLM-Backends) | Configure Ollama, llama.cpp, or LM Studio | | [Projects](https://github.com/VikingOwl91/vessel/wiki/Projects) | Organize conversations into projects | | [Knowledge Base](https://github.com/VikingOwl91/vessel/wiki/Knowledge-Base) | RAG with document upload and semantic search | | [Search](https://github.com/VikingOwl91/vessel/wiki/Search) | Semantic and content search across chats | @@ -178,6 +189,7 @@ Full documentation is available on the **[GitHub Wiki](https://github.com/Viking Vessel prioritizes **usability and simplicity** over feature breadth. **Completed:** +- [x] Multi-backend support (Ollama, llama.cpp, LM Studio) - [x] Model browser with filtering and update detection - [x] Custom tools (JavaScript, Python, HTTP) - [x] System prompt library with model-specific defaults @@ -197,7 +209,7 @@ Vessel prioritizes **usability and simplicity** over feature breadth. - Multi-user systems - Cloud sync - Plugin ecosystems -- Support for every LLM runtime +- Cloud/API-based LLM providers (OpenAI, Anthropic, etc.) > *Do one thing well. Keep the UI out of the way.* @@ -223,5 +235,5 @@ Contributions are welcome! GPL-3.0 — See [LICENSE](LICENSE) for details.

- Made with Ollama and Svelte + Made with Svelte • Supports Ollama, llama.cpp, and LM Studio

diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 9c0f94a..112e8a7 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -14,6 +14,9 @@ import ( "github.com/gin-gonic/gin" "vessel-backend/internal/api" + "vessel-backend/internal/backends" + "vessel-backend/internal/backends/ollama" + "vessel-backend/internal/backends/openai" "vessel-backend/internal/database" ) @@ -29,9 +32,11 @@ func getEnvOrDefault(key, defaultValue string) string { func main() { var ( - port = flag.String("port", getEnvOrDefault("PORT", "8080"), "Server port") - dbPath = flag.String("db", getEnvOrDefault("DB_PATH", "./data/vessel.db"), "Database file path") - ollamaURL = flag.String("ollama-url", getEnvOrDefault("OLLAMA_URL", "http://localhost:11434"), "Ollama API URL") + port = flag.String("port", getEnvOrDefault("PORT", "8080"), "Server port") + dbPath = flag.String("db", getEnvOrDefault("DB_PATH", "./data/vessel.db"), "Database file path") + ollamaURL = flag.String("ollama-url", getEnvOrDefault("OLLAMA_URL", "http://localhost:11434"), "Ollama API URL") + llamacppURL = flag.String("llamacpp-url", getEnvOrDefault("LLAMACPP_URL", "http://localhost:8081"), "llama.cpp server URL") + lmstudioURL = flag.String("lmstudio-url", getEnvOrDefault("LMSTUDIO_URL", "http://localhost:1234"), "LM Studio server URL") ) flag.Parse() @@ -47,6 +52,52 @@ func main() { log.Fatalf("Failed to run migrations: %v", err) } + // Initialize backend registry + registry := backends.NewRegistry() + + // Register Ollama backend + ollamaAdapter, err := ollama.NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: *ollamaURL, + }) + if err != nil { + log.Printf("Warning: Failed to create Ollama adapter: %v", err) + } else { + if err := registry.Register(ollamaAdapter); err != nil { + log.Printf("Warning: Failed to register Ollama backend: %v", err) + } + } + + // Register llama.cpp backend (if URL is configured) + if *llamacppURL != "" { + llamacppAdapter, err := openai.NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: *llamacppURL, + }) + if err != nil { + log.Printf("Warning: Failed to create llama.cpp adapter: %v", err) + } else { + if err := registry.Register(llamacppAdapter); err != nil { + log.Printf("Warning: Failed to register llama.cpp backend: %v", err) + } + } + } + + // Register LM Studio backend (if URL is configured) + if *lmstudioURL != "" { + lmstudioAdapter, err := openai.NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLMStudio, + BaseURL: *lmstudioURL, + }) + if err != nil { + log.Printf("Warning: Failed to create LM Studio adapter: %v", err) + } else { + if err := registry.Register(lmstudioAdapter); err != nil { + log.Printf("Warning: Failed to register LM Studio backend: %v", err) + } + } + } + // Setup Gin router gin.SetMode(gin.ReleaseMode) r := gin.New() @@ -64,7 +115,7 @@ func main() { })) // Register routes - api.SetupRoutes(r, db, *ollamaURL, Version) + api.SetupRoutes(r, db, *ollamaURL, Version, registry) // Create server srv := &http.Server{ @@ -79,8 +130,12 @@ func main() { // Graceful shutdown handling go func() { log.Printf("Server starting on port %s", *port) - log.Printf("Ollama URL: %s (using official Go client)", *ollamaURL) log.Printf("Database: %s", *dbPath) + log.Printf("Backends configured:") + log.Printf(" - Ollama: %s", *ollamaURL) + log.Printf(" - llama.cpp: %s", *llamacppURL) + log.Printf(" - LM Studio: %s", *lmstudioURL) + log.Printf("Active backend: %s", registry.ActiveType().String()) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Failed to start server: %v", err) } diff --git a/backend/internal/api/ai_handlers.go b/backend/internal/api/ai_handlers.go new file mode 100644 index 0000000..2fa4248 --- /dev/null +++ b/backend/internal/api/ai_handlers.go @@ -0,0 +1,275 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/gin-gonic/gin" + + "vessel-backend/internal/backends" +) + +// AIHandlers provides HTTP handlers for the unified AI API +type AIHandlers struct { + registry *backends.Registry +} + +// NewAIHandlers creates a new AIHandlers instance +func NewAIHandlers(registry *backends.Registry) *AIHandlers { + return &AIHandlers{ + registry: registry, + } +} + +// ListBackendsHandler returns information about all configured backends +func (h *AIHandlers) ListBackendsHandler() gin.HandlerFunc { + return func(c *gin.Context) { + infos := h.registry.AllInfo(c.Request.Context()) + + c.JSON(http.StatusOK, gin.H{ + "backends": infos, + "active": h.registry.ActiveType().String(), + }) + } +} + +// DiscoverBackendsHandler probes for available backends +func (h *AIHandlers) DiscoverBackendsHandler() gin.HandlerFunc { + return func(c *gin.Context) { + var req struct { + Endpoints []backends.DiscoveryEndpoint `json:"endpoints"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + // Use default endpoints if none provided + req.Endpoints = backends.DefaultDiscoveryEndpoints() + } + + if len(req.Endpoints) == 0 { + req.Endpoints = backends.DefaultDiscoveryEndpoints() + } + + results := h.registry.Discover(c.Request.Context(), req.Endpoints) + + c.JSON(http.StatusOK, gin.H{ + "results": results, + }) + } +} + +// SetActiveHandler sets the active backend +func (h *AIHandlers) SetActiveHandler() gin.HandlerFunc { + return func(c *gin.Context) { + var req struct { + Type string `json:"type" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "type is required"}) + return + } + + backendType, err := backends.ParseBackendType(req.Type) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.registry.SetActive(backendType); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "active": backendType.String(), + }) + } +} + +// HealthCheckHandler checks the health of a specific backend +func (h *AIHandlers) HealthCheckHandler() gin.HandlerFunc { + return func(c *gin.Context) { + typeParam := c.Param("type") + + backendType, err := backends.ParseBackendType(typeParam) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + backend, ok := h.registry.Get(backendType) + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "backend not registered"}) + return + } + + if err := backend.HealthCheck(c.Request.Context()); err != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "status": "unhealthy", + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + }) + } +} + +// ListModelsHandler returns models from the active backend +func (h *AIHandlers) ListModelsHandler() gin.HandlerFunc { + return func(c *gin.Context) { + active := h.registry.Active() + if active == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "no active backend"}) + return + } + + models, err := active.ListModels(c.Request.Context()) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "models": models, + "backend": active.Type().String(), + }) + } +} + +// ChatHandler handles chat requests through the active backend +func (h *AIHandlers) ChatHandler() gin.HandlerFunc { + return func(c *gin.Context) { + active := h.registry.Active() + if active == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "no active backend"}) + return + } + + var req backends.ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request: " + err.Error()}) + return + } + + if err := req.Validate(); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Check if streaming is requested + streaming := req.Stream != nil && *req.Stream + + if streaming { + h.handleStreamingChat(c, active, &req) + } else { + h.handleNonStreamingChat(c, active, &req) + } + } +} + +// handleNonStreamingChat handles non-streaming chat requests +func (h *AIHandlers) handleNonStreamingChat(c *gin.Context, backend backends.LLMBackend, req *backends.ChatRequest) { + resp, err := backend.Chat(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, resp) +} + +// handleStreamingChat handles streaming chat requests +func (h *AIHandlers) handleStreamingChat(c *gin.Context, backend backends.LLMBackend, req *backends.ChatRequest) { + // Set headers for NDJSON streaming + c.Header("Content-Type", "application/x-ndjson") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + ctx := c.Request.Context() + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) + return + } + + chunkCh, err := backend.StreamChat(ctx, req) + if err != nil { + errResp := gin.H{"error": err.Error()} + data, _ := json.Marshal(errResp) + c.Writer.Write(append(data, '\n')) + flusher.Flush() + return + } + + for chunk := range chunkCh { + select { + case <-ctx.Done(): + return + default: + } + + data, err := json.Marshal(chunk) + if err != nil { + continue + } + + _, err = c.Writer.Write(append(data, '\n')) + if err != nil { + return + } + flusher.Flush() + } +} + +// RegisterBackendHandler registers a new backend +func (h *AIHandlers) RegisterBackendHandler() gin.HandlerFunc { + return func(c *gin.Context) { + var req backends.BackendConfig + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request: " + err.Error()}) + return + } + + if err := req.Validate(); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Create adapter based on type + var backend backends.LLMBackend + var err error + + switch req.Type { + case backends.BackendTypeOllama: + // Would import ollama adapter + c.JSON(http.StatusNotImplemented, gin.H{"error": "use /api/v1/ai/backends/discover to register backends"}) + return + case backends.BackendTypeLlamaCpp, backends.BackendTypeLMStudio: + // Would import openai adapter + c.JSON(http.StatusNotImplemented, gin.H{"error": "use /api/v1/ai/backends/discover to register backends"}) + return + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "unknown backend type"}) + return + } + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.registry.Register(backend); err != nil { + c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{ + "type": req.Type.String(), + "baseUrl": req.BaseURL, + }) + } +} diff --git a/backend/internal/api/ai_handlers_test.go b/backend/internal/api/ai_handlers_test.go new file mode 100644 index 0000000..a70f14b --- /dev/null +++ b/backend/internal/api/ai_handlers_test.go @@ -0,0 +1,354 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "vessel-backend/internal/backends" +) + +func setupAITestRouter(registry *backends.Registry) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + + handlers := NewAIHandlers(registry) + + ai := r.Group("/api/v1/ai") + { + ai.GET("/backends", handlers.ListBackendsHandler()) + ai.POST("/backends/discover", handlers.DiscoverBackendsHandler()) + ai.POST("/backends/active", handlers.SetActiveHandler()) + ai.GET("/backends/:type/health", handlers.HealthCheckHandler()) + ai.POST("/chat", handlers.ChatHandler()) + ai.GET("/models", handlers.ListModelsHandler()) + } + + return r +} + +func TestAIHandlers_ListBackends(t *testing.T) { + registry := backends.NewRegistry() + + mock := &mockAIBackend{ + backendType: backends.BackendTypeOllama, + config: backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + info: backends.BackendInfo{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Status: backends.BackendStatusConnected, + Capabilities: backends.OllamaCapabilities(), + Version: "0.3.0", + }, + } + registry.Register(mock) + registry.SetActive(backends.BackendTypeOllama) + + router := setupAITestRouter(registry) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/ai/backends", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("ListBackends() status = %d, want %d", w.Code, http.StatusOK) + } + + var resp struct { + Backends []backends.BackendInfo `json:"backends"` + Active string `json:"active"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if len(resp.Backends) != 1 { + t.Errorf("ListBackends() returned %d backends, want 1", len(resp.Backends)) + } + + if resp.Active != "ollama" { + t.Errorf("ListBackends() active = %q, want %q", resp.Active, "ollama") + } +} + +func TestAIHandlers_SetActive(t *testing.T) { + registry := backends.NewRegistry() + + mock := &mockAIBackend{ + backendType: backends.BackendTypeOllama, + config: backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + } + registry.Register(mock) + + router := setupAITestRouter(registry) + + t.Run("set valid backend active", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"type": "ollama"}) + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/ai/backends/active", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("SetActive() status = %d, want %d", w.Code, http.StatusOK) + } + + if registry.ActiveType() != backends.BackendTypeOllama { + t.Errorf("Active backend = %v, want %v", registry.ActiveType(), backends.BackendTypeOllama) + } + }) + + t.Run("set invalid backend active", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"type": "llamacpp"}) + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/ai/backends/active", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("SetActive() status = %d, want %d", w.Code, http.StatusBadRequest) + } + }) +} + +func TestAIHandlers_HealthCheck(t *testing.T) { + registry := backends.NewRegistry() + + mock := &mockAIBackend{ + backendType: backends.BackendTypeOllama, + config: backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + healthErr: nil, + } + registry.Register(mock) + + router := setupAITestRouter(registry) + + t.Run("healthy backend", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/ai/backends/ollama/health", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("HealthCheck() status = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("non-existent backend", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/ai/backends/llamacpp/health", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("HealthCheck() status = %d, want %d", w.Code, http.StatusNotFound) + } + }) +} + +func TestAIHandlers_ListModels(t *testing.T) { + registry := backends.NewRegistry() + + mock := &mockAIBackend{ + backendType: backends.BackendTypeOllama, + config: backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + models: []backends.Model{ + {ID: "llama3.2:8b", Name: "llama3.2:8b", Family: "llama"}, + {ID: "mistral:7b", Name: "mistral:7b", Family: "mistral"}, + }, + } + registry.Register(mock) + registry.SetActive(backends.BackendTypeOllama) + + router := setupAITestRouter(registry) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/ai/models", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("ListModels() status = %d, want %d", w.Code, http.StatusOK) + } + + var resp struct { + Models []backends.Model `json:"models"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if len(resp.Models) != 2 { + t.Errorf("ListModels() returned %d models, want 2", len(resp.Models)) + } +} + +func TestAIHandlers_ListModels_NoActiveBackend(t *testing.T) { + registry := backends.NewRegistry() + router := setupAITestRouter(registry) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/ai/models", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("ListModels() status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +func TestAIHandlers_Chat(t *testing.T) { + registry := backends.NewRegistry() + + mock := &mockAIBackend{ + backendType: backends.BackendTypeOllama, + config: backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + chatResponse: &backends.ChatChunk{ + Model: "llama3.2:8b", + Message: &backends.ChatMessage{ + Role: "assistant", + Content: "Hello! How can I help?", + }, + Done: true, + }, + } + registry.Register(mock) + registry.SetActive(backends.BackendTypeOllama) + + router := setupAITestRouter(registry) + + t.Run("non-streaming chat", func(t *testing.T) { + chatReq := backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + } + body, _ := json.Marshal(chatReq) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/ai/chat", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Chat() status = %d, want %d, body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp backends.ChatChunk + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if !resp.Done { + t.Error("Chat() response.Done = false, want true") + } + + if resp.Message == nil || resp.Message.Content != "Hello! How can I help?" { + t.Errorf("Chat() unexpected response: %+v", resp) + } + }) +} + +func TestAIHandlers_Chat_InvalidRequest(t *testing.T) { + registry := backends.NewRegistry() + + mock := &mockAIBackend{ + backendType: backends.BackendTypeOllama, + } + registry.Register(mock) + registry.SetActive(backends.BackendTypeOllama) + + router := setupAITestRouter(registry) + + // Missing model + chatReq := map[string]interface{}{ + "messages": []map[string]string{ + {"role": "user", "content": "Hello"}, + }, + } + body, _ := json.Marshal(chatReq) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/ai/chat", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Chat() status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +// mockAIBackend implements backends.LLMBackend for testing +type mockAIBackend struct { + backendType backends.BackendType + config backends.BackendConfig + info backends.BackendInfo + healthErr error + models []backends.Model + chatResponse *backends.ChatChunk +} + +func (m *mockAIBackend) Type() backends.BackendType { + return m.backendType +} + +func (m *mockAIBackend) Config() backends.BackendConfig { + return m.config +} + +func (m *mockAIBackend) HealthCheck(ctx context.Context) error { + return m.healthErr +} + +func (m *mockAIBackend) ListModels(ctx context.Context) ([]backends.Model, error) { + return m.models, nil +} + +func (m *mockAIBackend) StreamChat(ctx context.Context, req *backends.ChatRequest) (<-chan backends.ChatChunk, error) { + ch := make(chan backends.ChatChunk, 1) + if m.chatResponse != nil { + ch <- *m.chatResponse + } + close(ch) + return ch, nil +} + +func (m *mockAIBackend) Chat(ctx context.Context, req *backends.ChatRequest) (*backends.ChatChunk, error) { + if m.chatResponse != nil { + return m.chatResponse, nil + } + return &backends.ChatChunk{Done: true}, nil +} + +func (m *mockAIBackend) Capabilities() backends.BackendCapabilities { + return backends.OllamaCapabilities() +} + +func (m *mockAIBackend) Info(ctx context.Context) backends.BackendInfo { + if m.info.Type != "" { + return m.info + } + return backends.BackendInfo{ + Type: m.backendType, + BaseURL: m.config.BaseURL, + Status: backends.BackendStatusConnected, + Capabilities: m.Capabilities(), + } +} diff --git a/backend/internal/api/routes.go b/backend/internal/api/routes.go index 8e76f22..fb192b0 100644 --- a/backend/internal/api/routes.go +++ b/backend/internal/api/routes.go @@ -5,10 +5,12 @@ import ( "log" "github.com/gin-gonic/gin" + + "vessel-backend/internal/backends" ) // SetupRoutes configures all API routes -func SetupRoutes(r *gin.Engine, db *sql.DB, ollamaURL string, appVersion string) { +func SetupRoutes(r *gin.Engine, db *sql.DB, ollamaURL string, appVersion string, registry *backends.Registry) { // Initialize Ollama service with official client ollamaService, err := NewOllamaService(ollamaURL) if err != nil { @@ -97,6 +99,24 @@ func SetupRoutes(r *gin.Engine, db *sql.DB, ollamaURL string, appVersion string) models.GET("/remote/status", modelRegistry.SyncStatusHandler()) } + // Unified AI routes (multi-backend support) + if registry != nil { + aiHandlers := NewAIHandlers(registry) + ai := v1.Group("/ai") + { + // Backend management + ai.GET("/backends", aiHandlers.ListBackendsHandler()) + ai.POST("/backends/discover", aiHandlers.DiscoverBackendsHandler()) + ai.POST("/backends/active", aiHandlers.SetActiveHandler()) + ai.GET("/backends/:type/health", aiHandlers.HealthCheckHandler()) + ai.POST("/backends/register", aiHandlers.RegisterBackendHandler()) + + // Unified model and chat endpoints (route to active backend) + ai.GET("/models", aiHandlers.ListModelsHandler()) + ai.POST("/chat", aiHandlers.ChatHandler()) + } + } + // Ollama API routes (using official client) if ollamaService != nil { ollama := v1.Group("/ollama") diff --git a/backend/internal/backends/interface.go b/backend/internal/backends/interface.go new file mode 100644 index 0000000..70da0f0 --- /dev/null +++ b/backend/internal/backends/interface.go @@ -0,0 +1,98 @@ +package backends + +import ( + "context" +) + +// LLMBackend defines the interface for LLM backend implementations. +// All backends (Ollama, llama.cpp, LM Studio) must implement this interface. +type LLMBackend interface { + // Type returns the backend type identifier + Type() BackendType + + // Config returns the backend configuration + Config() BackendConfig + + // HealthCheck verifies the backend is reachable and operational + HealthCheck(ctx context.Context) error + + // ListModels returns all models available from this backend + ListModels(ctx context.Context) ([]Model, error) + + // StreamChat sends a chat request and returns a channel for streaming responses. + // The channel is closed when the stream completes or an error occurs. + // Callers should check ChatChunk.Error for stream errors. + StreamChat(ctx context.Context, req *ChatRequest) (<-chan ChatChunk, error) + + // Chat sends a non-streaming chat request and returns the final response + Chat(ctx context.Context, req *ChatRequest) (*ChatChunk, error) + + // Capabilities returns what features this backend supports + Capabilities() BackendCapabilities + + // Info returns detailed information about the backend including status + Info(ctx context.Context) BackendInfo +} + +// ModelManager extends LLMBackend with model management capabilities. +// Only Ollama implements this interface. +type ModelManager interface { + LLMBackend + + // PullModel downloads a model from the registry. + // Returns a channel for progress updates. + PullModel(ctx context.Context, name string) (<-chan PullProgress, error) + + // DeleteModel removes a model from local storage + DeleteModel(ctx context.Context, name string) error + + // CreateModel creates a custom model with the given Modelfile content + CreateModel(ctx context.Context, name string, modelfile string) (<-chan CreateProgress, error) + + // CopyModel creates a copy of an existing model + CopyModel(ctx context.Context, source, destination string) error + + // ShowModel returns detailed information about a specific model + ShowModel(ctx context.Context, name string) (*ModelDetails, error) +} + +// EmbeddingProvider extends LLMBackend with embedding capabilities. +type EmbeddingProvider interface { + LLMBackend + + // Embed generates embeddings for the given input + Embed(ctx context.Context, model string, input []string) ([][]float64, error) +} + +// PullProgress represents progress during model download +type PullProgress struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int64 `json:"total,omitempty"` + Completed int64 `json:"completed,omitempty"` + Error string `json:"error,omitempty"` +} + +// CreateProgress represents progress during model creation +type CreateProgress struct { + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +// ModelDetails contains detailed information about a model +type ModelDetails struct { + Name string `json:"name"` + ModifiedAt string `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParamSize string `json:"parameter_size"` + QuantLevel string `json:"quantization_level"` + Template string `json:"template"` + System string `json:"system"` + License string `json:"license"` + Modelfile string `json:"modelfile"` + Parameters map[string]string `json:"parameters"` +} diff --git a/backend/internal/backends/ollama/adapter.go b/backend/internal/backends/ollama/adapter.go new file mode 100644 index 0000000..b46eb00 --- /dev/null +++ b/backend/internal/backends/ollama/adapter.go @@ -0,0 +1,624 @@ +package ollama + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "vessel-backend/internal/backends" +) + +// Adapter implements the LLMBackend interface for Ollama. +// It also implements ModelManager and EmbeddingProvider. +type Adapter struct { + config backends.BackendConfig + httpClient *http.Client + baseURL *url.URL +} + +// Ensure Adapter implements all required interfaces +var ( + _ backends.LLMBackend = (*Adapter)(nil) + _ backends.ModelManager = (*Adapter)(nil) + _ backends.EmbeddingProvider = (*Adapter)(nil) +) + +// NewAdapter creates a new Ollama backend adapter +func NewAdapter(config backends.BackendConfig) (*Adapter, error) { + if config.Type != backends.BackendTypeOllama { + return nil, fmt.Errorf("invalid backend type: expected %s, got %s", backends.BackendTypeOllama, config.Type) + } + + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + baseURL, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + return &Adapter{ + config: config, + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + }, nil +} + +// Type returns the backend type +func (a *Adapter) Type() backends.BackendType { + return backends.BackendTypeOllama +} + +// Config returns the backend configuration +func (a *Adapter) Config() backends.BackendConfig { + return a.config +} + +// Capabilities returns what features this backend supports +func (a *Adapter) Capabilities() backends.BackendCapabilities { + return backends.OllamaCapabilities() +} + +// HealthCheck verifies the backend is reachable +func (a *Adapter) HealthCheck(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, "GET", a.baseURL.String()+"/api/version", nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to reach Ollama: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Ollama returned status %d", resp.StatusCode) + } + + return nil +} + +// ollamaListResponse represents the response from /api/tags +type ollamaListResponse struct { + Models []ollamaModel `json:"models"` +} + +type ollamaModel struct { + Name string `json:"name"` + Size int64 `json:"size"` + ModifiedAt string `json:"modified_at"` + Details ollamaModelDetails `json:"details"` +} + +type ollamaModelDetails struct { + Family string `json:"family"` + QuantLevel string `json:"quantization_level"` + ParamSize string `json:"parameter_size"` +} + +// ListModels returns all models available from Ollama +func (a *Adapter) ListModels(ctx context.Context) ([]backends.Model, error) { + req, err := http.NewRequestWithContext(ctx, "GET", a.baseURL.String()+"/api/tags", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to list models: %w", err) + } + defer resp.Body.Close() + + var listResp ollamaListResponse + if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + models := make([]backends.Model, len(listResp.Models)) + for i, m := range listResp.Models { + models[i] = backends.Model{ + ID: m.Name, + Name: m.Name, + Size: m.Size, + ModifiedAt: m.ModifiedAt, + Family: m.Details.Family, + QuantLevel: m.Details.QuantLevel, + } + } + + return models, nil +} + +// Chat sends a non-streaming chat request +func (a *Adapter) Chat(ctx context.Context, req *backends.ChatRequest) (*backends.ChatChunk, error) { + if err := req.Validate(); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Convert to Ollama format + ollamaReq := a.convertChatRequest(req) + ollamaReq["stream"] = false + + body, err := json.Marshal(ollamaReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/chat", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("chat request failed: %w", err) + } + defer resp.Body.Close() + + var ollamaResp ollamaChatResponse + if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return a.convertChatResponse(&ollamaResp), nil +} + +// StreamChat sends a streaming chat request +func (a *Adapter) StreamChat(ctx context.Context, req *backends.ChatRequest) (<-chan backends.ChatChunk, error) { + if err := req.Validate(); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Convert to Ollama format + ollamaReq := a.convertChatRequest(req) + ollamaReq["stream"] = true + + body, err := json.Marshal(ollamaReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request without timeout for streaming + httpReq, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/chat", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + // Use a client without timeout for streaming + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("chat request failed: %w", err) + } + + chunkCh := make(chan backends.ChatChunk) + + go func() { + defer close(chunkCh) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + var ollamaResp ollamaChatResponse + if err := json.Unmarshal(line, &ollamaResp); err != nil { + chunkCh <- backends.ChatChunk{Error: fmt.Sprintf("failed to parse response: %v", err)} + return + } + + chunkCh <- *a.convertChatResponse(&ollamaResp) + + if ollamaResp.Done { + return + } + } + + if err := scanner.Err(); err != nil && ctx.Err() == nil { + chunkCh <- backends.ChatChunk{Error: fmt.Sprintf("stream error: %v", err)} + } + }() + + return chunkCh, nil +} + +// Info returns detailed information about the backend +func (a *Adapter) Info(ctx context.Context) backends.BackendInfo { + info := backends.BackendInfo{ + Type: backends.BackendTypeOllama, + BaseURL: a.config.BaseURL, + Capabilities: a.Capabilities(), + } + + // Try to get version + req, err := http.NewRequestWithContext(ctx, "GET", a.baseURL.String()+"/api/version", nil) + if err != nil { + info.Status = backends.BackendStatusDisconnected + info.Error = err.Error() + return info + } + + resp, err := a.httpClient.Do(req) + if err != nil { + info.Status = backends.BackendStatusDisconnected + info.Error = err.Error() + return info + } + defer resp.Body.Close() + + var versionResp struct { + Version string `json:"version"` + } + if err := json.NewDecoder(resp.Body).Decode(&versionResp); err != nil { + info.Status = backends.BackendStatusDisconnected + info.Error = err.Error() + return info + } + + info.Status = backends.BackendStatusConnected + info.Version = versionResp.Version + return info +} + +// ShowModel returns detailed information about a specific model +func (a *Adapter) ShowModel(ctx context.Context, name string) (*backends.ModelDetails, error) { + body, err := json.Marshal(map[string]string{"name": name}) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/show", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to show model: %w", err) + } + defer resp.Body.Close() + + var showResp struct { + Modelfile string `json:"modelfile"` + Template string `json:"template"` + System string `json:"system"` + Details struct { + Family string `json:"family"` + ParamSize string `json:"parameter_size"` + QuantLevel string `json:"quantization_level"` + } `json:"details"` + } + if err := json.NewDecoder(resp.Body).Decode(&showResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &backends.ModelDetails{ + Name: name, + Family: showResp.Details.Family, + ParamSize: showResp.Details.ParamSize, + QuantLevel: showResp.Details.QuantLevel, + Template: showResp.Template, + System: showResp.System, + Modelfile: showResp.Modelfile, + }, nil +} + +// PullModel downloads a model from the registry +func (a *Adapter) PullModel(ctx context.Context, name string) (<-chan backends.PullProgress, error) { + body, err := json.Marshal(map[string]interface{}{"name": name, "stream": true}) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/pull", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to pull model: %w", err) + } + + progressCh := make(chan backends.PullProgress) + + go func() { + defer close(progressCh) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + + var progress struct { + Status string `json:"status"` + Digest string `json:"digest"` + Total int64 `json:"total"` + Completed int64 `json:"completed"` + } + if err := json.Unmarshal(scanner.Bytes(), &progress); err != nil { + progressCh <- backends.PullProgress{Error: err.Error()} + return + } + + progressCh <- backends.PullProgress{ + Status: progress.Status, + Digest: progress.Digest, + Total: progress.Total, + Completed: progress.Completed, + } + } + + if err := scanner.Err(); err != nil && ctx.Err() == nil { + progressCh <- backends.PullProgress{Error: err.Error()} + } + }() + + return progressCh, nil +} + +// DeleteModel removes a model from local storage +func (a *Adapter) DeleteModel(ctx context.Context, name string) error { + body, err := json.Marshal(map[string]string{"name": name}) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "DELETE", a.baseURL.String()+"/api/delete", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to delete model: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return fmt.Errorf("delete failed: %s", string(bodyBytes)) + } + + return nil +} + +// CreateModel creates a custom model with the given Modelfile content +func (a *Adapter) CreateModel(ctx context.Context, name string, modelfile string) (<-chan backends.CreateProgress, error) { + body, err := json.Marshal(map[string]interface{}{ + "name": name, + "modelfile": modelfile, + "stream": true, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/create", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to create model: %w", err) + } + + progressCh := make(chan backends.CreateProgress) + + go func() { + defer close(progressCh) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + + var progress struct { + Status string `json:"status"` + } + if err := json.Unmarshal(scanner.Bytes(), &progress); err != nil { + progressCh <- backends.CreateProgress{Error: err.Error()} + return + } + + progressCh <- backends.CreateProgress{Status: progress.Status} + } + + if err := scanner.Err(); err != nil && ctx.Err() == nil { + progressCh <- backends.CreateProgress{Error: err.Error()} + } + }() + + return progressCh, nil +} + +// CopyModel creates a copy of an existing model +func (a *Adapter) CopyModel(ctx context.Context, source, destination string) error { + body, err := json.Marshal(map[string]string{ + "source": source, + "destination": destination, + }) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/copy", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to copy model: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return fmt.Errorf("copy failed: %s", string(bodyBytes)) + } + + return nil +} + +// Embed generates embeddings for the given input +func (a *Adapter) Embed(ctx context.Context, model string, input []string) ([][]float64, error) { + body, err := json.Marshal(map[string]interface{}{ + "model": model, + "input": input, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/api/embed", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("embed request failed: %w", err) + } + defer resp.Body.Close() + + var embedResp struct { + Embeddings [][]float64 `json:"embeddings"` + } + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return embedResp.Embeddings, nil +} + +// ollamaChatResponse represents the response from /api/chat +type ollamaChatResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message ollamaChatMessage `json:"message"` + Done bool `json:"done"` + DoneReason string `json:"done_reason,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` +} + +type ollamaChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Images []string `json:"images,omitempty"` + ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` +} + +type ollamaToolCall struct { + Function struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + } `json:"function"` +} + +// convertChatRequest converts a backends.ChatRequest to Ollama format +func (a *Adapter) convertChatRequest(req *backends.ChatRequest) map[string]interface{} { + messages := make([]map[string]interface{}, len(req.Messages)) + for i, msg := range req.Messages { + m := map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + if len(msg.Images) > 0 { + m["images"] = msg.Images + } + messages[i] = m + } + + ollamaReq := map[string]interface{}{ + "model": req.Model, + "messages": messages, + } + + // Add optional parameters + if req.Options != nil { + ollamaReq["options"] = req.Options + } + if len(req.Tools) > 0 { + ollamaReq["tools"] = req.Tools + } + + return ollamaReq +} + +// convertChatResponse converts an Ollama response to backends.ChatChunk +func (a *Adapter) convertChatResponse(resp *ollamaChatResponse) *backends.ChatChunk { + chunk := &backends.ChatChunk{ + Model: resp.Model, + CreatedAt: resp.CreatedAt, + Done: resp.Done, + DoneReason: resp.DoneReason, + PromptEvalCount: resp.PromptEvalCount, + EvalCount: resp.EvalCount, + } + + if resp.Message.Role != "" || resp.Message.Content != "" { + msg := &backends.ChatMessage{ + Role: resp.Message.Role, + Content: resp.Message.Content, + Images: resp.Message.Images, + } + + // Convert tool calls + for _, tc := range resp.Message.ToolCalls { + msg.ToolCalls = append(msg.ToolCalls, backends.ToolCall{ + Type: "function", + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: tc.Function.Name, + Arguments: string(tc.Function.Arguments), + }, + }) + } + + chunk.Message = msg + } + + return chunk +} diff --git a/backend/internal/backends/ollama/adapter_test.go b/backend/internal/backends/ollama/adapter_test.go new file mode 100644 index 0000000..3b6041f --- /dev/null +++ b/backend/internal/backends/ollama/adapter_test.go @@ -0,0 +1,574 @@ +package ollama + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "vessel-backend/internal/backends" +) + +func TestAdapter_Type(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }) + + if adapter.Type() != backends.BackendTypeOllama { + t.Errorf("Type() = %v, want %v", adapter.Type(), backends.BackendTypeOllama) + } +} + +func TestAdapter_Config(t *testing.T) { + cfg := backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Enabled: true, + } + + adapter, _ := NewAdapter(cfg) + got := adapter.Config() + + if got.Type != cfg.Type { + t.Errorf("Config().Type = %v, want %v", got.Type, cfg.Type) + } + if got.BaseURL != cfg.BaseURL { + t.Errorf("Config().BaseURL = %v, want %v", got.BaseURL, cfg.BaseURL) + } +} + +func TestAdapter_Capabilities(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }) + + caps := adapter.Capabilities() + + if !caps.CanListModels { + t.Error("Ollama adapter should support listing models") + } + if !caps.CanPullModels { + t.Error("Ollama adapter should support pulling models") + } + if !caps.CanDeleteModels { + t.Error("Ollama adapter should support deleting models") + } + if !caps.CanCreateModels { + t.Error("Ollama adapter should support creating models") + } + if !caps.CanStreamChat { + t.Error("Ollama adapter should support streaming chat") + } + if !caps.CanEmbed { + t.Error("Ollama adapter should support embeddings") + } +} + +func TestAdapter_HealthCheck(t *testing.T) { + t.Run("healthy server", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" || r.URL.Path == "/api/version" { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"version": "0.1.0"}) + } + })) + defer server.Close() + + adapter, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("Failed to create adapter: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := adapter.HealthCheck(ctx); err != nil { + t.Errorf("HealthCheck() error = %v, want nil", err) + } + }) + + t.Run("unreachable server", func(t *testing.T) { + adapter, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:19999", // unlikely to be running + }) + if err != nil { + t.Fatalf("Failed to create adapter: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + if err := adapter.HealthCheck(ctx); err == nil { + t.Error("HealthCheck() expected error for unreachable server") + } + }) +} + +func TestAdapter_ListModels(t *testing.T) { + t.Run("returns model list", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/tags" { + resp := map[string]interface{}{ + "models": []map[string]interface{}{ + { + "name": "llama3.2:8b", + "size": int64(4700000000), + "modified_at": "2024-01-15T10:30:00Z", + "details": map[string]interface{}{ + "family": "llama", + "quantization_level": "Q4_K_M", + }, + }, + { + "name": "mistral:7b", + "size": int64(4100000000), + "modified_at": "2024-01-14T08:00:00Z", + "details": map[string]interface{}{ + "family": "mistral", + "quantization_level": "Q4_0", + }, + }, + }, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + ctx := context.Background() + models, err := adapter.ListModels(ctx) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + + if len(models) != 2 { + t.Errorf("ListModels() returned %d models, want 2", len(models)) + } + + if models[0].Name != "llama3.2:8b" { + t.Errorf("First model name = %q, want %q", models[0].Name, "llama3.2:8b") + } + + if models[0].Family != "llama" { + t.Errorf("First model family = %q, want %q", models[0].Family, "llama") + } + }) + + t.Run("handles empty model list", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/tags" { + resp := map[string]interface{}{ + "models": []map[string]interface{}{}, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + models, err := adapter.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + + if len(models) != 0 { + t.Errorf("ListModels() returned %d models, want 0", len(models)) + } + }) +} + +func TestAdapter_Chat(t *testing.T) { + t.Run("non-streaming chat", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/chat" && r.Method == "POST" { + var req map[string]interface{} + json.NewDecoder(r.Body).Decode(&req) + + // Check stream is false + if stream, ok := req["stream"].(bool); !ok || stream { + t.Error("Expected stream=false for non-streaming chat") + } + + resp := map[string]interface{}{ + "model": "llama3.2:8b", + "message": map[string]interface{}{"role": "assistant", "content": "Hello! How can I help you?"}, + "done": true, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + } + + resp, err := adapter.Chat(context.Background(), req) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if !resp.Done { + t.Error("Chat() response.Done = false, want true") + } + + if resp.Message == nil || resp.Message.Content != "Hello! How can I help you?" { + t.Errorf("Chat() response content unexpected: %+v", resp.Message) + } + }) +} + +func TestAdapter_StreamChat(t *testing.T) { + t.Run("streaming chat", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/chat" && r.Method == "POST" { + var req map[string]interface{} + json.NewDecoder(r.Body).Decode(&req) + + // Check stream is true + if stream, ok := req["stream"].(bool); ok && !stream { + t.Error("Expected stream=true for streaming chat") + } + + w.Header().Set("Content-Type", "application/x-ndjson") + flusher := w.(http.Flusher) + + // Send streaming chunks + chunks := []map[string]interface{}{ + {"model": "llama3.2:8b", "message": map[string]interface{}{"role": "assistant", "content": "Hello"}, "done": false}, + {"model": "llama3.2:8b", "message": map[string]interface{}{"role": "assistant", "content": "!"}, "done": false}, + {"model": "llama3.2:8b", "message": map[string]interface{}{"role": "assistant", "content": ""}, "done": true}, + } + + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + w.Write(append(data, '\n')) + flusher.Flush() + } + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + streaming := true + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + Stream: &streaming, + } + + chunkCh, err := adapter.StreamChat(context.Background(), req) + if err != nil { + t.Fatalf("StreamChat() error = %v", err) + } + + var chunks []backends.ChatChunk + for chunk := range chunkCh { + chunks = append(chunks, chunk) + } + + if len(chunks) != 3 { + t.Errorf("StreamChat() received %d chunks, want 3", len(chunks)) + } + + // Last chunk should be done + if !chunks[len(chunks)-1].Done { + t.Error("Last chunk should have Done=true") + } + }) + + t.Run("handles context cancellation", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/chat" { + w.Header().Set("Content-Type", "application/x-ndjson") + flusher := w.(http.Flusher) + + // Send first chunk then wait + chunk := map[string]interface{}{"model": "llama3.2:8b", "message": map[string]interface{}{"role": "assistant", "content": "Starting..."}, "done": false} + data, _ := json.Marshal(chunk) + w.Write(append(data, '\n')) + flusher.Flush() + + // Wait long enough for context to be cancelled + time.Sleep(2 * time.Second) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + streaming := true + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + Stream: &streaming, + } + + chunkCh, err := adapter.StreamChat(ctx, req) + if err != nil { + t.Fatalf("StreamChat() error = %v", err) + } + + // Should receive at least one chunk before timeout + receivedChunks := 0 + for range chunkCh { + receivedChunks++ + } + + if receivedChunks == 0 { + t.Error("Expected to receive at least one chunk before cancellation") + } + }) +} + +func TestAdapter_Info(t *testing.T) { + t.Run("connected server", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" || r.URL.Path == "/api/version" { + json.NewEncoder(w).Encode(map[string]string{"version": "0.3.0"}) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + info := adapter.Info(context.Background()) + + if info.Type != backends.BackendTypeOllama { + t.Errorf("Info().Type = %v, want %v", info.Type, backends.BackendTypeOllama) + } + + if info.Status != backends.BackendStatusConnected { + t.Errorf("Info().Status = %v, want %v", info.Status, backends.BackendStatusConnected) + } + + if info.Version != "0.3.0" { + t.Errorf("Info().Version = %v, want %v", info.Version, "0.3.0") + } + }) + + t.Run("disconnected server", func(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:19999", + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + info := adapter.Info(ctx) + + if info.Status != backends.BackendStatusDisconnected { + t.Errorf("Info().Status = %v, want %v", info.Status, backends.BackendStatusDisconnected) + } + + if info.Error == "" { + t.Error("Info().Error should be set for disconnected server") + } + }) +} + +func TestAdapter_ShowModel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" && r.Method == "POST" { + var req map[string]string + json.NewDecoder(r.Body).Decode(&req) + + resp := map[string]interface{}{ + "modelfile": "FROM llama3.2:8b\nSYSTEM You are helpful.", + "template": "{{ .Prompt }}", + "system": "You are helpful.", + "details": map[string]interface{}{ + "family": "llama", + "parameter_size": "8B", + "quantization_level": "Q4_K_M", + }, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + details, err := adapter.ShowModel(context.Background(), "llama3.2:8b") + if err != nil { + t.Fatalf("ShowModel() error = %v", err) + } + + if details.Family != "llama" { + t.Errorf("ShowModel().Family = %q, want %q", details.Family, "llama") + } + + if details.System != "You are helpful." { + t.Errorf("ShowModel().System = %q, want %q", details.System, "You are helpful.") + } +} + +func TestAdapter_DeleteModel(t *testing.T) { + deleted := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/delete" && r.Method == "DELETE" { + deleted = true + w.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + err := adapter.DeleteModel(context.Background(), "test-model") + if err != nil { + t.Fatalf("DeleteModel() error = %v", err) + } + + if !deleted { + t.Error("DeleteModel() did not call the delete endpoint") + } +} + +func TestAdapter_CopyModel(t *testing.T) { + copied := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/copy" && r.Method == "POST" { + var req map[string]string + json.NewDecoder(r.Body).Decode(&req) + + if req["source"] == "source-model" && req["destination"] == "dest-model" { + copied = true + } + w.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + err := adapter.CopyModel(context.Background(), "source-model", "dest-model") + if err != nil { + t.Fatalf("CopyModel() error = %v", err) + } + + if !copied { + t.Error("CopyModel() did not call the copy endpoint with correct params") + } +} + +func TestAdapter_Embed(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/embed" && r.Method == "POST" { + resp := map[string]interface{}{ + "embeddings": [][]float64{ + {0.1, 0.2, 0.3}, + {0.4, 0.5, 0.6}, + }, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: server.URL, + }) + + embeddings, err := adapter.Embed(context.Background(), "nomic-embed-text", []string{"hello", "world"}) + if err != nil { + t.Fatalf("Embed() error = %v", err) + } + + if len(embeddings) != 2 { + t.Errorf("Embed() returned %d embeddings, want 2", len(embeddings)) + } + + if len(embeddings[0]) != 3 { + t.Errorf("First embedding has %d dimensions, want 3", len(embeddings[0])) + } +} + +func TestNewAdapter_Validation(t *testing.T) { + t.Run("invalid URL", func(t *testing.T) { + _, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "not-a-url", + }) + if err == nil { + t.Error("NewAdapter() should fail with invalid URL") + } + }) + + t.Run("wrong backend type", func(t *testing.T) { + _, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "http://localhost:11434", + }) + if err == nil { + t.Error("NewAdapter() should fail with wrong backend type") + } + }) + + t.Run("valid config", func(t *testing.T) { + adapter, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:11434", + }) + if err != nil { + t.Errorf("NewAdapter() error = %v", err) + } + if adapter == nil { + t.Error("NewAdapter() returned nil adapter") + } + }) +} diff --git a/backend/internal/backends/openai/adapter.go b/backend/internal/backends/openai/adapter.go new file mode 100644 index 0000000..a2908eb --- /dev/null +++ b/backend/internal/backends/openai/adapter.go @@ -0,0 +1,503 @@ +package openai + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "vessel-backend/internal/backends" +) + +// Adapter implements the LLMBackend interface for OpenAI-compatible APIs. +// This includes llama.cpp server and LM Studio. +type Adapter struct { + config backends.BackendConfig + httpClient *http.Client + baseURL *url.URL +} + +// Ensure Adapter implements required interfaces +var ( + _ backends.LLMBackend = (*Adapter)(nil) + _ backends.EmbeddingProvider = (*Adapter)(nil) +) + +// NewAdapter creates a new OpenAI-compatible backend adapter +func NewAdapter(config backends.BackendConfig) (*Adapter, error) { + if config.Type != backends.BackendTypeLlamaCpp && config.Type != backends.BackendTypeLMStudio { + return nil, fmt.Errorf("invalid backend type: expected %s or %s, got %s", + backends.BackendTypeLlamaCpp, backends.BackendTypeLMStudio, config.Type) + } + + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + baseURL, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + return &Adapter{ + config: config, + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + }, nil +} + +// Type returns the backend type +func (a *Adapter) Type() backends.BackendType { + return a.config.Type +} + +// Config returns the backend configuration +func (a *Adapter) Config() backends.BackendConfig { + return a.config +} + +// Capabilities returns what features this backend supports +func (a *Adapter) Capabilities() backends.BackendCapabilities { + if a.config.Type == backends.BackendTypeLlamaCpp { + return backends.LlamaCppCapabilities() + } + return backends.LMStudioCapabilities() +} + +// HealthCheck verifies the backend is reachable +func (a *Adapter) HealthCheck(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, "GET", a.baseURL.String()+"/v1/models", nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to reach backend: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("backend returned status %d", resp.StatusCode) + } + + return nil +} + +// openaiModelsResponse represents the response from /v1/models +type openaiModelsResponse struct { + Data []openaiModel `json:"data"` +} + +type openaiModel struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Created int64 `json:"created"` +} + +// ListModels returns all models available from this backend +func (a *Adapter) ListModels(ctx context.Context) ([]backends.Model, error) { + req, err := http.NewRequestWithContext(ctx, "GET", a.baseURL.String()+"/v1/models", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to list models: %w", err) + } + defer resp.Body.Close() + + var listResp openaiModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + models := make([]backends.Model, len(listResp.Data)) + for i, m := range listResp.Data { + models[i] = backends.Model{ + ID: m.ID, + Name: m.ID, + } + } + + return models, nil +} + +// Chat sends a non-streaming chat request +func (a *Adapter) Chat(ctx context.Context, req *backends.ChatRequest) (*backends.ChatChunk, error) { + if err := req.Validate(); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + openaiReq := a.convertChatRequest(req) + openaiReq["stream"] = false + + body, err := json.Marshal(openaiReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/v1/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("chat request failed: %w", err) + } + defer resp.Body.Close() + + var openaiResp openaiChatResponse + if err := json.NewDecoder(resp.Body).Decode(&openaiResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return a.convertChatResponse(&openaiResp), nil +} + +// StreamChat sends a streaming chat request +func (a *Adapter) StreamChat(ctx context.Context, req *backends.ChatRequest) (<-chan backends.ChatChunk, error) { + if err := req.Validate(); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + openaiReq := a.convertChatRequest(req) + openaiReq["stream"] = true + + body, err := json.Marshal(openaiReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/v1/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + + // Use a client without timeout for streaming + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("chat request failed: %w", err) + } + + chunkCh := make(chan backends.ChatChunk) + + go func() { + defer close(chunkCh) + defer resp.Body.Close() + + a.parseSSEStream(ctx, resp.Body, chunkCh) + }() + + return chunkCh, nil +} + +// parseSSEStream parses Server-Sent Events and emits ChatChunks +func (a *Adapter) parseSSEStream(ctx context.Context, body io.Reader, chunkCh chan<- backends.ChatChunk) { + scanner := bufio.NewScanner(body) + + // Track accumulated tool call arguments + toolCallArgs := make(map[int]string) + + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE data line + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + // Check for stream end + if data == "[DONE]" { + chunkCh <- backends.ChatChunk{Done: true} + return + } + + var streamResp openaiStreamResponse + if err := json.Unmarshal([]byte(data), &streamResp); err != nil { + chunkCh <- backends.ChatChunk{Error: fmt.Sprintf("failed to parse SSE data: %v", err)} + continue + } + + chunk := a.convertStreamResponse(&streamResp, toolCallArgs) + chunkCh <- chunk + + if chunk.Done { + return + } + } + + if err := scanner.Err(); err != nil && ctx.Err() == nil { + chunkCh <- backends.ChatChunk{Error: fmt.Sprintf("stream error: %v", err)} + } +} + +// Info returns detailed information about the backend +func (a *Adapter) Info(ctx context.Context) backends.BackendInfo { + info := backends.BackendInfo{ + Type: a.config.Type, + BaseURL: a.config.BaseURL, + Capabilities: a.Capabilities(), + } + + // Try to reach the models endpoint + if err := a.HealthCheck(ctx); err != nil { + info.Status = backends.BackendStatusDisconnected + info.Error = err.Error() + return info + } + + info.Status = backends.BackendStatusConnected + return info +} + +// Embed generates embeddings for the given input +func (a *Adapter) Embed(ctx context.Context, model string, input []string) ([][]float64, error) { + body, err := json.Marshal(map[string]interface{}{ + "model": model, + "input": input, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL.String()+"/v1/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("embed request failed: %w", err) + } + defer resp.Body.Close() + + var embedResp struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + embeddings := make([][]float64, len(embedResp.Data)) + for _, d := range embedResp.Data { + embeddings[d.Index] = d.Embedding + } + + return embeddings, nil +} + +// OpenAI API response types + +type openaiChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []openaiChoice `json:"choices"` + Usage *openaiUsage `json:"usage,omitempty"` +} + +type openaiChoice struct { + Index int `json:"index"` + Message *openaiMessage `json:"message,omitempty"` + Delta *openaiMessage `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type openaiMessage struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []openaiToolCall `json:"tool_calls,omitempty"` +} + +type openaiToolCall struct { + ID string `json:"id,omitempty"` + Index int `json:"index,omitempty"` + Type string `json:"type,omitempty"` + Function struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + } `json:"function"` +} + +type openaiUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type openaiStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []openaiChoice `json:"choices"` +} + +// convertChatRequest converts a backends.ChatRequest to OpenAI format +func (a *Adapter) convertChatRequest(req *backends.ChatRequest) map[string]interface{} { + messages := make([]map[string]interface{}, len(req.Messages)) + for i, msg := range req.Messages { + m := map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + if msg.Name != "" { + m["name"] = msg.Name + } + if msg.ToolCallID != "" { + m["tool_call_id"] = msg.ToolCallID + } + messages[i] = m + } + + openaiReq := map[string]interface{}{ + "model": req.Model, + "messages": messages, + } + + // Add optional parameters + if req.Temperature != nil { + openaiReq["temperature"] = *req.Temperature + } + if req.TopP != nil { + openaiReq["top_p"] = *req.TopP + } + if req.MaxTokens != nil { + openaiReq["max_tokens"] = *req.MaxTokens + } + if len(req.Tools) > 0 { + openaiReq["tools"] = req.Tools + } + + return openaiReq +} + +// convertChatResponse converts an OpenAI response to backends.ChatChunk +func (a *Adapter) convertChatResponse(resp *openaiChatResponse) *backends.ChatChunk { + chunk := &backends.ChatChunk{ + Model: resp.Model, + Done: true, + } + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + if choice.Message != nil { + msg := &backends.ChatMessage{ + Role: choice.Message.Role, + Content: choice.Message.Content, + } + + // Convert tool calls + for _, tc := range choice.Message.ToolCalls { + msg.ToolCalls = append(msg.ToolCalls, backends.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + chunk.Message = msg + } + + if choice.FinishReason != "" { + chunk.DoneReason = choice.FinishReason + } + } + + if resp.Usage != nil { + chunk.PromptEvalCount = resp.Usage.PromptTokens + chunk.EvalCount = resp.Usage.CompletionTokens + } + + return chunk +} + +// convertStreamResponse converts an OpenAI stream response to backends.ChatChunk +func (a *Adapter) convertStreamResponse(resp *openaiStreamResponse, toolCallArgs map[int]string) backends.ChatChunk { + chunk := backends.ChatChunk{ + Model: resp.Model, + } + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + + if choice.FinishReason != "" { + chunk.Done = true + chunk.DoneReason = choice.FinishReason + } + + if choice.Delta != nil { + msg := &backends.ChatMessage{ + Role: choice.Delta.Role, + Content: choice.Delta.Content, + } + + // Handle streaming tool calls + for _, tc := range choice.Delta.ToolCalls { + // Accumulate arguments + if tc.Function.Arguments != "" { + toolCallArgs[tc.Index] += tc.Function.Arguments + } + + // Only add tool call when we have the initial info + if tc.ID != "" || tc.Function.Name != "" { + msg.ToolCalls = append(msg.ToolCalls, backends.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: tc.Function.Name, + Arguments: toolCallArgs[tc.Index], + }, + }) + } + } + + chunk.Message = msg + } + } + + return chunk +} diff --git a/backend/internal/backends/openai/adapter_test.go b/backend/internal/backends/openai/adapter_test.go new file mode 100644 index 0000000..4b70dc3 --- /dev/null +++ b/backend/internal/backends/openai/adapter_test.go @@ -0,0 +1,594 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "vessel-backend/internal/backends" +) + +func TestAdapter_Type(t *testing.T) { + tests := []struct { + name string + backendType backends.BackendType + expectedType backends.BackendType + }{ + {"llamacpp type", backends.BackendTypeLlamaCpp, backends.BackendTypeLlamaCpp}, + {"lmstudio type", backends.BackendTypeLMStudio, backends.BackendTypeLMStudio}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: tt.backendType, + BaseURL: "http://localhost:8081", + }) + + if adapter.Type() != tt.expectedType { + t.Errorf("Type() = %v, want %v", adapter.Type(), tt.expectedType) + } + }) + } +} + +func TestAdapter_Config(t *testing.T) { + cfg := backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "http://localhost:8081", + Enabled: true, + } + + adapter, _ := NewAdapter(cfg) + got := adapter.Config() + + if got.Type != cfg.Type { + t.Errorf("Config().Type = %v, want %v", got.Type, cfg.Type) + } + if got.BaseURL != cfg.BaseURL { + t.Errorf("Config().BaseURL = %v, want %v", got.BaseURL, cfg.BaseURL) + } +} + +func TestAdapter_Capabilities(t *testing.T) { + t.Run("llamacpp capabilities", func(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "http://localhost:8081", + }) + + caps := adapter.Capabilities() + + if !caps.CanListModels { + t.Error("llama.cpp adapter should support listing models") + } + if caps.CanPullModels { + t.Error("llama.cpp adapter should NOT support pulling models") + } + if caps.CanDeleteModels { + t.Error("llama.cpp adapter should NOT support deleting models") + } + if caps.CanCreateModels { + t.Error("llama.cpp adapter should NOT support creating models") + } + if !caps.CanStreamChat { + t.Error("llama.cpp adapter should support streaming chat") + } + if !caps.CanEmbed { + t.Error("llama.cpp adapter should support embeddings") + } + }) + + t.Run("lmstudio capabilities", func(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLMStudio, + BaseURL: "http://localhost:1234", + }) + + caps := adapter.Capabilities() + + if !caps.CanListModels { + t.Error("LM Studio adapter should support listing models") + } + if caps.CanPullModels { + t.Error("LM Studio adapter should NOT support pulling models") + } + }) +} + +func TestAdapter_HealthCheck(t *testing.T) { + t.Run("healthy server", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/models" { + json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]string{{"id": "llama3.2:8b"}}, + }) + } + })) + defer server.Close() + + adapter, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("Failed to create adapter: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := adapter.HealthCheck(ctx); err != nil { + t.Errorf("HealthCheck() error = %v, want nil", err) + } + }) + + t.Run("unreachable server", func(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "http://localhost:19999", + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + if err := adapter.HealthCheck(ctx); err == nil { + t.Error("HealthCheck() expected error for unreachable server") + } + }) +} + +func TestAdapter_ListModels(t *testing.T) { + t.Run("returns model list", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/models" { + resp := map[string]interface{}{ + "data": []map[string]interface{}{ + { + "id": "llama3.2-8b-instruct", + "object": "model", + "owned_by": "local", + "created": 1700000000, + }, + { + "id": "mistral-7b-v0.2", + "object": "model", + "owned_by": "local", + "created": 1700000001, + }, + }, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + ctx := context.Background() + models, err := adapter.ListModels(ctx) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + + if len(models) != 2 { + t.Errorf("ListModels() returned %d models, want 2", len(models)) + } + + if models[0].ID != "llama3.2-8b-instruct" { + t.Errorf("First model ID = %q, want %q", models[0].ID, "llama3.2-8b-instruct") + } + }) + + t.Run("handles empty model list", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/models" { + resp := map[string]interface{}{ + "data": []map[string]interface{}{}, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + models, err := adapter.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + + if len(models) != 0 { + t.Errorf("ListModels() returned %d models, want 0", len(models)) + } + }) +} + +func TestAdapter_Chat(t *testing.T) { + t.Run("non-streaming chat", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/chat/completions" && r.Method == "POST" { + var req map[string]interface{} + json.NewDecoder(r.Body).Decode(&req) + + // Check stream is false + if stream, ok := req["stream"].(bool); ok && stream { + t.Error("Expected stream=false for non-streaming chat") + } + + resp := map[string]interface{}{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1700000000, + "model": "llama3.2:8b", + "choices": []map[string]interface{}{ + { + "index": 0, + "message": map[string]interface{}{ + "role": "assistant", + "content": "Hello! How can I help you?", + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]int{ + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + }, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + } + + resp, err := adapter.Chat(context.Background(), req) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if !resp.Done { + t.Error("Chat() response.Done = false, want true") + } + + if resp.Message == nil || resp.Message.Content != "Hello! How can I help you?" { + t.Errorf("Chat() response content unexpected: %+v", resp.Message) + } + }) +} + +func TestAdapter_StreamChat(t *testing.T) { + t.Run("streaming chat with SSE", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/chat/completions" && r.Method == "POST" { + var req map[string]interface{} + json.NewDecoder(r.Body).Decode(&req) + + // Check stream is true + if stream, ok := req["stream"].(bool); !ok || !stream { + t.Error("Expected stream=true for streaming chat") + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + flusher := w.(http.Flusher) + + // Send SSE chunks + chunks := []string{ + `{"id":"chatcmpl-1","choices":[{"delta":{"role":"assistant","content":"Hello"}}]}`, + `{"id":"chatcmpl-1","choices":[{"delta":{"content":"!"}}]}`, + `{"id":"chatcmpl-1","choices":[{"delta":{},"finish_reason":"stop"}]}`, + } + + for _, chunk := range chunks { + fmt.Fprintf(w, "data: %s\n\n", chunk) + flusher.Flush() + } + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + streaming := true + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + Stream: &streaming, + } + + chunkCh, err := adapter.StreamChat(context.Background(), req) + if err != nil { + t.Fatalf("StreamChat() error = %v", err) + } + + var chunks []backends.ChatChunk + for chunk := range chunkCh { + chunks = append(chunks, chunk) + } + + if len(chunks) < 2 { + t.Errorf("StreamChat() received %d chunks, want at least 2", len(chunks)) + } + + // Last chunk should be done + if !chunks[len(chunks)-1].Done { + t.Error("Last chunk should have Done=true") + } + }) + + t.Run("handles context cancellation", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/chat/completions" { + w.Header().Set("Content-Type", "text/event-stream") + flusher := w.(http.Flusher) + + // Send first chunk then wait + fmt.Fprintf(w, "data: %s\n\n", `{"id":"chatcmpl-1","choices":[{"delta":{"role":"assistant","content":"Starting..."}}]}`) + flusher.Flush() + + // Wait long enough for context to be cancelled + time.Sleep(2 * time.Second) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + streaming := true + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + Stream: &streaming, + } + + chunkCh, err := adapter.StreamChat(ctx, req) + if err != nil { + t.Fatalf("StreamChat() error = %v", err) + } + + // Should receive at least one chunk before timeout + receivedChunks := 0 + for range chunkCh { + receivedChunks++ + } + + if receivedChunks == 0 { + t.Error("Expected to receive at least one chunk before cancellation") + } + }) +} + +func TestAdapter_Info(t *testing.T) { + t.Run("connected server", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/models" { + json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]string{{"id": "llama3.2:8b"}}, + }) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + info := adapter.Info(context.Background()) + + if info.Type != backends.BackendTypeLlamaCpp { + t.Errorf("Info().Type = %v, want %v", info.Type, backends.BackendTypeLlamaCpp) + } + + if info.Status != backends.BackendStatusConnected { + t.Errorf("Info().Status = %v, want %v", info.Status, backends.BackendStatusConnected) + } + }) + + t.Run("disconnected server", func(t *testing.T) { + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "http://localhost:19999", + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + info := adapter.Info(ctx) + + if info.Status != backends.BackendStatusDisconnected { + t.Errorf("Info().Status = %v, want %v", info.Status, backends.BackendStatusDisconnected) + } + + if info.Error == "" { + t.Error("Info().Error should be set for disconnected server") + } + }) +} + +func TestAdapter_Embed(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/embeddings" && r.Method == "POST" { + resp := map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{0.1, 0.2, 0.3}, "index": 0}, + {"embedding": []float64{0.4, 0.5, 0.6}, "index": 1}, + }, + } + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + embeddings, err := adapter.Embed(context.Background(), "nomic-embed-text", []string{"hello", "world"}) + if err != nil { + t.Fatalf("Embed() error = %v", err) + } + + if len(embeddings) != 2 { + t.Errorf("Embed() returned %d embeddings, want 2", len(embeddings)) + } + + if len(embeddings[0]) != 3 { + t.Errorf("First embedding has %d dimensions, want 3", len(embeddings[0])) + } +} + +func TestNewAdapter_Validation(t *testing.T) { + t.Run("invalid URL", func(t *testing.T) { + _, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "not-a-url", + }) + if err == nil { + t.Error("NewAdapter() should fail with invalid URL") + } + }) + + t.Run("wrong backend type", func(t *testing.T) { + _, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeOllama, + BaseURL: "http://localhost:8081", + }) + if err == nil { + t.Error("NewAdapter() should fail with Ollama backend type") + } + }) + + t.Run("valid llamacpp config", func(t *testing.T) { + adapter, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: "http://localhost:8081", + }) + if err != nil { + t.Errorf("NewAdapter() error = %v", err) + } + if adapter == nil { + t.Error("NewAdapter() returned nil adapter") + } + }) + + t.Run("valid lmstudio config", func(t *testing.T) { + adapter, err := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLMStudio, + BaseURL: "http://localhost:1234", + }) + if err != nil { + t.Errorf("NewAdapter() error = %v", err) + } + if adapter == nil { + t.Error("NewAdapter() returned nil adapter") + } + }) +} + +func TestAdapter_ToolCalls(t *testing.T) { + t.Run("streaming with tool calls", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/chat/completions" { + w.Header().Set("Content-Type", "text/event-stream") + flusher := w.(http.Flusher) + + // Send tool call chunks + chunks := []string{ + `{"id":"chatcmpl-1","choices":[{"delta":{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":""}}]}}]}`, + `{"id":"chatcmpl-1","choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"location\":"}}]}}]}`, + `{"id":"chatcmpl-1","choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Tokyo\"}"}}]}}]}`, + `{"id":"chatcmpl-1","choices":[{"delta":{},"finish_reason":"tool_calls"}]}`, + } + + for _, chunk := range chunks { + fmt.Fprintf(w, "data: %s\n\n", chunk) + flusher.Flush() + } + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + } + })) + defer server.Close() + + adapter, _ := NewAdapter(backends.BackendConfig{ + Type: backends.BackendTypeLlamaCpp, + BaseURL: server.URL, + }) + + streaming := true + req := &backends.ChatRequest{ + Model: "llama3.2:8b", + Messages: []backends.ChatMessage{ + {Role: "user", Content: "What's the weather in Tokyo?"}, + }, + Stream: &streaming, + Tools: []backends.Tool{ + { + Type: "function", + Function: struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` + }{ + Name: "get_weather", + Description: "Get weather for a location", + }, + }, + }, + } + + chunkCh, err := adapter.StreamChat(context.Background(), req) + if err != nil { + t.Fatalf("StreamChat() error = %v", err) + } + + var lastChunk backends.ChatChunk + for chunk := range chunkCh { + lastChunk = chunk + } + + if !lastChunk.Done { + t.Error("Last chunk should have Done=true") + } + }) +} diff --git a/backend/internal/backends/registry.go b/backend/internal/backends/registry.go new file mode 100644 index 0000000..65bce06 --- /dev/null +++ b/backend/internal/backends/registry.go @@ -0,0 +1,242 @@ +package backends + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" +) + +// Registry manages multiple LLM backend instances +type Registry struct { + mu sync.RWMutex + backends map[BackendType]LLMBackend + active BackendType +} + +// NewRegistry creates a new backend registry +func NewRegistry() *Registry { + return &Registry{ + backends: make(map[BackendType]LLMBackend), + } +} + +// Register adds a backend to the registry +func (r *Registry) Register(backend LLMBackend) error { + r.mu.Lock() + defer r.mu.Unlock() + + bt := backend.Type() + if _, exists := r.backends[bt]; exists { + return fmt.Errorf("backend %q already registered", bt) + } + + r.backends[bt] = backend + return nil +} + +// Unregister removes a backend from the registry +func (r *Registry) Unregister(backendType BackendType) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.backends[backendType]; !exists { + return fmt.Errorf("backend %q not registered", backendType) + } + + delete(r.backends, backendType) + + // Clear active if it was the unregistered backend + if r.active == backendType { + r.active = "" + } + + return nil +} + +// Get retrieves a backend by type +func (r *Registry) Get(backendType BackendType) (LLMBackend, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + backend, ok := r.backends[backendType] + return backend, ok +} + +// SetActive sets the active backend +func (r *Registry) SetActive(backendType BackendType) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.backends[backendType]; !exists { + return fmt.Errorf("backend %q not registered", backendType) + } + + r.active = backendType + return nil +} + +// Active returns the currently active backend +func (r *Registry) Active() LLMBackend { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.active == "" { + return nil + } + + return r.backends[r.active] +} + +// ActiveType returns the type of the currently active backend +func (r *Registry) ActiveType() BackendType { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.active +} + +// Backends returns all registered backend types +func (r *Registry) Backends() []BackendType { + r.mu.RLock() + defer r.mu.RUnlock() + + types := make([]BackendType, 0, len(r.backends)) + for bt := range r.backends { + types = append(types, bt) + } + return types +} + +// AllInfo returns information about all registered backends +func (r *Registry) AllInfo(ctx context.Context) []BackendInfo { + r.mu.RLock() + defer r.mu.RUnlock() + + infos := make([]BackendInfo, 0, len(r.backends)) + for _, backend := range r.backends { + infos = append(infos, backend.Info(ctx)) + } + return infos +} + +// DiscoveryEndpoint represents a potential backend endpoint to probe +type DiscoveryEndpoint struct { + Type BackendType + BaseURL string +} + +// DiscoveryResult represents the result of probing an endpoint +type DiscoveryResult struct { + Type BackendType `json:"type"` + BaseURL string `json:"baseUrl"` + Available bool `json:"available"` + Version string `json:"version,omitempty"` + Error string `json:"error,omitempty"` +} + +// Discover probes the given endpoints to find available backends +func (r *Registry) Discover(ctx context.Context, endpoints []DiscoveryEndpoint) []DiscoveryResult { + results := make([]DiscoveryResult, len(endpoints)) + var wg sync.WaitGroup + + for i, endpoint := range endpoints { + wg.Add(1) + go func(idx int, ep DiscoveryEndpoint) { + defer wg.Done() + results[idx] = probeEndpoint(ctx, ep) + }(i, endpoint) + } + + wg.Wait() + return results +} + +// probeEndpoint checks if a backend is available at the given endpoint +func probeEndpoint(ctx context.Context, endpoint DiscoveryEndpoint) DiscoveryResult { + result := DiscoveryResult{ + Type: endpoint.Type, + BaseURL: endpoint.BaseURL, + } + + client := &http.Client{ + Timeout: 3 * time.Second, + } + + // Determine probe path based on backend type + var probePath string + switch endpoint.Type { + case BackendTypeOllama: + probePath = "/api/version" + case BackendTypeLlamaCpp, BackendTypeLMStudio: + probePath = "/v1/models" + default: + probePath = "/health" + } + + req, err := http.NewRequestWithContext(ctx, "GET", endpoint.BaseURL+probePath, nil) + if err != nil { + result.Error = err.Error() + return result + } + + resp, err := client.Do(req) + if err != nil { + result.Error = err.Error() + return result + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + result.Available = true + } else { + result.Error = fmt.Sprintf("HTTP %d", resp.StatusCode) + } + + return result +} + +// DefaultDiscoveryEndpoints returns the default endpoints to probe +func DefaultDiscoveryEndpoints() []DiscoveryEndpoint { + return []DiscoveryEndpoint{ + {Type: BackendTypeOllama, BaseURL: "http://localhost:11434"}, + {Type: BackendTypeLlamaCpp, BaseURL: "http://localhost:8081"}, + {Type: BackendTypeLlamaCpp, BaseURL: "http://localhost:8080"}, + {Type: BackendTypeLMStudio, BaseURL: "http://localhost:1234"}, + } +} + +// DiscoverAndRegister probes endpoints and registers available backends +func (r *Registry) DiscoverAndRegister(ctx context.Context, endpoints []DiscoveryEndpoint, adapterFactory AdapterFactory) []DiscoveryResult { + results := r.Discover(ctx, endpoints) + + for _, result := range results { + if !result.Available { + continue + } + + // Skip if already registered + if _, exists := r.Get(result.Type); exists { + continue + } + + config := BackendConfig{ + Type: result.Type, + BaseURL: result.BaseURL, + Enabled: true, + } + + adapter, err := adapterFactory(config) + if err != nil { + continue + } + + r.Register(adapter) + } + + return results +} + +// AdapterFactory creates an LLMBackend from a config +type AdapterFactory func(config BackendConfig) (LLMBackend, error) diff --git a/backend/internal/backends/registry_test.go b/backend/internal/backends/registry_test.go new file mode 100644 index 0000000..f23ebde --- /dev/null +++ b/backend/internal/backends/registry_test.go @@ -0,0 +1,352 @@ +package backends + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewRegistry(t *testing.T) { + registry := NewRegistry() + + if registry == nil { + t.Fatal("NewRegistry() returned nil") + } + + if len(registry.Backends()) != 0 { + t.Errorf("New registry should have no backends, got %d", len(registry.Backends())) + } + + if registry.Active() != nil { + t.Error("New registry should have no active backend") + } +} + +func TestRegistry_Register(t *testing.T) { + registry := NewRegistry() + + // Create a mock backend + mock := &mockBackend{ + backendType: BackendTypeOllama, + config: BackendConfig{ + Type: BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + } + + err := registry.Register(mock) + if err != nil { + t.Fatalf("Register() error = %v", err) + } + + if len(registry.Backends()) != 1 { + t.Errorf("Registry should have 1 backend, got %d", len(registry.Backends())) + } + + // Should not allow duplicate registration + err = registry.Register(mock) + if err == nil { + t.Error("Register() should fail for duplicate backend type") + } +} + +func TestRegistry_Get(t *testing.T) { + registry := NewRegistry() + + mock := &mockBackend{ + backendType: BackendTypeOllama, + config: BackendConfig{ + Type: BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + } + registry.Register(mock) + + t.Run("existing backend", func(t *testing.T) { + backend, ok := registry.Get(BackendTypeOllama) + if !ok { + t.Error("Get() should return ok=true for registered backend") + } + if backend != mock { + t.Error("Get() returned wrong backend") + } + }) + + t.Run("non-existing backend", func(t *testing.T) { + _, ok := registry.Get(BackendTypeLlamaCpp) + if ok { + t.Error("Get() should return ok=false for unregistered backend") + } + }) +} + +func TestRegistry_SetActive(t *testing.T) { + registry := NewRegistry() + + mock := &mockBackend{ + backendType: BackendTypeOllama, + config: BackendConfig{ + Type: BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + } + registry.Register(mock) + + t.Run("set registered backend as active", func(t *testing.T) { + err := registry.SetActive(BackendTypeOllama) + if err != nil { + t.Errorf("SetActive() error = %v", err) + } + + active := registry.Active() + if active == nil { + t.Fatal("Active() returned nil after SetActive()") + } + if active.Type() != BackendTypeOllama { + t.Errorf("Active().Type() = %v, want %v", active.Type(), BackendTypeOllama) + } + }) + + t.Run("set unregistered backend as active", func(t *testing.T) { + err := registry.SetActive(BackendTypeLlamaCpp) + if err == nil { + t.Error("SetActive() should fail for unregistered backend") + } + }) +} + +func TestRegistry_ActiveType(t *testing.T) { + registry := NewRegistry() + + t.Run("no active backend", func(t *testing.T) { + activeType := registry.ActiveType() + if activeType != "" { + t.Errorf("ActiveType() = %q, want empty string", activeType) + } + }) + + t.Run("with active backend", func(t *testing.T) { + mock := &mockBackend{backendType: BackendTypeOllama} + registry.Register(mock) + registry.SetActive(BackendTypeOllama) + + activeType := registry.ActiveType() + if activeType != BackendTypeOllama { + t.Errorf("ActiveType() = %v, want %v", activeType, BackendTypeOllama) + } + }) +} + +func TestRegistry_Unregister(t *testing.T) { + registry := NewRegistry() + + mock := &mockBackend{backendType: BackendTypeOllama} + registry.Register(mock) + registry.SetActive(BackendTypeOllama) + + err := registry.Unregister(BackendTypeOllama) + if err != nil { + t.Errorf("Unregister() error = %v", err) + } + + if len(registry.Backends()) != 0 { + t.Error("Registry should have no backends after unregister") + } + + if registry.Active() != nil { + t.Error("Active backend should be nil after unregistering it") + } +} + +func TestRegistry_AllInfo(t *testing.T) { + registry := NewRegistry() + + mock1 := &mockBackend{ + backendType: BackendTypeOllama, + config: BackendConfig{Type: BackendTypeOllama, BaseURL: "http://localhost:11434"}, + info: BackendInfo{ + Type: BackendTypeOllama, + Status: BackendStatusConnected, + Version: "0.1.0", + }, + } + mock2 := &mockBackend{ + backendType: BackendTypeLlamaCpp, + config: BackendConfig{Type: BackendTypeLlamaCpp, BaseURL: "http://localhost:8081"}, + info: BackendInfo{ + Type: BackendTypeLlamaCpp, + Status: BackendStatusDisconnected, + }, + } + + registry.Register(mock1) + registry.Register(mock2) + registry.SetActive(BackendTypeOllama) + + infos := registry.AllInfo(context.Background()) + + if len(infos) != 2 { + t.Errorf("AllInfo() returned %d infos, want 2", len(infos)) + } + + // Find the active one + var foundActive bool + for _, info := range infos { + if info.Type == BackendTypeOllama { + foundActive = true + } + } + if !foundActive { + t.Error("AllInfo() did not include ollama backend info") + } +} + +func TestRegistry_Discover(t *testing.T) { + // Create test servers for each backend type + ollamaServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/version" || r.URL.Path == "/" { + json.NewEncoder(w).Encode(map[string]string{"version": "0.3.0"}) + } + })) + defer ollamaServer.Close() + + llamacppServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/models" { + json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]string{{"id": "llama3.2:8b"}}, + }) + } + if r.URL.Path == "/health" { + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + } + })) + defer llamacppServer.Close() + + registry := NewRegistry() + + // Configure discovery endpoints + endpoints := []DiscoveryEndpoint{ + {Type: BackendTypeOllama, BaseURL: ollamaServer.URL}, + {Type: BackendTypeLlamaCpp, BaseURL: llamacppServer.URL}, + {Type: BackendTypeLMStudio, BaseURL: "http://localhost:19999"}, // Not running + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + results := registry.Discover(ctx, endpoints) + + if len(results) != 3 { + t.Errorf("Discover() returned %d results, want 3", len(results)) + } + + // Check Ollama was discovered + var ollamaResult *DiscoveryResult + for i := range results { + if results[i].Type == BackendTypeOllama { + ollamaResult = &results[i] + break + } + } + + if ollamaResult == nil { + t.Fatal("Ollama not found in discovery results") + } + if !ollamaResult.Available { + t.Errorf("Ollama should be available, error: %s", ollamaResult.Error) + } + + // Check LM Studio was not discovered + var lmstudioResult *DiscoveryResult + for i := range results { + if results[i].Type == BackendTypeLMStudio { + lmstudioResult = &results[i] + break + } + } + + if lmstudioResult == nil { + t.Fatal("LM Studio not found in discovery results") + } + if lmstudioResult.Available { + t.Error("LM Studio should NOT be available") + } +} + +func TestRegistry_DefaultEndpoints(t *testing.T) { + endpoints := DefaultDiscoveryEndpoints() + + if len(endpoints) < 3 { + t.Errorf("DefaultDiscoveryEndpoints() returned %d endpoints, want at least 3", len(endpoints)) + } + + // Check that all expected types are present + types := make(map[BackendType]bool) + for _, e := range endpoints { + types[e.Type] = true + } + + if !types[BackendTypeOllama] { + t.Error("DefaultDiscoveryEndpoints() missing Ollama") + } + if !types[BackendTypeLlamaCpp] { + t.Error("DefaultDiscoveryEndpoints() missing llama.cpp") + } + if !types[BackendTypeLMStudio] { + t.Error("DefaultDiscoveryEndpoints() missing LM Studio") + } +} + +// mockBackend implements LLMBackend for testing +type mockBackend struct { + backendType BackendType + config BackendConfig + info BackendInfo + healthErr error + models []Model +} + +func (m *mockBackend) Type() BackendType { + return m.backendType +} + +func (m *mockBackend) Config() BackendConfig { + return m.config +} + +func (m *mockBackend) HealthCheck(ctx context.Context) error { + return m.healthErr +} + +func (m *mockBackend) ListModels(ctx context.Context) ([]Model, error) { + return m.models, nil +} + +func (m *mockBackend) StreamChat(ctx context.Context, req *ChatRequest) (<-chan ChatChunk, error) { + ch := make(chan ChatChunk) + close(ch) + return ch, nil +} + +func (m *mockBackend) Chat(ctx context.Context, req *ChatRequest) (*ChatChunk, error) { + return &ChatChunk{Done: true}, nil +} + +func (m *mockBackend) Capabilities() BackendCapabilities { + return OllamaCapabilities() +} + +func (m *mockBackend) Info(ctx context.Context) BackendInfo { + if m.info.Type != "" { + return m.info + } + return BackendInfo{ + Type: m.backendType, + BaseURL: m.config.BaseURL, + Status: BackendStatusConnected, + Capabilities: m.Capabilities(), + } +} diff --git a/backend/internal/backends/types.go b/backend/internal/backends/types.go new file mode 100644 index 0000000..f7b7fd4 --- /dev/null +++ b/backend/internal/backends/types.go @@ -0,0 +1,245 @@ +package backends + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +// BackendType identifies the type of LLM backend +type BackendType string + +const ( + BackendTypeOllama BackendType = "ollama" + BackendTypeLlamaCpp BackendType = "llamacpp" + BackendTypeLMStudio BackendType = "lmstudio" +) + +// String returns the string representation of the backend type +func (bt BackendType) String() string { + return string(bt) +} + +// ParseBackendType parses a string into a BackendType +func ParseBackendType(s string) (BackendType, error) { + switch strings.ToLower(s) { + case "ollama": + return BackendTypeOllama, nil + case "llamacpp", "llama.cpp", "llama-cpp": + return BackendTypeLlamaCpp, nil + case "lmstudio", "lm-studio", "lm_studio": + return BackendTypeLMStudio, nil + default: + return "", fmt.Errorf("unknown backend type: %q", s) + } +} + +// BackendCapabilities describes what features a backend supports +type BackendCapabilities struct { + CanListModels bool `json:"canListModels"` + CanPullModels bool `json:"canPullModels"` + CanDeleteModels bool `json:"canDeleteModels"` + CanCreateModels bool `json:"canCreateModels"` + CanStreamChat bool `json:"canStreamChat"` + CanEmbed bool `json:"canEmbed"` +} + +// OllamaCapabilities returns the capabilities for Ollama backend +func OllamaCapabilities() BackendCapabilities { + return BackendCapabilities{ + CanListModels: true, + CanPullModels: true, + CanDeleteModels: true, + CanCreateModels: true, + CanStreamChat: true, + CanEmbed: true, + } +} + +// LlamaCppCapabilities returns the capabilities for llama.cpp backend +func LlamaCppCapabilities() BackendCapabilities { + return BackendCapabilities{ + CanListModels: true, + CanPullModels: false, + CanDeleteModels: false, + CanCreateModels: false, + CanStreamChat: true, + CanEmbed: true, + } +} + +// LMStudioCapabilities returns the capabilities for LM Studio backend +func LMStudioCapabilities() BackendCapabilities { + return BackendCapabilities{ + CanListModels: true, + CanPullModels: false, + CanDeleteModels: false, + CanCreateModels: false, + CanStreamChat: true, + CanEmbed: true, + } +} + +// BackendStatus represents the connection status of a backend +type BackendStatus string + +const ( + BackendStatusConnected BackendStatus = "connected" + BackendStatusDisconnected BackendStatus = "disconnected" + BackendStatusUnknown BackendStatus = "unknown" +) + +// BackendConfig holds configuration for a backend +type BackendConfig struct { + Type BackendType `json:"type"` + BaseURL string `json:"baseUrl"` + Enabled bool `json:"enabled"` +} + +// Validate checks if the backend config is valid +func (c BackendConfig) Validate() error { + if c.BaseURL == "" { + return errors.New("base URL is required") + } + + u, err := url.Parse(c.BaseURL) + if err != nil { + return fmt.Errorf("invalid base URL: %w", err) + } + + if u.Scheme == "" || u.Host == "" { + return errors.New("invalid URL: missing scheme or host") + } + + return nil +} + +// BackendInfo describes a configured backend and its current state +type BackendInfo struct { + Type BackendType `json:"type"` + BaseURL string `json:"baseUrl"` + Status BackendStatus `json:"status"` + Capabilities BackendCapabilities `json:"capabilities"` + Version string `json:"version,omitempty"` + Error string `json:"error,omitempty"` +} + +// IsConnected returns true if the backend is connected +func (bi BackendInfo) IsConnected() bool { + return bi.Status == BackendStatusConnected +} + +// Model represents an LLM model available from a backend +type Model struct { + ID string `json:"id"` + Name string `json:"name"` + Size int64 `json:"size,omitempty"` + ModifiedAt string `json:"modifiedAt,omitempty"` + Family string `json:"family,omitempty"` + QuantLevel string `json:"quantLevel,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// HasCapability checks if the model has a specific capability +func (m Model) HasCapability(cap string) bool { + for _, c := range m.Capabilities { + if c == cap { + return true + } + } + return false +} + +// ChatMessage represents a message in a chat conversation +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Images []string `json:"images,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +var validRoles = map[string]bool{ + "user": true, + "assistant": true, + "system": true, + "tool": true, +} + +// Validate checks if the chat message is valid +func (m ChatMessage) Validate() error { + if m.Role == "" { + return errors.New("role is required") + } + if !validRoles[m.Role] { + return fmt.Errorf("invalid role: %q", m.Role) + } + return nil +} + +// ToolCall represents a tool invocation +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + +// Tool represents a tool definition +type Tool struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` + } `json:"function"` +} + +// ChatRequest represents a chat completion request +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Stream *bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Options map[string]any `json:"options,omitempty"` +} + +// Validate checks if the chat request is valid +func (r ChatRequest) Validate() error { + if r.Model == "" { + return errors.New("model is required") + } + if len(r.Messages) == 0 { + return errors.New("at least one message is required") + } + for i, msg := range r.Messages { + if err := msg.Validate(); err != nil { + return fmt.Errorf("message %d: %w", i, err) + } + } + return nil +} + +// ChatChunk represents a streaming chat response chunk +type ChatChunk struct { + Model string `json:"model"` + CreatedAt string `json:"created_at,omitempty"` + Message *ChatMessage `json:"message,omitempty"` + Done bool `json:"done"` + DoneReason string `json:"done_reason,omitempty"` + + // Token counts (final chunk only) + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + + // Error information + Error string `json:"error,omitempty"` +} diff --git a/backend/internal/backends/types_test.go b/backend/internal/backends/types_test.go new file mode 100644 index 0000000..bd37c2d --- /dev/null +++ b/backend/internal/backends/types_test.go @@ -0,0 +1,323 @@ +package backends + +import ( + "testing" +) + +func TestBackendType_String(t *testing.T) { + tests := []struct { + name string + bt BackendType + expected string + }{ + {"ollama type", BackendTypeOllama, "ollama"}, + {"llamacpp type", BackendTypeLlamaCpp, "llamacpp"}, + {"lmstudio type", BackendTypeLMStudio, "lmstudio"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.bt.String(); got != tt.expected { + t.Errorf("BackendType.String() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestParseBackendType(t *testing.T) { + tests := []struct { + name string + input string + expected BackendType + expectErr bool + }{ + {"parse ollama", "ollama", BackendTypeOllama, false}, + {"parse llamacpp", "llamacpp", BackendTypeLlamaCpp, false}, + {"parse lmstudio", "lmstudio", BackendTypeLMStudio, false}, + {"parse llama.cpp alias", "llama.cpp", BackendTypeLlamaCpp, false}, + {"parse llama-cpp alias", "llama-cpp", BackendTypeLlamaCpp, false}, + {"parse unknown", "unknown", "", true}, + {"parse empty", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseBackendType(tt.input) + if (err != nil) != tt.expectErr { + t.Errorf("ParseBackendType() error = %v, expectErr %v", err, tt.expectErr) + return + } + if got != tt.expected { + t.Errorf("ParseBackendType() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestBackendCapabilities(t *testing.T) { + t.Run("ollama capabilities", func(t *testing.T) { + caps := OllamaCapabilities() + + if !caps.CanListModels { + t.Error("Ollama should be able to list models") + } + if !caps.CanPullModels { + t.Error("Ollama should be able to pull models") + } + if !caps.CanDeleteModels { + t.Error("Ollama should be able to delete models") + } + if !caps.CanCreateModels { + t.Error("Ollama should be able to create models") + } + if !caps.CanStreamChat { + t.Error("Ollama should be able to stream chat") + } + if !caps.CanEmbed { + t.Error("Ollama should be able to embed") + } + }) + + t.Run("llamacpp capabilities", func(t *testing.T) { + caps := LlamaCppCapabilities() + + if !caps.CanListModels { + t.Error("llama.cpp should be able to list models") + } + if caps.CanPullModels { + t.Error("llama.cpp should NOT be able to pull models") + } + if caps.CanDeleteModels { + t.Error("llama.cpp should NOT be able to delete models") + } + if caps.CanCreateModels { + t.Error("llama.cpp should NOT be able to create models") + } + if !caps.CanStreamChat { + t.Error("llama.cpp should be able to stream chat") + } + if !caps.CanEmbed { + t.Error("llama.cpp should be able to embed") + } + }) + + t.Run("lmstudio capabilities", func(t *testing.T) { + caps := LMStudioCapabilities() + + if !caps.CanListModels { + t.Error("LM Studio should be able to list models") + } + if caps.CanPullModels { + t.Error("LM Studio should NOT be able to pull models") + } + if caps.CanDeleteModels { + t.Error("LM Studio should NOT be able to delete models") + } + if caps.CanCreateModels { + t.Error("LM Studio should NOT be able to create models") + } + if !caps.CanStreamChat { + t.Error("LM Studio should be able to stream chat") + } + if !caps.CanEmbed { + t.Error("LM Studio should be able to embed") + } + }) +} + +func TestBackendConfig_Validate(t *testing.T) { + tests := []struct { + name string + config BackendConfig + expectErr bool + }{ + { + name: "valid ollama config", + config: BackendConfig{ + Type: BackendTypeOllama, + BaseURL: "http://localhost:11434", + }, + expectErr: false, + }, + { + name: "valid llamacpp config", + config: BackendConfig{ + Type: BackendTypeLlamaCpp, + BaseURL: "http://localhost:8081", + }, + expectErr: false, + }, + { + name: "empty base URL", + config: BackendConfig{ + Type: BackendTypeOllama, + BaseURL: "", + }, + expectErr: true, + }, + { + name: "invalid URL", + config: BackendConfig{ + Type: BackendTypeOllama, + BaseURL: "not-a-url", + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.expectErr { + t.Errorf("BackendConfig.Validate() error = %v, expectErr %v", err, tt.expectErr) + } + }) + } +} + +func TestModel_HasCapability(t *testing.T) { + model := Model{ + ID: "llama3.2:8b", + Name: "llama3.2:8b", + Capabilities: []string{"chat", "vision", "tools"}, + } + + tests := []struct { + name string + capability string + expected bool + }{ + {"has chat", "chat", true}, + {"has vision", "vision", true}, + {"has tools", "tools", true}, + {"no thinking", "thinking", false}, + {"no code", "code", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := model.HasCapability(tt.capability); got != tt.expected { + t.Errorf("Model.HasCapability(%q) = %v, want %v", tt.capability, got, tt.expected) + } + }) + } +} + +func TestChatMessage_Validation(t *testing.T) { + tests := []struct { + name string + msg ChatMessage + expectErr bool + }{ + { + name: "valid user message", + msg: ChatMessage{Role: "user", Content: "Hello"}, + expectErr: false, + }, + { + name: "valid assistant message", + msg: ChatMessage{Role: "assistant", Content: "Hi there"}, + expectErr: false, + }, + { + name: "valid system message", + msg: ChatMessage{Role: "system", Content: "You are helpful"}, + expectErr: false, + }, + { + name: "invalid role", + msg: ChatMessage{Role: "invalid", Content: "Hello"}, + expectErr: true, + }, + { + name: "empty role", + msg: ChatMessage{Role: "", Content: "Hello"}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.Validate() + if (err != nil) != tt.expectErr { + t.Errorf("ChatMessage.Validate() error = %v, expectErr %v", err, tt.expectErr) + } + }) + } +} + +func TestChatRequest_Validation(t *testing.T) { + streaming := true + + tests := []struct { + name string + req ChatRequest + expectErr bool + }{ + { + name: "valid request", + req: ChatRequest{ + Model: "llama3.2:8b", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + Stream: &streaming, + }, + expectErr: false, + }, + { + name: "empty model", + req: ChatRequest{ + Model: "", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + }, + expectErr: true, + }, + { + name: "empty messages", + req: ChatRequest{ + Model: "llama3.2:8b", + Messages: []ChatMessage{}, + }, + expectErr: true, + }, + { + name: "nil messages", + req: ChatRequest{ + Model: "llama3.2:8b", + Messages: nil, + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + if (err != nil) != tt.expectErr { + t.Errorf("ChatRequest.Validate() error = %v, expectErr %v", err, tt.expectErr) + } + }) + } +} + +func TestBackendInfo(t *testing.T) { + info := BackendInfo{ + Type: BackendTypeOllama, + BaseURL: "http://localhost:11434", + Status: BackendStatusConnected, + Capabilities: OllamaCapabilities(), + Version: "0.1.0", + } + + if !info.IsConnected() { + t.Error("BackendInfo.IsConnected() should be true when status is connected") + } + + info.Status = BackendStatusDisconnected + if info.IsConnected() { + t.Error("BackendInfo.IsConnected() should be false when status is disconnected") + } +} diff --git a/frontend/src/lib/components/chat/BranchNavigator.svelte b/frontend/src/lib/components/chat/BranchNavigator.svelte index 291e83d..148e368 100644 --- a/frontend/src/lib/components/chat/BranchNavigator.svelte +++ b/frontend/src/lib/components/chat/BranchNavigator.svelte @@ -2,7 +2,6 @@ /** * BranchNavigator - Navigate between message branches * Shows "< 1/3 >" style navigation for sibling messages - * Supports keyboard navigation with arrow keys when focused */ import type { BranchInfo } from '$lib/types'; @@ -15,7 +14,7 @@ const { branchInfo, onSwitch }: Props = $props(); // Reference to the navigator container for focus management - let navigatorRef: HTMLDivElement | null = $state(null); + let navigatorRef: HTMLElement | null = $state(null); // Track transition state for smooth animations let isTransitioning = $state(false); @@ -52,7 +51,7 @@ } /** - * Handle keyboard navigation when the component is focused + * Handle keyboard navigation with arrow keys */ function handleKeydown(event: KeyboardEvent): void { if (event.key === 'ArrowLeft' && canGoPrev) { @@ -65,11 +64,10 @@ } - +