Compare commits
93 Commits
33d11ae223
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| d86888704f | |||
| de6b6e20a5 | |||
| 1e8a5e08ed | |||
| 218ebbf32f | |||
| c49e7f4b22 | |||
| 9588c8c562 | |||
| 1948ac1284 | |||
| 3f92b7d963 | |||
| 5553e61dbf | |||
| 7f987737f9 | |||
| 5182f86133 | |||
| a50099ad74 | |||
| 20ba5523ee | |||
| 0b2b3701dc | |||
| 438b05b8a3 | |||
| e2a31b192f | |||
| b827d3d047 | |||
| 9c0cf274a3 | |||
| 85ae319690 | |||
| 449f133a1f | |||
| 2f6b03ef65 | |||
| d4030dc598 | |||
| 3271697f6b | |||
| cbfef5a5df | |||
| 52efd5f341 | |||
| 200cdbc4bd | |||
| 8525819ab4 | |||
| bcd52d526c | |||
| 7effade1d3 | |||
| dc0fee2ee3 | |||
| ea04a25ed6 | |||
| 282dcdce88 | |||
| b49f58bc16 | |||
| cdc425ae93 | |||
| 3525cb3949 | |||
| 9d85420bf6 | |||
| 641c95131f | |||
| 708c626176 | |||
| 5210e196f2 | |||
| 30c375b6c5 | |||
| baf49b1e69 | |||
| 96e0436d43 | |||
| 498e6e61b6 | |||
| 99064b6c41 | |||
| ee58b0ac32 | |||
| 990f93d467 | |||
| 44a00619b5 | |||
| 6923ee439f | |||
| c997b19b53 | |||
| c9daf68fea | |||
| ba9d083088 | |||
| 825dfc0722 | |||
| 3e4eacd1d3 | |||
| 23253219a3 | |||
| cc2b85a86d | |||
| 58dd6f3efa | |||
| c81d0f1593 | |||
| ae0dd3fc51 | |||
| 80dffa9f41 | |||
| ab0ae4fe04 | |||
| d31e068277 | |||
| 690f5c7056 | |||
| 0da8a3f193 | |||
| 15f81d9728 | |||
| b80db89391 | |||
| f413a63c5a | |||
| 33ad3797a1 | |||
| 55e6b0583d | |||
| ae9c3af096 | |||
| 0bd560b408 | |||
| 083b621b7d | |||
| d2a193e5c1 | |||
| acbfe47a4b | |||
| 60c859b3ab | |||
| 82078afd6d | |||
| 7851af14a9 | |||
| c2f5ccea3b | |||
| fab63d224b | |||
| 15e5c1206b | |||
| 38aba1a6bb | |||
| d0d3079df5 | |||
| 56de1170ee | |||
| 952e4819fe | |||
| 5ac0d152cb | |||
| 40c44470e8 | |||
| 5c37df1b22 | |||
| 5e81185df3 | |||
| 7534c9ef8d | |||
| 9545a4b3ad | |||
| e94df2c48a | |||
| cdf95002fc | |||
| 4c066bf2da | |||
| e57844e742 |
34
.github/workflows/macos-check.yml
vendored
Normal file
34
.github/workflows/macos-check.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: macos-check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- dev
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- dev
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: cargo check (macOS)
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout sources
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Rust toolchain
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
|
||||||
|
- name: Cache Cargo registry
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/registry
|
||||||
|
~/.cargo/git
|
||||||
|
target
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-cargo-
|
||||||
|
|
||||||
|
- name: Cargo check
|
||||||
|
run: cargo check --workspace --all-features
|
||||||
@@ -9,6 +9,7 @@ repos:
|
|||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
|
args: ['--allow-multiple-documents']
|
||||||
- id: check-toml
|
- id: check-toml
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
|
|||||||
@@ -1,3 +1,61 @@
|
|||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
name: pr-checks
|
||||||
|
|
||||||
|
when:
|
||||||
|
event:
|
||||||
|
- push
|
||||||
|
- pull_request
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: fmt-clippy-test
|
||||||
|
image: rust:1.83
|
||||||
|
commands:
|
||||||
|
- rustup component add rustfmt clippy
|
||||||
|
- cargo fmt --all -- --check
|
||||||
|
- cargo clippy --workspace --all-features -- -D warnings
|
||||||
|
- cargo test --workspace --all-features
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
name: security-audit
|
||||||
|
|
||||||
|
when:
|
||||||
|
event:
|
||||||
|
- push
|
||||||
|
- cron
|
||||||
|
branch:
|
||||||
|
- dev
|
||||||
|
cron: weekly-security
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: cargo-audit
|
||||||
|
image: rust:1.83
|
||||||
|
commands:
|
||||||
|
- cargo install cargo-audit --locked
|
||||||
|
- cargo audit
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
name: release-tests
|
||||||
|
|
||||||
|
when:
|
||||||
|
event: tag
|
||||||
|
tag: v*
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: workspace-tests
|
||||||
|
image: rust:1.83
|
||||||
|
commands:
|
||||||
|
- rustup component add llvm-tools-preview
|
||||||
|
- cargo install cargo-llvm-cov --locked
|
||||||
|
- cargo llvm-cov --workspace --all-features --summary-only
|
||||||
|
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
name: release
|
||||||
|
|
||||||
when:
|
when:
|
||||||
event: tag
|
event: tag
|
||||||
tag: v*
|
tag: v*
|
||||||
@@ -5,6 +63,9 @@ when:
|
|||||||
variables:
|
variables:
|
||||||
- &rust_image 'rust:1.83'
|
- &rust_image 'rust:1.83'
|
||||||
|
|
||||||
|
depends_on:
|
||||||
|
- release-tests
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
# Linux
|
# Linux
|
||||||
@@ -116,6 +177,11 @@ steps:
|
|||||||
sha256sum ${ARTIFACT}.tar.gz > ${ARTIFACT}.tar.gz.sha256
|
sha256sum ${ARTIFACT}.tar.gz > ${ARTIFACT}.tar.gz.sha256
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: release-notes
|
||||||
|
image: *rust_image
|
||||||
|
commands:
|
||||||
|
- scripts/release-notes.sh "${CI_COMMIT_TAG}" release-notes.md
|
||||||
|
|
||||||
- name: release
|
- name: release
|
||||||
image: plugins/gitea-release
|
image: plugins/gitea-release
|
||||||
settings:
|
settings:
|
||||||
@@ -128,4 +194,4 @@ steps:
|
|||||||
- ${ARTIFACT}.zip
|
- ${ARTIFACT}.zip
|
||||||
- ${ARTIFACT}.zip.sha256
|
- ${ARTIFACT}.zip.sha256
|
||||||
title: Release ${CI_COMMIT_TAG}
|
title: Release ${CI_COMMIT_TAG}
|
||||||
note: "Release ${CI_COMMIT_TAG}"
|
note_file: release-notes.md
|
||||||
|
|||||||
798
AGENTS.md
Normal file
798
AGENTS.md
Normal file
@@ -0,0 +1,798 @@
|
|||||||
|
# AGENTS.md - AI Agent Instructions for Owlen Development
|
||||||
|
|
||||||
|
This document provides comprehensive context and guidelines for AI agents (Claude, GPT-4, etc.) working on the Owlen codebase.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
**Owlen** is a local-first, terminal-based AI assistant built in Rust using the Ratatui TUI framework. It implements a Model Context Protocol (MCP) architecture for modular tool execution and supports both local (Ollama) and cloud LLM providers.
|
||||||
|
|
||||||
|
**Core Philosophy:**
|
||||||
|
- **Local-first**: Prioritize local LLMs (Ollama) with cloud as fallback
|
||||||
|
- **Privacy-focused**: No telemetry, user data stays on device
|
||||||
|
- **MCP-native**: All operations through MCP servers for modularity
|
||||||
|
- **Terminal-native**: Vim-style modal interaction in a beautiful TUI
|
||||||
|
|
||||||
|
**Current Status:** v1.0 - MCP-only architecture (Phase 10 complete)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
owlen/
|
||||||
|
├── crates/
|
||||||
|
│ ├── owlen-core/ # Core types, config, provider traits
|
||||||
|
│ ├── owlen-tui/ # Ratatui-based terminal interface
|
||||||
|
│ ├── owlen-cli/ # Command-line interface
|
||||||
|
│ ├── owlen-ollama/ # Ollama provider implementation
|
||||||
|
│ ├── owlen-mcp-llm-server/ # LLM inference as MCP server
|
||||||
|
│ ├── owlen-mcp-client/ # MCP client library
|
||||||
|
│ ├── owlen-mcp-server/ # Base MCP server framework
|
||||||
|
│ ├── owlen-mcp-code-server/ # Code execution in Docker
|
||||||
|
│ └── owlen-mcp-prompt-server/ # Prompt management server
|
||||||
|
├── docs/ # Documentation
|
||||||
|
├── themes/ # TUI color themes
|
||||||
|
└── .agents/ # Agent development plans
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Technologies
|
||||||
|
|
||||||
|
- **Language**: Rust 1.83+
|
||||||
|
- **TUI**: Ratatui with Crossterm backend
|
||||||
|
- **Async Runtime**: Tokio
|
||||||
|
- **Config**: TOML (serde)
|
||||||
|
- **HTTP Client**: reqwest
|
||||||
|
- **LLM Providers**: Ollama (primary), with extensibility for OpenAI/Anthropic
|
||||||
|
- **Protocol**: JSON-RPC 2.0 over STDIO/HTTP/WebSocket
|
||||||
|
|
||||||
|
## Current Features (v1.0)
|
||||||
|
|
||||||
|
### Core Capabilities
|
||||||
|
|
||||||
|
1. **MCP Architecture** (Phase 3-10 complete)
|
||||||
|
- All LLM interactions via MCP servers
|
||||||
|
- Local and remote MCP client support
|
||||||
|
- STDIO, HTTP, WebSocket transports
|
||||||
|
- Automatic failover with health checks
|
||||||
|
|
||||||
|
2. **Provider System**
|
||||||
|
- Ollama (local and cloud)
|
||||||
|
- Configurable per-provider settings
|
||||||
|
- API key management with env variable expansion
|
||||||
|
- Model switching via TUI (`:m` command)
|
||||||
|
|
||||||
|
3. **Agentic Loop** (ReAct pattern)
|
||||||
|
- THOUGHT → ACTION → OBSERVATION cycle
|
||||||
|
- Tool discovery and execution
|
||||||
|
- Configurable iteration limits
|
||||||
|
- Emergency stop (Ctrl+C)
|
||||||
|
|
||||||
|
4. **Mode System**
|
||||||
|
- Chat mode: Limited tool availability
|
||||||
|
- Code mode: Full tool access
|
||||||
|
- Tool filtering by mode
|
||||||
|
- Runtime mode switching
|
||||||
|
|
||||||
|
5. **Session Management**
|
||||||
|
- Auto-save conversations
|
||||||
|
- Session persistence with encryption
|
||||||
|
- Description generation
|
||||||
|
- Session timeout management
|
||||||
|
|
||||||
|
6. **Security**
|
||||||
|
- Docker sandboxing for code execution
|
||||||
|
- Tool whitelisting
|
||||||
|
- Permission prompts for dangerous operations
|
||||||
|
- Network isolation options
|
||||||
|
|
||||||
|
### TUI Features
|
||||||
|
|
||||||
|
- Vim-style modal editing (Normal, Insert, Visual, Command modes)
|
||||||
|
- Multi-panel layout (conversation, status, input)
|
||||||
|
- Syntax highlighting for code blocks
|
||||||
|
- Theme system (10+ built-in themes)
|
||||||
|
- Scrollback history (configurable limit)
|
||||||
|
- Word wrap and visual selection
|
||||||
|
|
||||||
|
## Development Guidelines
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
1. **Rust Best Practices**
|
||||||
|
- Use `rustfmt` (pre-commit hook enforced)
|
||||||
|
- Run `cargo clippy` before commits
|
||||||
|
- Prefer `Result` over `panic!` for errors
|
||||||
|
- Document public APIs with `///` comments
|
||||||
|
|
||||||
|
2. **Error Handling**
|
||||||
|
- Use `owlen_core::Error` enum for all errors
|
||||||
|
- Chain errors with context (`.map_err(|e| Error::X(format!(...)))`)
|
||||||
|
- Never unwrap in library code (tests OK)
|
||||||
|
|
||||||
|
3. **Async Patterns**
|
||||||
|
- All I/O operations must be async
|
||||||
|
- Use `tokio::spawn` for background tasks
|
||||||
|
- Prefer `tokio::sync::mpsc` for channels
|
||||||
|
- Always set timeouts for network operations
|
||||||
|
|
||||||
|
4. **Testing**
|
||||||
|
- Unit tests in same file (`#[cfg(test)] mod tests`)
|
||||||
|
- Use mock implementations from `test_utils` modules
|
||||||
|
- Integration tests in `crates/*/tests/`
|
||||||
|
- All public APIs must have tests
|
||||||
|
|
||||||
|
### File Organization
|
||||||
|
|
||||||
|
**When editing existing files:**
|
||||||
|
1. Read the entire file first (use `Read` tool)
|
||||||
|
2. Preserve existing code style and formatting
|
||||||
|
3. Update related tests in the same commit
|
||||||
|
4. Keep changes atomic and focused
|
||||||
|
|
||||||
|
**When creating new files:**
|
||||||
|
1. Check `crates/owlen-core/src/` for similar modules
|
||||||
|
2. Follow existing module structure
|
||||||
|
3. Add to `lib.rs` with appropriate visibility
|
||||||
|
4. Document module purpose with `//!` header
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
**Config file**: `~/.config/owlen/config.toml`
|
||||||
|
|
||||||
|
Example structure:
|
||||||
|
```toml
|
||||||
|
[general]
|
||||||
|
default_provider = "ollama"
|
||||||
|
default_model = "llama3.2:latest"
|
||||||
|
enable_streaming = true
|
||||||
|
|
||||||
|
[mcp]
|
||||||
|
# MCP is always enabled in v1.0+
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
provider_type = "ollama"
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
|
||||||
|
[providers.ollama-cloud]
|
||||||
|
provider_type = "ollama-cloud"
|
||||||
|
base_url = "https://ollama.com"
|
||||||
|
api_key = "$OLLAMA_API_KEY"
|
||||||
|
|
||||||
|
[ui]
|
||||||
|
theme = "default_dark"
|
||||||
|
word_wrap = true
|
||||||
|
|
||||||
|
[security]
|
||||||
|
enable_sandboxing = true
|
||||||
|
allowed_tools = ["web_search", "code_exec"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Tasks
|
||||||
|
|
||||||
|
#### Adding a New Provider
|
||||||
|
|
||||||
|
1. Create `crates/owlen-{provider}/` crate
|
||||||
|
2. Implement `owlen_core::provider::Provider` trait
|
||||||
|
3. Add to `owlen_core::router::ProviderRouter`
|
||||||
|
4. Update config schema in `owlen_core::config`
|
||||||
|
5. Add tests with `MockProvider` pattern
|
||||||
|
6. Document in `docs/provider-implementation.md`
|
||||||
|
|
||||||
|
#### Adding a New MCP Server
|
||||||
|
|
||||||
|
1. Create `crates/owlen-mcp-{name}-server/` crate
|
||||||
|
2. Implement JSON-RPC 2.0 protocol handlers
|
||||||
|
3. Define tool descriptors with JSON schemas
|
||||||
|
4. Add sandboxing/security checks
|
||||||
|
5. Register in `mcp_servers` config array
|
||||||
|
6. Document tool capabilities
|
||||||
|
|
||||||
|
#### Adding a TUI Feature
|
||||||
|
|
||||||
|
1. Modify `crates/owlen-tui/src/chat_app.rs`
|
||||||
|
2. Update keybinding handlers
|
||||||
|
3. Extend UI rendering in `draw()` method
|
||||||
|
4. Add to help screen (`?` command)
|
||||||
|
5. Test with different terminal sizes
|
||||||
|
6. Ensure theme compatibility
|
||||||
|
|
||||||
|
## Feature Parity Roadmap
|
||||||
|
|
||||||
|
Based on analysis of OpenAI Codex and Claude Code, here are prioritized features to implement:
|
||||||
|
|
||||||
|
### Phase 11: MCP Client Enhancement (HIGHEST PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Full MCP client capabilities to access ecosystem tools
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **MCP Server Management**
|
||||||
|
- `owlen mcp add/list/remove` commands
|
||||||
|
- Three config scopes: local, project (`.mcp.json`), user
|
||||||
|
- Environment variable expansion in config
|
||||||
|
- OAuth 2.0 authentication for remote servers
|
||||||
|
|
||||||
|
2. **MCP Resource References**
|
||||||
|
- `@github:issue://123` syntax
|
||||||
|
- `@postgres:schema://users` syntax
|
||||||
|
- Auto-completion for resources
|
||||||
|
|
||||||
|
3. **MCP Prompts as Slash Commands**
|
||||||
|
- `/mcp__github__list_prs`
|
||||||
|
- Dynamic command registration
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- Extend `owlen-mcp-client` crate
|
||||||
|
- Add `.mcp.json` parsing to `owlen-core::config`
|
||||||
|
- Update TUI command parser for `@` and `/mcp__` syntax
|
||||||
|
- Add OAuth flow to TUI
|
||||||
|
|
||||||
|
**Files to modify:**
|
||||||
|
- `crates/owlen-mcp-client/src/lib.rs`
|
||||||
|
- `crates/owlen-core/src/config.rs`
|
||||||
|
- `crates/owlen-tui/src/command_parser.rs`
|
||||||
|
|
||||||
|
### Phase 12: Approval & Sandbox System (HIGHEST PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Safe agentic behavior with user control
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **Three-tier Approval Modes**
|
||||||
|
- `suggest`: Approve ALL file writes and shell commands (default)
|
||||||
|
- `auto-edit`: Auto-approve file changes, prompt for shell
|
||||||
|
- `full-auto`: Auto-approve everything (requires Git repo)
|
||||||
|
|
||||||
|
2. **Platform-specific Sandboxing**
|
||||||
|
- Linux: Docker with network isolation
|
||||||
|
- macOS: Apple Seatbelt (`sandbox-exec`)
|
||||||
|
- Windows: AppContainer or Job Objects
|
||||||
|
|
||||||
|
3. **Permission Management**
|
||||||
|
- `/permissions` command in TUI
|
||||||
|
- Tool allowlist (e.g., `Edit`, `Bash(git commit:*)`)
|
||||||
|
- Stored in `.owlen/settings.json` (project) or `~/.owlen.json` (user)
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- New `owlen-core::approval` module
|
||||||
|
- Extend `owlen-core::sandbox` with platform detection
|
||||||
|
- Update `owlen-mcp-code-server` to use new sandbox
|
||||||
|
- Add permission storage to config system
|
||||||
|
|
||||||
|
**Files to create:**
|
||||||
|
- `crates/owlen-core/src/approval.rs`
|
||||||
|
- `crates/owlen-core/src/sandbox/linux.rs`
|
||||||
|
- `crates/owlen-core/src/sandbox/macos.rs`
|
||||||
|
- `crates/owlen-core/src/sandbox/windows.rs`
|
||||||
|
|
||||||
|
### Phase 13: Project Documentation System (HIGH PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Massive usability improvement with project context
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **OWLEN.md System**
|
||||||
|
- `OWLEN.md` at repo root (checked into git)
|
||||||
|
- `OWLEN.local.md` (gitignored, personal)
|
||||||
|
- `~/.config/owlen/OWLEN.md` (global)
|
||||||
|
- Support nested OWLEN.md in monorepos
|
||||||
|
|
||||||
|
2. **Auto-generation**
|
||||||
|
- `/init` command to generate project-specific OWLEN.md
|
||||||
|
- Analyze codebase structure
|
||||||
|
- Detect build system, test framework
|
||||||
|
- Suggest common commands
|
||||||
|
|
||||||
|
3. **Live Updates**
|
||||||
|
- `#` command to add instructions to OWLEN.md
|
||||||
|
- Context-aware insertion (relevant section)
|
||||||
|
|
||||||
|
**Contents of OWLEN.md:**
|
||||||
|
- Common bash commands
|
||||||
|
- Code style guidelines
|
||||||
|
- Testing instructions
|
||||||
|
- Core files and utilities
|
||||||
|
- Known quirks/warnings
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- New `owlen-core::project_doc` module
|
||||||
|
- File discovery algorithm (walk up directory tree)
|
||||||
|
- Markdown parser for sections
|
||||||
|
- TUI commands: `/init`, `#`
|
||||||
|
|
||||||
|
**Files to create:**
|
||||||
|
- `crates/owlen-core/src/project_doc.rs`
|
||||||
|
- `crates/owlen-tui/src/commands/init.rs`
|
||||||
|
|
||||||
|
### Phase 14: Non-Interactive Mode (HIGH PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Enable CI/CD integration and automation
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **Headless Execution**
|
||||||
|
```bash
|
||||||
|
owlen exec "fix linting errors" --approval-mode auto-edit
|
||||||
|
owlen --quiet "update CHANGELOG" --json
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Environment Variables**
|
||||||
|
- `OWLEN_QUIET_MODE=1`
|
||||||
|
- `OWLEN_DISABLE_PROJECT_DOC=1`
|
||||||
|
- `OWLEN_APPROVAL_MODE=full-auto`
|
||||||
|
|
||||||
|
3. **JSON Output**
|
||||||
|
- Structured output for parsing
|
||||||
|
- Exit codes for success/failure
|
||||||
|
- Progress events on stderr
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- New `owlen-cli` subcommand: `exec`
|
||||||
|
- Extend `owlen-core::session` with non-interactive mode
|
||||||
|
- Add JSON serialization for results
|
||||||
|
- Environment variable parsing in config
|
||||||
|
|
||||||
|
**Files to modify:**
|
||||||
|
- `crates/owlen-cli/src/main.rs`
|
||||||
|
- `crates/owlen-core/src/session.rs`
|
||||||
|
|
||||||
|
### Phase 15: Multi-Provider Expansion (HIGH PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Support cloud providers while maintaining local-first
|
||||||
|
|
||||||
|
**Providers to add:**
|
||||||
|
1. OpenAI (GPT-4, o1, o4-mini)
|
||||||
|
2. Anthropic (Claude 3.5 Sonnet, Opus)
|
||||||
|
3. Google (Gemini Ultra, Pro)
|
||||||
|
4. Mistral AI
|
||||||
|
|
||||||
|
**Configuration:**
|
||||||
|
```toml
|
||||||
|
[providers.openai]
|
||||||
|
api_key = "${OPENAI_API_KEY}"
|
||||||
|
model = "o4-mini"
|
||||||
|
enabled = true
|
||||||
|
|
||||||
|
[providers.anthropic]
|
||||||
|
api_key = "${ANTHROPIC_API_KEY}"
|
||||||
|
model = "claude-3-5-sonnet"
|
||||||
|
enabled = true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Runtime Switching:**
|
||||||
|
```
|
||||||
|
:model ollama/starcoder
|
||||||
|
:model openai/o4-mini
|
||||||
|
:model anthropic/claude-3-5-sonnet
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- Create `owlen-openai`, `owlen-anthropic`, `owlen-google` crates
|
||||||
|
- Implement `Provider` trait for each
|
||||||
|
- Add runtime model switching to TUI
|
||||||
|
- Maintain Ollama as default
|
||||||
|
|
||||||
|
**Files to create:**
|
||||||
|
- `crates/owlen-openai/src/lib.rs`
|
||||||
|
- `crates/owlen-anthropic/src/lib.rs`
|
||||||
|
- `crates/owlen-google/src/lib.rs`
|
||||||
|
|
||||||
|
### Phase 16: Custom Slash Commands (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: User and team-defined workflows
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **Command Directories**
|
||||||
|
- `~/.owlen/commands/` (user, available everywhere)
|
||||||
|
- `.owlen/commands/` (project, checked into git)
|
||||||
|
- Support `$ARGUMENTS` keyword
|
||||||
|
|
||||||
|
2. **Example Structure**
|
||||||
|
```markdown
|
||||||
|
# .owlen/commands/fix-github-issue.md
|
||||||
|
Please analyze and fix GitHub issue: $ARGUMENTS.
|
||||||
|
1. Use `gh issue view` to get details
|
||||||
|
2. Implement changes
|
||||||
|
3. Write and run tests
|
||||||
|
4. Create PR
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **TUI Integration**
|
||||||
|
- Auto-complete for custom commands
|
||||||
|
- Help text from command files
|
||||||
|
- Parameter validation
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- New `owlen-core::commands` module
|
||||||
|
- Command discovery and parsing
|
||||||
|
- Template expansion
|
||||||
|
- TUI command registration
|
||||||
|
|
||||||
|
**Files to create:**
|
||||||
|
- `crates/owlen-core/src/commands.rs`
|
||||||
|
- `crates/owlen-tui/src/commands/custom.rs`
|
||||||
|
|
||||||
|
### Phase 17: Plugin System (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: One-command installation of tool collections
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **Plugin Structure**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "github-workflow",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"commands": [
|
||||||
|
{"name": "pr", "file": "commands/pr.md"}
|
||||||
|
],
|
||||||
|
"mcp_servers": [
|
||||||
|
{
|
||||||
|
"name": "github",
|
||||||
|
"command": "${OWLEN_PLUGIN_ROOT}/bin/github-mcp"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Installation**
|
||||||
|
```bash
|
||||||
|
owlen plugin install github-workflow
|
||||||
|
owlen plugin list
|
||||||
|
owlen plugin remove github-workflow
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Discovery**
|
||||||
|
- `~/.owlen/plugins/` directory
|
||||||
|
- Git repository URLs
|
||||||
|
- Plugin registry (future)
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- New `owlen-core::plugins` module
|
||||||
|
- Plugin manifest parser
|
||||||
|
- Installation/removal logic
|
||||||
|
- Sandboxing for plugin code
|
||||||
|
|
||||||
|
**Files to create:**
|
||||||
|
- `crates/owlen-core/src/plugins.rs`
|
||||||
|
- `crates/owlen-cli/src/commands/plugin.rs`
|
||||||
|
|
||||||
|
### Phase 18: Extended Thinking Modes (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Progressive computation budgets for complex tasks
|
||||||
|
|
||||||
|
**Modes:**
|
||||||
|
- `think` - basic extended thinking
|
||||||
|
- `think hard` - increased computation
|
||||||
|
- `think harder` - more computation
|
||||||
|
- `ultrathink` - maximum budget
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- Extend `owlen-core::types::ChatParameters`
|
||||||
|
- Add thinking mode to TUI commands
|
||||||
|
- Configure per-provider max tokens
|
||||||
|
|
||||||
|
**Files to modify:**
|
||||||
|
- `crates/owlen-core/src/types.rs`
|
||||||
|
- `crates/owlen-tui/src/command_parser.rs`
|
||||||
|
|
||||||
|
### Phase 19: Git Workflow Automation (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Streamline common Git operations
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. Auto-commit message generation
|
||||||
|
2. PR creation via `gh` CLI
|
||||||
|
3. Rebase conflict resolution
|
||||||
|
4. File revert operations
|
||||||
|
5. Git history analysis
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- New `owlen-mcp-git-server` crate
|
||||||
|
- Tools: `commit`, `create_pr`, `rebase`, `revert`, `history`
|
||||||
|
- Integration with TUI commands
|
||||||
|
|
||||||
|
**Files to create:**
|
||||||
|
- `crates/owlen-mcp-git-server/src/lib.rs`
|
||||||
|
|
||||||
|
### Phase 20: Enterprise Features (LOW PRIORITY)
|
||||||
|
|
||||||
|
**Goal**: Team and enterprise deployment support
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
1. **Managed Configuration**
|
||||||
|
- `/etc/owlen/managed-mcp.json` (Linux)
|
||||||
|
- Restrict user additions with `useEnterpriseMcpConfigOnly`
|
||||||
|
|
||||||
|
2. **Audit Logging**
|
||||||
|
- Log all file writes and shell commands
|
||||||
|
- Structured JSON logs
|
||||||
|
- Tamper-proof storage
|
||||||
|
|
||||||
|
3. **Team Collaboration**
|
||||||
|
- Shared OWLEN.md across team
|
||||||
|
- Project-scoped MCP servers in `.mcp.json`
|
||||||
|
- Approval policy enforcement
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- Extend `owlen-core::config` with managed settings
|
||||||
|
- New `owlen-core::audit` module
|
||||||
|
- Enterprise deployment documentation
|
||||||
|
|
||||||
|
## Testing Requirements
|
||||||
|
|
||||||
|
### Test Coverage Goals
|
||||||
|
|
||||||
|
- **Unit tests**: 80%+ coverage for `owlen-core`
|
||||||
|
- **Integration tests**: All MCP servers, providers
|
||||||
|
- **TUI tests**: Key workflows (not pixel-perfect)
|
||||||
|
|
||||||
|
### Test Organization
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::provider::test_utils::MockProvider;
|
||||||
|
use crate::mcp::test_utils::MockMcpClient;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_feature() {
|
||||||
|
// Setup
|
||||||
|
let provider = MockProvider::new();
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
let result = provider.chat(request).await;
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test --all # All tests
|
||||||
|
cargo test --lib -p owlen-core # Core library tests
|
||||||
|
cargo test --test integration # Integration tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation Standards
|
||||||
|
|
||||||
|
### Code Documentation
|
||||||
|
|
||||||
|
1. **Module-level** (`//!` at top of file):
|
||||||
|
```rust
|
||||||
|
//! Brief module description
|
||||||
|
//!
|
||||||
|
//! Detailed explanation of module purpose,
|
||||||
|
//! key types, and usage examples.
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Public APIs** (`///` above items):
|
||||||
|
```rust
|
||||||
|
/// Brief description
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `arg1` - Description
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Description of return value
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// When this function returns an error
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// let result = function(arg);
|
||||||
|
/// ```
|
||||||
|
pub fn function(arg: Type) -> Result<Output> {
|
||||||
|
// implementation
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Private items**: Optional, use for complex logic
|
||||||
|
|
||||||
|
### User Documentation
|
||||||
|
|
||||||
|
Location: `docs/` directory
|
||||||
|
|
||||||
|
Files to maintain:
|
||||||
|
- `architecture.md` - System design
|
||||||
|
- `configuration.md` - Config reference
|
||||||
|
- `migration-guide.md` - Version upgrades
|
||||||
|
- `troubleshooting.md` - Common issues
|
||||||
|
- `provider-implementation.md` - Adding providers
|
||||||
|
- `faq.md` - Frequently asked questions
|
||||||
|
|
||||||
|
## Git Workflow
|
||||||
|
|
||||||
|
### Branch Strategy
|
||||||
|
|
||||||
|
- `main` - stable releases only
|
||||||
|
- `dev` - active development (default)
|
||||||
|
- `feature/*` - new features
|
||||||
|
- `fix/*` - bug fixes
|
||||||
|
- `docs/*` - documentation only
|
||||||
|
|
||||||
|
### Commit Messages
|
||||||
|
|
||||||
|
Follow conventional commits:
|
||||||
|
|
||||||
|
```
|
||||||
|
type(scope): brief description
|
||||||
|
|
||||||
|
Detailed explanation of changes.
|
||||||
|
|
||||||
|
Breaking changes, if any.
|
||||||
|
|
||||||
|
🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
||||||
|
|
||||||
|
Co-Authored-By: Claude <noreply@anthropic.com>
|
||||||
|
```
|
||||||
|
|
||||||
|
Types: `feat`, `fix`, `docs`, `refactor`, `test`, `chore`
|
||||||
|
|
||||||
|
### Pre-commit Hooks
|
||||||
|
|
||||||
|
Automatically run:
|
||||||
|
- `cargo fmt` (formatting)
|
||||||
|
- `cargo check` (compilation)
|
||||||
|
- `cargo clippy` (linting)
|
||||||
|
- YAML/TOML validation
|
||||||
|
- Trailing whitespace removal
|
||||||
|
|
||||||
|
## Performance Guidelines
|
||||||
|
|
||||||
|
### Optimization Priorities
|
||||||
|
|
||||||
|
1. **Startup time**: < 500ms cold start
|
||||||
|
2. **First token latency**: < 2s for local models
|
||||||
|
3. **Memory usage**: < 100MB base, < 500MB with conversation
|
||||||
|
4. **Responsiveness**: TUI redraws < 16ms (60 FPS)
|
||||||
|
|
||||||
|
### Profiling
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo build --release --features profiling
|
||||||
|
valgrind --tool=callgrind target/release/owlen
|
||||||
|
kcachegrind callgrind.out.*
|
||||||
|
```
|
||||||
|
|
||||||
|
### Async Performance
|
||||||
|
|
||||||
|
- Avoid blocking in async contexts
|
||||||
|
- Use `tokio::spawn` for CPU-intensive work
|
||||||
|
- Set timeouts on all network operations
|
||||||
|
- Cancel tasks on shutdown
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
### Threat Model
|
||||||
|
|
||||||
|
**Trusted:**
|
||||||
|
- User's local machine
|
||||||
|
- User-installed Ollama models
|
||||||
|
- User configuration files
|
||||||
|
|
||||||
|
**Untrusted:**
|
||||||
|
- MCP server responses
|
||||||
|
- Web search results
|
||||||
|
- Code execution output
|
||||||
|
- Cloud LLM responses
|
||||||
|
|
||||||
|
### Security Measures
|
||||||
|
|
||||||
|
1. **Input Validation**
|
||||||
|
- Sanitize all MCP tool arguments
|
||||||
|
- Validate JSON schemas strictly
|
||||||
|
- Escape shell commands
|
||||||
|
|
||||||
|
2. **Sandboxing**
|
||||||
|
- Docker for code execution
|
||||||
|
- Network isolation
|
||||||
|
- Filesystem restrictions
|
||||||
|
|
||||||
|
3. **Secrets Management**
|
||||||
|
- Never log API keys
|
||||||
|
- Use environment variables
|
||||||
|
- Encrypt sensitive config fields
|
||||||
|
|
||||||
|
4. **Dependency Auditing**
|
||||||
|
```bash
|
||||||
|
cargo audit
|
||||||
|
cargo deny check
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debugging Tips
|
||||||
|
|
||||||
|
### Enable Debug Logging
|
||||||
|
|
||||||
|
```bash
|
||||||
|
OWLEN_DEBUG_OLLAMA=1 owlen # Ollama requests
|
||||||
|
RUST_LOG=debug owlen # All debug logs
|
||||||
|
RUST_BACKTRACE=1 owlen # Stack traces
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Timeout on Ollama**
|
||||||
|
- Check `ollama ps` for loaded models
|
||||||
|
- Increase timeout in config
|
||||||
|
- Restart Ollama service
|
||||||
|
|
||||||
|
2. **MCP Server Not Found**
|
||||||
|
- Verify `mcp_servers` config
|
||||||
|
- Check server binary exists
|
||||||
|
- Test server manually with STDIO
|
||||||
|
|
||||||
|
3. **TUI Rendering Issues**
|
||||||
|
- Test in different terminals
|
||||||
|
- Check terminal size (`tput cols; tput lines`)
|
||||||
|
- Verify theme compatibility
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Before Submitting PR
|
||||||
|
|
||||||
|
1. Run full test suite: `cargo test --all`
|
||||||
|
2. Check formatting: `cargo fmt -- --check`
|
||||||
|
3. Run linter: `cargo clippy -- -D warnings`
|
||||||
|
4. Update documentation if API changed
|
||||||
|
5. Add tests for new features
|
||||||
|
6. Update CHANGELOG.md
|
||||||
|
|
||||||
|
### PR Description Template
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Summary
|
||||||
|
Brief description of changes
|
||||||
|
|
||||||
|
## Type of Change
|
||||||
|
- [ ] Bug fix
|
||||||
|
- [ ] New feature
|
||||||
|
- [ ] Breaking change
|
||||||
|
- [ ] Documentation update
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
Describe tests performed
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
- [ ] Tests added/updated
|
||||||
|
- [ ] Documentation updated
|
||||||
|
- [ ] CHANGELOG.md updated
|
||||||
|
- [ ] No clippy warnings
|
||||||
|
```
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
### External Documentation
|
||||||
|
|
||||||
|
- [Ratatui Docs](https://ratatui.rs/)
|
||||||
|
- [Tokio Tutorial](https://tokio.rs/tokio/tutorial)
|
||||||
|
- [MCP Specification](https://modelcontextprotocol.io/)
|
||||||
|
- [Ollama API](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||||
|
|
||||||
|
### Internal Documentation
|
||||||
|
|
||||||
|
- `.agents/new_phases.md` - 10-phase migration plan (completed)
|
||||||
|
- `docs/phase5-mode-system.md` - Mode system design
|
||||||
|
- `docs/migration-guide.md` - v0.x → v1.0 migration
|
||||||
|
|
||||||
|
### Community
|
||||||
|
|
||||||
|
- GitHub Issues: Bug reports and feature requests
|
||||||
|
- GitHub Discussions: Questions and ideas
|
||||||
|
- AUR Package: `owlen-git` (Arch Linux)
|
||||||
|
|
||||||
|
## Version History
|
||||||
|
|
||||||
|
- **v1.0.0** (current) - MCP-only architecture, Phase 10 complete
|
||||||
|
- **v0.2.0** - Added web search, code execution servers
|
||||||
|
- **v0.1.0** - Initial release with Ollama support
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Owlen is open source software. See LICENSE file for details.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Last Updated**: 2025-10-11
|
||||||
|
**Maintained By**: Owlen Development Team
|
||||||
|
**For AI Agents**: Follow these guidelines when modifying Owlen codebase. Prioritize MCP client enhancement (Phase 11) and approval system (Phase 12) for feature parity with Codex/Claude Code while maintaining local-first philosophy.
|
||||||
30
CHANGELOG.md
30
CHANGELOG.md
@@ -11,15 +11,45 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- Comprehensive documentation suite including guides for architecture, configuration, testing, and more.
|
- Comprehensive documentation suite including guides for architecture, configuration, testing, and more.
|
||||||
- Rustdoc examples for core components like `Provider` and `SessionController`.
|
- Rustdoc examples for core components like `Provider` and `SessionController`.
|
||||||
- Module-level documentation for `owlen-tui`.
|
- Module-level documentation for `owlen-tui`.
|
||||||
|
- Provider integration tests (`crates/owlen-providers/tests`) covering registration, routing, and health status handling for the new `ProviderManager`.
|
||||||
|
- TUI message and generation tests that exercise the non-blocking event loop, background worker, and message dispatch.
|
||||||
- Ollama integration can now talk to Ollama Cloud when an API key is configured.
|
- Ollama integration can now talk to Ollama Cloud when an API key is configured.
|
||||||
- Ollama provider will also read `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables when no key is stored in the config.
|
- Ollama provider will also read `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables when no key is stored in the config.
|
||||||
|
- `owlen config doctor`, `owlen config path`, and `owlen upgrade` CLI commands to automate migrations and surface manual update steps.
|
||||||
|
- Startup provider health check with actionable hints when Ollama or remote MCP servers are unavailable.
|
||||||
|
- `dev/check-windows.sh` helper script for on-demand Windows cross-checks.
|
||||||
|
- Global F1 keybinding for the in-app help overlay and a clearer status hint on launch.
|
||||||
|
- Automatic fallback to the new `ansi_basic` theme when the active terminal only advertises 16-color support.
|
||||||
|
- Offline provider shim that keeps the TUI usable while primary providers are unreachable and communicates recovery steps inline.
|
||||||
|
- `owlen cloud` subcommands (`setup`, `status`, `models`, `logout`) for managing Ollama Cloud credentials without hand-editing config files.
|
||||||
|
- Tabbed model selector that separates local and cloud providers, including cloud indicators in the UI.
|
||||||
|
- Footer status line includes provider connectivity/credential summaries (e.g., cloud auth failures, missing API keys).
|
||||||
|
- Secure credential vault integration for Ollama Cloud API keys when `privacy.encrypt_local_data = true`.
|
||||||
|
- Input panel respects a new `ui.input_max_rows` setting so long prompts expand predictably before scrolling kicks in.
|
||||||
|
- Command palette offers fuzzy `:model` filtering and `:provider` completions for fast switching.
|
||||||
|
- Message rendering caches wrapped lines and throttles streaming redraws to keep the TUI responsive on long sessions.
|
||||||
|
- Model picker badges now inspect provider capabilities so vision/audio/thinking models surface the correct icons even when descriptions are sparse.
|
||||||
|
- Chat history honors `ui.scrollback_lines`, trimming older rows to keep the TUI responsive and surfacing a "↓ New messages" badge whenever updates land off-screen.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- The main `README.md` has been updated to be more concise and link to the new documentation.
|
- The main `README.md` has been updated to be more concise and link to the new documentation.
|
||||||
- Default configuration now pre-populates both `providers.ollama` and `providers.ollama-cloud` entries so switching between local and cloud backends is a single setting change.
|
- Default configuration now pre-populates both `providers.ollama` and `providers.ollama-cloud` entries so switching between local and cloud backends is a single setting change.
|
||||||
|
- `McpMode` support was restored with explicit validation; `remote_only`, `remote_preferred`, and `local_only` now behave predictably.
|
||||||
|
- Configuration loading performs structural validation and fails fast on missing default providers or invalid MCP definitions.
|
||||||
|
- Ollama provider error handling now distinguishes timeouts, missing models, and authentication failures.
|
||||||
|
- `owlen` warns when the active terminal likely lacks 256-color support.
|
||||||
|
- `config.toml` now carries a schema version (`1.2.0`) and is migrated automatically; deprecated keys such as `agent.max_tool_calls` trigger warnings instead of hard failures.
|
||||||
|
- Model selector navigation (Tab/Shift-Tab) now switches between local and cloud tabs while preserving selection state.
|
||||||
|
- Header displays the active model together with its provider (e.g., `Model (Provider)`), improving clarity when swapping backends.
|
||||||
|
- Documentation refreshed to cover the message handler architecture, the background health worker, multi-provider configuration, and the new provider onboarding checklist.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## [0.1.11] - 2025-10-18
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Bump workspace packages and distribution metadata to version `0.1.11`.
|
||||||
|
|
||||||
## [0.1.10] - 2025-10-03
|
## [0.1.10] - 2025-10-03
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ This project and everyone participating in it is governed by the [Owlen Code of
|
|||||||
|
|
||||||
## How Can I Contribute?
|
## How Can I Contribute?
|
||||||
|
|
||||||
|
### Repository map
|
||||||
|
|
||||||
|
Need a quick orientation before diving in? Start with the curated [repo map](docs/repo-map.md) for a two-level directory overview. If you move folders around, regenerate it with `scripts/gen-repo-map.sh`.
|
||||||
|
|
||||||
### Reporting Bugs
|
### Reporting Bugs
|
||||||
|
|
||||||
This is one of the most helpful ways you can contribute. Before creating a bug report, please check a few things:
|
This is one of the most helpful ways you can contribute. Before creating a bug report, please check a few things:
|
||||||
@@ -40,6 +44,7 @@ The process for submitting a pull request is as follows:
|
|||||||
6. **Add a clear, concise commit message.** We follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
6. **Add a clear, concise commit message.** We follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
||||||
7. **Push to your fork** and submit a pull request to Owlen's `main` branch.
|
7. **Push to your fork** and submit a pull request to Owlen's `main` branch.
|
||||||
8. **Include a clear description** of the problem and solution. Include the relevant issue number if applicable.
|
8. **Include a clear description** of the problem and solution. Include the relevant issue number if applicable.
|
||||||
|
9. **Declare AI assistance.** If any part of the patch was generated with an AI tool (e.g., ChatGPT, Claude Code), call that out in the PR description. A human maintainer must review and approve AI-assisted changes before merge.
|
||||||
|
|
||||||
## Development Setup
|
## Development Setup
|
||||||
|
|
||||||
|
|||||||
22
Cargo.toml
22
Cargo.toml
@@ -4,16 +4,20 @@ members = [
|
|||||||
"crates/owlen-core",
|
"crates/owlen-core",
|
||||||
"crates/owlen-tui",
|
"crates/owlen-tui",
|
||||||
"crates/owlen-cli",
|
"crates/owlen-cli",
|
||||||
"crates/owlen-ollama",
|
"crates/owlen-providers",
|
||||||
"crates/owlen-mcp-server",
|
"crates/mcp/server",
|
||||||
"crates/owlen-mcp-llm-server",
|
"crates/mcp/llm-server",
|
||||||
"crates/owlen-mcp-client",
|
"crates/mcp/client",
|
||||||
|
"crates/mcp/code-server",
|
||||||
|
"crates/mcp/prompt-server",
|
||||||
|
"crates/owlen-markdown",
|
||||||
|
"xtask",
|
||||||
]
|
]
|
||||||
exclude = []
|
exclude = []
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.9"
|
version = "0.1.11"
|
||||||
edition = "2021"
|
edition = "2024"
|
||||||
authors = ["Owlibou"]
|
authors = ["Owlibou"]
|
||||||
license = "AGPL-3.0"
|
license = "AGPL-3.0"
|
||||||
repository = "https://somegit.dev/Owlibou/owlen"
|
repository = "https://somegit.dev/Owlibou/owlen"
|
||||||
@@ -42,7 +46,7 @@ serde_json = { version = "1.0" }
|
|||||||
# Utilities
|
# Utilities
|
||||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
thiserror = "1.0"
|
thiserror = "2.0"
|
||||||
nix = "0.29"
|
nix = "0.29"
|
||||||
which = "6.0"
|
which = "6.0"
|
||||||
tempfile = "3.8"
|
tempfile = "3.8"
|
||||||
@@ -55,6 +59,10 @@ urlencoding = "2.1"
|
|||||||
regex = "1.10"
|
regex = "1.10"
|
||||||
rpassword = "7.3"
|
rpassword = "7.3"
|
||||||
sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio-rustls", "sqlite", "macros", "uuid", "chrono", "migrate"] }
|
sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio-rustls", "sqlite", "macros", "uuid", "chrono", "migrate"] }
|
||||||
|
log = "0.4"
|
||||||
|
dirs = "5.0"
|
||||||
|
serde_yaml = "0.9"
|
||||||
|
handlebars = "6.0"
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
|||||||
2
PKGBUILD
2
PKGBUILD
@@ -1,6 +1,6 @@
|
|||||||
# Maintainer: vikingowl <christian@nachtigall.dev>
|
# Maintainer: vikingowl <christian@nachtigall.dev>
|
||||||
pkgname=owlen
|
pkgname=owlen
|
||||||
pkgver=0.1.9
|
pkgver=0.1.11
|
||||||
pkgrel=1
|
pkgrel=1
|
||||||
pkgdesc="Terminal User Interface LLM client for Ollama with chat and code assistance features"
|
pkgdesc="Terminal User Interface LLM client for Ollama with chat and code assistance features"
|
||||||
arch=('x86_64')
|
arch=('x86_64')
|
||||||
|
|||||||
109
README.md
109
README.md
@@ -3,16 +3,17 @@
|
|||||||
> Terminal-native assistant for running local language models with a comfortable TUI.
|
> Terminal-native assistant for running local language models with a comfortable TUI.
|
||||||
|
|
||||||

|

|
||||||

|

|
||||||

|

|
||||||

|

|
||||||
|
|
||||||
## What Is OWLEN?
|
## What Is OWLEN?
|
||||||
|
|
||||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local large
|
OWLEN is a Rust-powered, terminal-first interface for interacting with local and cloud
|
||||||
language models. It provides a responsive chat workflow that runs against
|
language models. It provides a responsive chat workflow that now routes through a
|
||||||
[Ollama](https://ollama.com/) with a focus on developer productivity, vim-style navigation,
|
multi-provider manager—handling local Ollama, Ollama Cloud, and future MCP-backed providers—
|
||||||
and seamless session management—all without leaving your terminal.
|
with a focus on developer productivity, vim-style navigation, and seamless session
|
||||||
|
management—all without leaving your terminal.
|
||||||
|
|
||||||
## Alpha Status
|
## Alpha Status
|
||||||
|
|
||||||
@@ -30,8 +31,22 @@ The OWLEN interface features a clean, multi-panel layout with vim-inspired navig
|
|||||||
- **Streaming Responses**: Real-time token streaming from Ollama.
|
- **Streaming Responses**: Real-time token streaming from Ollama.
|
||||||
- **Advanced Text Editing**: Multi-line input, history, and clipboard support.
|
- **Advanced Text Editing**: Multi-line input, history, and clipboard support.
|
||||||
- **Session Management**: Save, load, and manage conversations.
|
- **Session Management**: Save, load, and manage conversations.
|
||||||
|
- **Code Side Panel**: Switch to code mode (`:mode code`) and open files inline with `:open <path>` for LLM-assisted coding.
|
||||||
- **Theming System**: 10 built-in themes and support for custom themes.
|
- **Theming System**: 10 built-in themes and support for custom themes.
|
||||||
- **Modular Architecture**: Extensible provider system (currently Ollama).
|
- **Modular Architecture**: Extensible provider system orchestrated by the new `ProviderManager`, ready for additional MCP-backed providers.
|
||||||
|
- **Dual-Source Model Picker**: Merge local and cloud catalogues with real-time availability badges powered by the background health worker.
|
||||||
|
- **Non-Blocking UI Loop**: Asynchronous generation tasks and provider health checks run off-thread, keeping the TUI responsive even while streaming long replies.
|
||||||
|
- **Guided Setup**: `owlen config doctor` upgrades legacy configs and verifies your environment in seconds.
|
||||||
|
|
||||||
|
## Security & Privacy
|
||||||
|
|
||||||
|
Owlen is designed to keep data local by default while still allowing controlled access to remote tooling.
|
||||||
|
|
||||||
|
- **Local-first execution**: All LLM calls flow through the bundled MCP LLM server which talks to a local Ollama instance. If the server is unreachable, Owlen stays usable in “offline mode” and surfaces clear recovery instructions.
|
||||||
|
- **Sandboxed tooling**: Code execution runs in Docker according to the MCP Code Server settings, and future releases will extend this to other OS-level sandboxes (`sandbox-exec` on macOS, Windows job objects).
|
||||||
|
- **Session storage**: Conversations are stored under the platform data directory and can be encrypted at rest. Set `privacy.encrypt_local_data = true` in `config.toml` to enable AES-GCM storage protected by a user-supplied passphrase.
|
||||||
|
- **Network access**: No telemetry is sent. The only outbound requests occur when you explicitly enable remote tooling (e.g., web search) or configure a cloud LLM provider. Each tool is opt-in via `privacy` and `tools` configuration sections.
|
||||||
|
- **Config migrations**: Every saved `config.toml` carries a schema version and is upgraded automatically; deprecated keys trigger warnings so security-related settings are not silently ignored.
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
@@ -42,18 +57,28 @@ The OWLEN interface features a clean, multi-panel layout with vim-inspired navig
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
#### Linux & macOS
|
Pick the option that matches your platform and appetite for source builds:
|
||||||
The recommended way to install on Linux and macOS is to clone the repository and install using `cargo`.
|
|
||||||
|
| Platform | Package / Command | Notes |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| Arch Linux | `yay -S owlen-git` | Builds from the latest `dev` branch via AUR. |
|
||||||
|
| Other Linux | `cargo install --path crates/owlen-cli --locked --force` | Requires Rust 1.75+ and a running Ollama daemon. |
|
||||||
|
| macOS | `cargo install --path crates/owlen-cli --locked --force` | macOS 12+ tested. Install Ollama separately (`brew install ollama`). The binary links against the system OpenSSL – ensure Command Line Tools are installed. |
|
||||||
|
| Windows (experimental) | `cargo install --path crates/owlen-cli --locked --force` | Enable the GNU toolchain (`rustup target add x86_64-pc-windows-gnu`) and install Ollama for Windows preview builds. Some optional tools (e.g., Docker-based code execution) are currently disabled. |
|
||||||
|
|
||||||
|
If you prefer containerised builds, use the provided `Dockerfile` as a base image and copy out `target/release/owlen`.
|
||||||
|
|
||||||
|
Run the helper scripts to sanity-check platform coverage:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/Owlibou/owlen.git
|
# Windows compatibility smoke test (GNU toolchain)
|
||||||
cd owlen
|
scripts/check-windows.sh
|
||||||
cargo install --path crates/owlen-cli
|
|
||||||
```
|
|
||||||
**Note for macOS**: While this method works, official binary releases for macOS are planned for the future.
|
|
||||||
|
|
||||||
#### Windows
|
# Reproduce CI packaging locally (choose a target from .woodpecker.yml)
|
||||||
The Windows build has not been thoroughly tested yet. Installation is possible via the same `cargo install` method, but it is considered experimental at this time.
|
dev/local_build.sh x86_64-unknown-linux-gnu
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Tip (macOS):** On the first launch macOS Gatekeeper may quarantine the binary. Clear the attribute (`xattr -d com.apple.quarantine $(which owlen)`) or build from source locally to avoid notarisation prompts.
|
||||||
|
|
||||||
### Running OWLEN
|
### Running OWLEN
|
||||||
|
|
||||||
@@ -66,13 +91,26 @@ If you built from source without installing, you can run it with:
|
|||||||
./target/release/owlen
|
./target/release/owlen
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Updating
|
||||||
|
|
||||||
|
Owlen does not auto-update. Run `owlen upgrade` at any time to print the recommended manual steps (pull the repository and reinstall with `cargo install --path crates/owlen-cli --force`). Arch Linux users can update via the `owlen-git` AUR package.
|
||||||
|
|
||||||
## Using the TUI
|
## Using the TUI
|
||||||
|
|
||||||
OWLEN uses a modal, vim-inspired interface. Press `?` in Normal mode to view the help screen with all keybindings.
|
OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode) or `?` in Normal mode to view the help screen with all keybindings.
|
||||||
|
|
||||||
- **Normal Mode**: Navigate with `h/j/k/l`, `w/b`, `gg/G`.
|
- **Normal Mode**: Navigate with `h/j/k/l`, `w/b`, `gg/G`.
|
||||||
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
||||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`.
|
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:w`, `:session save`, `:theme`.
|
||||||
|
- **Quick Exit**: Press `Ctrl+C` twice in Normal mode to quit quickly (first press still cancels active generations).
|
||||||
|
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
|
||||||
|
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
|
||||||
|
|
||||||
|
Model discovery commands worth remembering:
|
||||||
|
|
||||||
|
- `:models --local` or `:models --cloud` jump directly to the corresponding section in the picker.
|
||||||
|
- `:cloud setup [--force-cloud-base-url]` stores your cloud API key without clobbering an existing local base URL (unless you opt in with the flag).
|
||||||
|
When a catalogue is unreachable, Owlen now tags the picker with `Local unavailable` / `Cloud unavailable` so you can recover without guessing.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
@@ -82,17 +120,47 @@ For more detailed information, please refer to the following documents:
|
|||||||
- **[CHANGELOG.md](CHANGELOG.md)**: A log of changes for each version.
|
- **[CHANGELOG.md](CHANGELOG.md)**: A log of changes for each version.
|
||||||
- **[docs/architecture.md](docs/architecture.md)**: An overview of the project's architecture.
|
- **[docs/architecture.md](docs/architecture.md)**: An overview of the project's architecture.
|
||||||
- **[docs/troubleshooting.md](docs/troubleshooting.md)**: Help with common issues.
|
- **[docs/troubleshooting.md](docs/troubleshooting.md)**: Help with common issues.
|
||||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: A guide for adding new providers.
|
- **[docs/repo-map.md](docs/repo-map.md)**: Snapshot of the workspace layout and key crates.
|
||||||
|
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: Trait-level details for implementing providers.
|
||||||
|
- **[docs/adding-providers.md](docs/adding-providers.md)**: Step-by-step checklist for wiring a provider into the multi-provider architecture and test suite.
|
||||||
|
- **Experimental providers staging area**: [crates/providers/experimental/README.md](crates/providers/experimental/README.md) records the placeholder crates (OpenAI, Anthropic, Gemini) and their current status.
|
||||||
|
- **[docs/platform-support.md](docs/platform-support.md)**: Current OS support matrix and cross-check instructions.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
OWLEN stores its configuration in `~/.config/owlen/config.toml`. This file is created on the first run and can be customized. You can also add custom themes in `~/.config/owlen/themes/`.
|
OWLEN stores its configuration in the standard platform-specific config directory:
|
||||||
|
|
||||||
|
| Platform | Location |
|
||||||
|
|----------|----------|
|
||||||
|
| Linux | `~/.config/owlen/config.toml` |
|
||||||
|
| macOS | `~/Library/Application Support/owlen/config.toml` |
|
||||||
|
| Windows | `%APPDATA%\owlen\config.toml` |
|
||||||
|
|
||||||
|
Use `owlen config path` to print the exact location on your machine and `owlen config doctor` to migrate a legacy config automatically.
|
||||||
|
You can also add custom themes alongside the config directory (e.g., `~/.config/owlen/themes/`).
|
||||||
|
|
||||||
See the [themes/README.md](themes/README.md) for more details on theming.
|
See the [themes/README.md](themes/README.md) for more details on theming.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Owlen uses standard Rust tooling for verification. Run the full test suite with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
|
|
||||||
|
Unit tests cover the command palette state machine, agent response parsing, and key MCP abstractions. Formatting and lint checks can be run with `cargo fmt --all` and `cargo clippy` respectively.
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
We are actively working on enhancing the code client, adding more providers (OpenAI, Anthropic), and improving the overall user experience. See the [Roadmap section in the old README](https://github.com/Owlibou/owlen/blob/main/README.md?plain=1#L295) for more details.
|
Upcoming milestones focus on feature parity with modern code assistants while keeping Owlen local-first:
|
||||||
|
|
||||||
|
1. **Phase 11 – MCP client enhancements**: `owlen mcp add/list/remove`, resource references (`@github:issue://123`), and MCP prompt slash commands.
|
||||||
|
2. **Phase 12 – Approval & sandboxing**: Three-tier approval modes plus platform-specific sandboxes (Docker, `sandbox-exec`, Windows job objects).
|
||||||
|
3. **Phase 13 – Project documentation system**: Automatic `OWLEN.md` generation, contextual updates, and nested project support.
|
||||||
|
4. **Phase 15 – Provider expansion**: OpenAI, Anthropic, and other cloud providers layered onto the existing Ollama-first architecture.
|
||||||
|
|
||||||
|
See `AGENTS.md` for the long-form roadmap and design notes.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
@@ -101,3 +169,4 @@ Contributions are highly welcome! Please see our **[Contributing Guide](CONTRIBU
|
|||||||
## License
|
## License
|
||||||
|
|
||||||
This project is licensed under the GNU Affero General Public License v3.0. See the [LICENSE](LICENSE) file for details.
|
This project is licensed under the GNU Affero General Public License v3.0. See the [LICENSE](LICENSE) file for details.
|
||||||
|
For commercial or proprietary integrations that cannot adopt AGPL, please reach out to the maintainers to discuss alternative licensing arrangements.
|
||||||
|
|||||||
21
SECURITY.md
21
SECURITY.md
@@ -17,3 +17,24 @@ To report a security vulnerability, please email the project lead at [security@o
|
|||||||
You will receive a response from us within 48 hours. If the issue is confirmed, we will release a patch as soon as possible, depending on the complexity of the issue.
|
You will receive a response from us within 48 hours. If the issue is confirmed, we will release a patch as soon as possible, depending on the complexity of the issue.
|
||||||
|
|
||||||
Please do not report security vulnerabilities through public GitHub issues.
|
Please do not report security vulnerabilities through public GitHub issues.
|
||||||
|
|
||||||
|
## Design Overview
|
||||||
|
|
||||||
|
Owlen ships with a local-first architecture:
|
||||||
|
|
||||||
|
- **Process isolation** – The TUI speaks to language models through a separate MCP LLM server. Tool execution (code, web, filesystem) occurs in dedicated MCP processes so a crash or hang cannot take down the UI.
|
||||||
|
- **Sandboxing** – The MCP Code Server executes snippets in Docker containers. Upcoming releases will extend this to platform sandboxes (`sandbox-exec` on macOS, Windows job objects) as described in our roadmap.
|
||||||
|
- **Network posture** – No telemetry is emitted. The application only reaches the network when a user explicitly enables remote tools (web search, remote MCP servers) or configures cloud providers. All tools require allow-listing in `config.toml`.
|
||||||
|
|
||||||
|
## Data Handling
|
||||||
|
|
||||||
|
- **Sessions** – Conversations are stored in the user’s data directory (`~/.local/share/owlen` on Linux, equivalent paths on macOS/Windows). Enable `privacy.encrypt_local_data = true` to wrap the session store in AES-GCM encryption protected by a passphrase (`OWLEN_MASTER_PASSWORD` or an interactive prompt).
|
||||||
|
- **Credentials** – API tokens are resolved from the config file or environment variables at runtime and are never written to logs.
|
||||||
|
- **Remote calls** – When remote search or cloud LLM tooling is on, only the minimum payload (prompt, tool arguments) is sent. All outbound requests go through the MCP servers so they can be audited or disabled centrally.
|
||||||
|
|
||||||
|
## Supply-Chain Safeguards
|
||||||
|
|
||||||
|
- The repository includes a git `pre-commit` configuration that runs `cargo fmt`, `cargo check`, and `cargo clippy -- -D warnings` on every commit.
|
||||||
|
- Pull requests generated with the assistance of AI tooling must receive manual maintainer review before merging. Contributors are asked to declare AI involvement in their PR description so maintainers can double-check the changes.
|
||||||
|
|
||||||
|
Additional recommendations for operators (e.g., running Owlen on shared systems) are maintained in `docs/security.md` (planned) and the issue tracker.
|
||||||
|
|||||||
29
config.toml
Normal file
29
config.toml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
[general]
|
||||||
|
default_provider = "ollama_local"
|
||||||
|
default_model = "llama3.2:latest"
|
||||||
|
|
||||||
|
[privacy]
|
||||||
|
encrypt_local_data = true
|
||||||
|
|
||||||
|
[providers.ollama_local]
|
||||||
|
enabled = true
|
||||||
|
provider_type = "ollama"
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
|
||||||
|
[providers.ollama_cloud]
|
||||||
|
enabled = false
|
||||||
|
provider_type = "ollama_cloud"
|
||||||
|
base_url = "https://ollama.com"
|
||||||
|
api_key_env = "OLLAMA_CLOUD_API_KEY"
|
||||||
|
|
||||||
|
[providers.openai]
|
||||||
|
enabled = false
|
||||||
|
provider_type = "openai"
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
|
api_key_env = "OPENAI_API_KEY"
|
||||||
|
|
||||||
|
[providers.anthropic]
|
||||||
|
enabled = false
|
||||||
|
provider_type = "anthropic"
|
||||||
|
base_url = "https://api.anthropic.com/v1"
|
||||||
|
api_key_env = "ANTHROPIC_API_KEY"
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "owlen-mcp-client"
|
name = "owlen-mcp-client"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
description = "Dedicated MCP client library for Owlen, exposing remote MCP server communication"
|
description = "Dedicated MCP client library for Owlen, exposing remote MCP server communication"
|
||||||
license = "AGPL-3.0"
|
license = "AGPL-3.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
owlen-core = { path = "../owlen-core" }
|
owlen-core = { path = "../../owlen-core" }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
@@ -5,14 +5,12 @@
|
|||||||
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
||||||
//! core crate internals.
|
//! core crate internals.
|
||||||
|
|
||||||
|
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
|
||||||
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||||
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
|
|
||||||
// Re‑export the Provider implementation so the client can also be used as an
|
|
||||||
// LLM provider when the remote MCP server hosts a language‑model tool (e.g.
|
|
||||||
// `generate_text`).
|
|
||||||
// Re‑export the core Provider trait so that the MCP client can also be used as an LLM provider.
|
// Re‑export the core Provider trait so that the MCP client can also be used as an LLM provider.
|
||||||
pub use owlen_core::provider::Provider as McpProvider;
|
pub use owlen_core::Provider as McpProvider;
|
||||||
|
|
||||||
// Note: The `RemoteMcpClient` type provides its own `new` constructor in the core
|
// Note: The `RemoteMcpClient` type provides its own `new` constructor in the core
|
||||||
// crate. Users can call `RemoteMcpClient::new()` directly. No additional wrapper
|
// crate. Users can call `RemoteMcpClient::new()` directly. No additional wrapper
|
||||||
22
crates/mcp/code-server/Cargo.toml
Normal file
22
crates/mcp/code-server/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[package]
|
||||||
|
name = "owlen-mcp-code-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
description = "MCP server exposing safe code execution tools for Owlen"
|
||||||
|
license = "AGPL-3.0"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
owlen-core = { path = "../../owlen-core" }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
tokio = { workspace = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
bollard = "0.17"
|
||||||
|
tempfile = { workspace = true }
|
||||||
|
uuid = { workspace = true }
|
||||||
|
futures = { workspace = true }
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "owlen_mcp_code_server"
|
||||||
|
path = "src/lib.rs"
|
||||||
186
crates/mcp/code-server/src/lib.rs
Normal file
186
crates/mcp/code-server/src/lib.rs
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
//! MCP server exposing code execution tools with Docker sandboxing.
|
||||||
|
//!
|
||||||
|
//! This server provides:
|
||||||
|
//! - compile_project: Build projects (Rust, Node.js, Python)
|
||||||
|
//! - run_tests: Execute test suites
|
||||||
|
//! - format_code: Run code formatters
|
||||||
|
//! - lint_code: Run linters
|
||||||
|
|
||||||
|
pub mod sandbox;
|
||||||
|
pub mod tools;
|
||||||
|
|
||||||
|
use owlen_core::mcp::protocol::{
|
||||||
|
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||||
|
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||||
|
};
|
||||||
|
use owlen_core::tools::{Tool, ToolResult};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
use tools::{CompileProjectTool, FormatCodeTool, LintCodeTool, RunTestsTool};
|
||||||
|
|
||||||
|
/// Tool registry for the code server
|
||||||
|
#[allow(dead_code)]
|
||||||
|
struct ToolRegistry {
|
||||||
|
tools: HashMap<String, Box<dyn Tool + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
impl ToolRegistry {
|
||||||
|
fn new() -> Self {
|
||||||
|
let mut tools: HashMap<String, Box<dyn Tool + Send + Sync>> = HashMap::new();
|
||||||
|
tools.insert(
|
||||||
|
"compile_project".to_string(),
|
||||||
|
Box::new(CompileProjectTool::new()),
|
||||||
|
);
|
||||||
|
tools.insert("run_tests".to_string(), Box::new(RunTestsTool::new()));
|
||||||
|
tools.insert("format_code".to_string(), Box::new(FormatCodeTool::new()));
|
||||||
|
tools.insert("lint_code".to_string(), Box::new(LintCodeTool::new()));
|
||||||
|
Self { tools }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_tools(&self) -> Vec<owlen_core::mcp::McpToolDescriptor> {
|
||||||
|
self.tools
|
||||||
|
.values()
|
||||||
|
.map(|tool| owlen_core::mcp::McpToolDescriptor {
|
||||||
|
name: tool.name().to_string(),
|
||||||
|
description: tool.description().to_string(),
|
||||||
|
input_schema: tool.schema(),
|
||||||
|
requires_network: tool.requires_network(),
|
||||||
|
requires_filesystem: tool.requires_filesystem(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, name: &str, args: Value) -> Result<ToolResult, String> {
|
||||||
|
self.tools
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| format!("Tool not found: {}", name))?
|
||||||
|
.execute(args)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let mut stdin = io::BufReader::new(io::stdin());
|
||||||
|
let mut stdout = io::stdout();
|
||||||
|
|
||||||
|
let registry = Arc::new(ToolRegistry::new());
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut line = String::new();
|
||||||
|
match stdin.read_line(&mut line).await {
|
||||||
|
Ok(0) => break, // EOF
|
||||||
|
Ok(_) => {
|
||||||
|
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let err = RpcErrorResponse::new(
|
||||||
|
RequestId::Number(0),
|
||||||
|
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||||
|
);
|
||||||
|
let s = serde_json::to_string(&err)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = handle_request(req.clone(), registry.clone()).await;
|
||||||
|
match resp {
|
||||||
|
Ok(r) => {
|
||||||
|
let s = serde_json::to_string(&r)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let err = RpcErrorResponse::new(req.id.clone(), e);
|
||||||
|
let s = serde_json::to_string(&err)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Error reading stdin: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
async fn handle_request(
|
||||||
|
req: RpcRequest,
|
||||||
|
registry: Arc<ToolRegistry>,
|
||||||
|
) -> Result<RpcResponse, RpcError> {
|
||||||
|
match req.method.as_str() {
|
||||||
|
methods::INITIALIZE => {
|
||||||
|
let params: InitializeParams =
|
||||||
|
serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||||
|
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||||
|
if !params.protocol_version.eq(PROTOCOL_VERSION) {
|
||||||
|
return Err(RpcError::new(
|
||||||
|
ErrorCode::INVALID_REQUEST,
|
||||||
|
format!(
|
||||||
|
"Incompatible protocol version. Client: {}, Server: {}",
|
||||||
|
params.protocol_version, PROTOCOL_VERSION
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let result = InitializeResult {
|
||||||
|
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||||
|
server_info: ServerInfo {
|
||||||
|
name: "owlen-mcp-code-server".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
},
|
||||||
|
capabilities: ServerCapabilities {
|
||||||
|
supports_tools: Some(true),
|
||||||
|
supports_resources: Some(false),
|
||||||
|
supports_streaming: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let payload = serde_json::to_value(result).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(RpcResponse::new(req.id, payload))
|
||||||
|
}
|
||||||
|
methods::TOOLS_LIST => {
|
||||||
|
let tools = registry.list_tools();
|
||||||
|
Ok(RpcResponse::new(req.id, json!(tools)))
|
||||||
|
}
|
||||||
|
methods::TOOLS_CALL => {
|
||||||
|
let call = serde_json::from_value::<owlen_core::mcp::McpToolCall>(
|
||||||
|
req.params.unwrap_or_else(|| json!({})),
|
||||||
|
)
|
||||||
|
.map_err(|e| RpcError::invalid_params(format!("Invalid tool call: {}", e)))?;
|
||||||
|
|
||||||
|
let result: ToolResult = registry
|
||||||
|
.execute(&call.name, call.arguments)
|
||||||
|
.await
|
||||||
|
.map_err(|e| RpcError::internal_error(format!("Tool execution failed: {}", e)))?;
|
||||||
|
|
||||||
|
let resp = owlen_core::mcp::McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: result.success,
|
||||||
|
output: result.output,
|
||||||
|
metadata: result.metadata,
|
||||||
|
duration_ms: result.duration.as_millis() as u128,
|
||||||
|
};
|
||||||
|
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(RpcResponse::new(req.id, payload))
|
||||||
|
}
|
||||||
|
_ => Err(RpcError::method_not_found(&req.method)),
|
||||||
|
}
|
||||||
|
}
|
||||||
250
crates/mcp/code-server/src/sandbox.rs
Normal file
250
crates/mcp/code-server/src/sandbox.rs
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
//! Docker-based sandboxing for secure code execution
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use bollard::Docker;
|
||||||
|
use bollard::container::{
|
||||||
|
Config, CreateContainerOptions, RemoveContainerOptions, StartContainerOptions,
|
||||||
|
WaitContainerOptions,
|
||||||
|
};
|
||||||
|
use bollard::models::{HostConfig, Mount, MountTypeEnum};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// Result of executing code in a sandbox
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ExecutionResult {
|
||||||
|
pub stdout: String,
|
||||||
|
pub stderr: String,
|
||||||
|
pub exit_code: i64,
|
||||||
|
pub timed_out: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Docker-based sandbox executor
|
||||||
|
pub struct Sandbox {
|
||||||
|
docker: Docker,
|
||||||
|
memory_limit: i64,
|
||||||
|
cpu_quota: i64,
|
||||||
|
timeout_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sandbox {
|
||||||
|
/// Create a new sandbox with default resource limits
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let docker =
|
||||||
|
Docker::connect_with_local_defaults().context("Failed to connect to Docker daemon")?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
docker,
|
||||||
|
memory_limit: 512 * 1024 * 1024, // 512MB
|
||||||
|
cpu_quota: 50000, // 50% of one core
|
||||||
|
timeout_secs: 30,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a command in a sandboxed container
|
||||||
|
pub async fn execute(
|
||||||
|
&self,
|
||||||
|
image: &str,
|
||||||
|
cmd: &[&str],
|
||||||
|
workspace: Option<&Path>,
|
||||||
|
env: HashMap<String, String>,
|
||||||
|
) -> Result<ExecutionResult> {
|
||||||
|
let container_name = format!("owlen-sandbox-{}", uuid::Uuid::new_v4());
|
||||||
|
|
||||||
|
// Prepare volume mount if workspace provided
|
||||||
|
let mounts = if let Some(ws) = workspace {
|
||||||
|
vec![Mount {
|
||||||
|
target: Some("/workspace".to_string()),
|
||||||
|
source: Some(ws.to_string_lossy().to_string()),
|
||||||
|
typ: Some(MountTypeEnum::BIND),
|
||||||
|
read_only: Some(false),
|
||||||
|
..Default::default()
|
||||||
|
}]
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create container config
|
||||||
|
let host_config = HostConfig {
|
||||||
|
memory: Some(self.memory_limit),
|
||||||
|
cpu_quota: Some(self.cpu_quota),
|
||||||
|
network_mode: Some("none".to_string()), // No network access
|
||||||
|
mounts: Some(mounts),
|
||||||
|
auto_remove: Some(true),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = Config {
|
||||||
|
image: Some(image.to_string()),
|
||||||
|
cmd: Some(cmd.iter().map(|s| s.to_string()).collect()),
|
||||||
|
working_dir: Some("/workspace".to_string()),
|
||||||
|
env: Some(env.iter().map(|(k, v)| format!("{}={}", k, v)).collect()),
|
||||||
|
host_config: Some(host_config),
|
||||||
|
attach_stdout: Some(true),
|
||||||
|
attach_stderr: Some(true),
|
||||||
|
tty: Some(false),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create container
|
||||||
|
let container = self
|
||||||
|
.docker
|
||||||
|
.create_container(
|
||||||
|
Some(CreateContainerOptions {
|
||||||
|
name: container_name.clone(),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.context("Failed to create container")?;
|
||||||
|
|
||||||
|
// Start container
|
||||||
|
self.docker
|
||||||
|
.start_container(&container.id, None::<StartContainerOptions<String>>)
|
||||||
|
.await
|
||||||
|
.context("Failed to start container")?;
|
||||||
|
|
||||||
|
// Wait for container with timeout
|
||||||
|
let wait_result =
|
||||||
|
tokio::time::timeout(std::time::Duration::from_secs(self.timeout_secs), async {
|
||||||
|
let mut wait_stream = self
|
||||||
|
.docker
|
||||||
|
.wait_container(&container.id, None::<WaitContainerOptions<String>>);
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
if let Some(result) = wait_stream.next().await {
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
Err(bollard::errors::Error::IOError {
|
||||||
|
err: std::io::Error::other("Container wait stream ended unexpectedly"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (exit_code, timed_out) = match wait_result {
|
||||||
|
Ok(Ok(result)) => (result.status_code, false),
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
eprintln!("Container wait error: {}", e);
|
||||||
|
(1, false)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Timeout - kill the container
|
||||||
|
let _ = self
|
||||||
|
.docker
|
||||||
|
.kill_container(
|
||||||
|
&container.id,
|
||||||
|
None::<bollard::container::KillContainerOptions<String>>,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
(124, true)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get logs
|
||||||
|
let logs = self.docker.logs(
|
||||||
|
&container.id,
|
||||||
|
Some(bollard::container::LogsOptions::<String> {
|
||||||
|
stdout: true,
|
||||||
|
stderr: true,
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
let mut stdout = String::new();
|
||||||
|
let mut stderr = String::new();
|
||||||
|
|
||||||
|
let log_result = tokio::time::timeout(std::time::Duration::from_secs(5), async {
|
||||||
|
let mut logs = logs;
|
||||||
|
while let Some(log) = logs.next().await {
|
||||||
|
match log {
|
||||||
|
Ok(bollard::container::LogOutput::StdOut { message }) => {
|
||||||
|
stdout.push_str(&String::from_utf8_lossy(&message));
|
||||||
|
}
|
||||||
|
Ok(bollard::container::LogOutput::StdErr { message }) => {
|
||||||
|
stderr.push_str(&String::from_utf8_lossy(&message));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if log_result.is_err() {
|
||||||
|
eprintln!("Timeout reading container logs");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove container (auto_remove should handle this, but be explicit)
|
||||||
|
let _ = self
|
||||||
|
.docker
|
||||||
|
.remove_container(
|
||||||
|
&container.id,
|
||||||
|
Some(RemoveContainerOptions {
|
||||||
|
force: true,
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(ExecutionResult {
|
||||||
|
stdout,
|
||||||
|
stderr,
|
||||||
|
exit_code,
|
||||||
|
timed_out,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute in a Rust environment
|
||||||
|
pub async fn execute_rust(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||||
|
self.execute("rust:1.75-slim", cmd, Some(workspace), HashMap::new())
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute in a Python environment
|
||||||
|
pub async fn execute_python(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||||
|
self.execute("python:3.11-slim", cmd, Some(workspace), HashMap::new())
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute in a Node.js environment
|
||||||
|
pub async fn execute_node(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||||
|
self.execute("node:20-slim", cmd, Some(workspace), HashMap::new())
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Sandbox {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new().expect("Failed to create default sandbox")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore] // Requires Docker daemon
|
||||||
|
async fn test_sandbox_rust_compile() {
|
||||||
|
let sandbox = Sandbox::new().unwrap();
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a simple Rust project
|
||||||
|
std::fs::write(
|
||||||
|
temp_dir.path().join("main.rs"),
|
||||||
|
"fn main() { println!(\"Hello from sandbox!\"); }",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result = sandbox
|
||||||
|
.execute_rust(temp_dir.path(), &["rustc", "main.rs"])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result.exit_code, 0);
|
||||||
|
assert!(!result.timed_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
417
crates/mcp/code-server/src/tools.rs
Normal file
417
crates/mcp/code-server/src/tools.rs
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
//! Code execution tools using Docker sandboxing
|
||||||
|
|
||||||
|
use crate::sandbox::Sandbox;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use owlen_core::Result;
|
||||||
|
use owlen_core::tools::{Tool, ToolResult};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
/// Tool for compiling projects (Rust, Node.js, Python)
|
||||||
|
pub struct CompileProjectTool {
|
||||||
|
sandbox: Sandbox,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CompileProjectTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompileProjectTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sandbox: Sandbox::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CompileProjectTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"compile_project"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"Compile a project (Rust, Node.js, Python). Detects project type automatically."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"project_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the project root"
|
||||||
|
},
|
||||||
|
"project_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["rust", "node", "python"],
|
||||||
|
"description": "Project type (auto-detected if not specified)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["project_path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||||
|
let project_path = args
|
||||||
|
.get("project_path")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||||
|
|
||||||
|
let path = PathBuf::from(project_path);
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(ToolResult::error("Project path does not exist"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect project type
|
||||||
|
let project_type = if let Some(pt) = args.get("project_type").and_then(|v| v.as_str()) {
|
||||||
|
pt.to_string()
|
||||||
|
} else if path.join("Cargo.toml").exists() {
|
||||||
|
"rust".to_string()
|
||||||
|
} else if path.join("package.json").exists() {
|
||||||
|
"node".to_string()
|
||||||
|
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||||
|
"python".to_string()
|
||||||
|
} else {
|
||||||
|
return Ok(ToolResult::error("Could not detect project type"));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute compilation
|
||||||
|
let result = match project_type.as_str() {
|
||||||
|
"rust" => self.sandbox.execute_rust(&path, &["cargo", "build"]).await,
|
||||||
|
"node" => {
|
||||||
|
self.sandbox
|
||||||
|
.execute_node(&path, &["npm", "run", "build"])
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
"python" => {
|
||||||
|
// Python typically doesn't need compilation, but we can check syntax
|
||||||
|
self.sandbox
|
||||||
|
.execute_python(&path, &["python", "-m", "compileall", "."])
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
_ => return Ok(ToolResult::error("Unsupported project type")),
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(exec_result) => {
|
||||||
|
if exec_result.timed_out {
|
||||||
|
Ok(ToolResult::error("Compilation timed out"))
|
||||||
|
} else if exec_result.exit_code == 0 {
|
||||||
|
Ok(ToolResult::success(json!({
|
||||||
|
"success": true,
|
||||||
|
"stdout": exec_result.stdout,
|
||||||
|
"stderr": exec_result.stderr,
|
||||||
|
"project_type": project_type
|
||||||
|
})))
|
||||||
|
} else {
|
||||||
|
Ok(ToolResult::success(json!({
|
||||||
|
"success": false,
|
||||||
|
"exit_code": exec_result.exit_code,
|
||||||
|
"stdout": exec_result.stdout,
|
||||||
|
"stderr": exec_result.stderr,
|
||||||
|
"project_type": project_type
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Ok(ToolResult::error(&format!("Compilation failed: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool for running test suites
|
||||||
|
pub struct RunTestsTool {
|
||||||
|
sandbox: Sandbox,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RunTestsTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RunTestsTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sandbox: Sandbox::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for RunTestsTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"run_tests"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"Run tests for a project (Rust, Node.js, Python)"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"project_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the project root"
|
||||||
|
},
|
||||||
|
"test_filter": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional test filter/pattern"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["project_path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||||
|
let project_path = args
|
||||||
|
.get("project_path")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||||
|
|
||||||
|
let path = PathBuf::from(project_path);
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(ToolResult::error("Project path does not exist"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let test_filter = args.get("test_filter").and_then(|v| v.as_str());
|
||||||
|
|
||||||
|
// Detect project type and run tests
|
||||||
|
let result = if path.join("Cargo.toml").exists() {
|
||||||
|
let cmd = if let Some(filter) = test_filter {
|
||||||
|
vec!["cargo", "test", filter]
|
||||||
|
} else {
|
||||||
|
vec!["cargo", "test"]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_rust(&path, &cmd).await
|
||||||
|
} else if path.join("package.json").exists() {
|
||||||
|
self.sandbox.execute_node(&path, &["npm", "test"]).await
|
||||||
|
} else if path.join("pytest.ini").exists()
|
||||||
|
|| path.join("setup.py").exists()
|
||||||
|
|| path.join("pyproject.toml").exists()
|
||||||
|
{
|
||||||
|
let cmd = if let Some(filter) = test_filter {
|
||||||
|
vec!["pytest", "-k", filter]
|
||||||
|
} else {
|
||||||
|
vec!["pytest"]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_python(&path, &cmd).await
|
||||||
|
} else {
|
||||||
|
return Ok(ToolResult::error("Could not detect test framework"));
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(exec_result) => Ok(ToolResult::success(json!({
|
||||||
|
"success": exec_result.exit_code == 0 && !exec_result.timed_out,
|
||||||
|
"exit_code": exec_result.exit_code,
|
||||||
|
"stdout": exec_result.stdout,
|
||||||
|
"stderr": exec_result.stderr,
|
||||||
|
"timed_out": exec_result.timed_out
|
||||||
|
}))),
|
||||||
|
Err(e) => Ok(ToolResult::error(&format!("Tests failed to run: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool for formatting code
|
||||||
|
pub struct FormatCodeTool {
|
||||||
|
sandbox: Sandbox,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FormatCodeTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FormatCodeTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sandbox: Sandbox::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for FormatCodeTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"format_code"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"Format code using project-appropriate formatter (rustfmt, prettier, black)"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"project_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the project root"
|
||||||
|
},
|
||||||
|
"check_only": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Only check formatting without modifying files",
|
||||||
|
"default": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["project_path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||||
|
let project_path = args
|
||||||
|
.get("project_path")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||||
|
|
||||||
|
let path = PathBuf::from(project_path);
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(ToolResult::error("Project path does not exist"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let check_only = args
|
||||||
|
.get("check_only")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
// Detect project type and run formatter
|
||||||
|
let result = if path.join("Cargo.toml").exists() {
|
||||||
|
let cmd = if check_only {
|
||||||
|
vec!["cargo", "fmt", "--", "--check"]
|
||||||
|
} else {
|
||||||
|
vec!["cargo", "fmt"]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_rust(&path, &cmd).await
|
||||||
|
} else if path.join("package.json").exists() {
|
||||||
|
let cmd = if check_only {
|
||||||
|
vec!["npx", "prettier", "--check", "."]
|
||||||
|
} else {
|
||||||
|
vec!["npx", "prettier", "--write", "."]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_node(&path, &cmd).await
|
||||||
|
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||||
|
let cmd = if check_only {
|
||||||
|
vec!["black", "--check", "."]
|
||||||
|
} else {
|
||||||
|
vec!["black", "."]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_python(&path, &cmd).await
|
||||||
|
} else {
|
||||||
|
return Ok(ToolResult::error("Could not detect project type"));
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(exec_result) => Ok(ToolResult::success(json!({
|
||||||
|
"success": exec_result.exit_code == 0,
|
||||||
|
"formatted": !check_only && exec_result.exit_code == 0,
|
||||||
|
"stdout": exec_result.stdout,
|
||||||
|
"stderr": exec_result.stderr
|
||||||
|
}))),
|
||||||
|
Err(e) => Ok(ToolResult::error(&format!("Formatting failed: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool for linting code
|
||||||
|
pub struct LintCodeTool {
|
||||||
|
sandbox: Sandbox,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LintCodeTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LintCodeTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sandbox: Sandbox::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for LintCodeTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"lint_code"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"Lint code using project-appropriate linter (clippy, eslint, pylint)"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"project_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the project root"
|
||||||
|
},
|
||||||
|
"fix": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Automatically fix issues if possible",
|
||||||
|
"default": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["project_path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||||
|
let project_path = args
|
||||||
|
.get("project_path")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||||
|
|
||||||
|
let path = PathBuf::from(project_path);
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(ToolResult::error("Project path does not exist"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let fix = args.get("fix").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||||
|
|
||||||
|
// Detect project type and run linter
|
||||||
|
let result = if path.join("Cargo.toml").exists() {
|
||||||
|
let cmd = if fix {
|
||||||
|
vec!["cargo", "clippy", "--fix", "--allow-dirty"]
|
||||||
|
} else {
|
||||||
|
vec!["cargo", "clippy"]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_rust(&path, &cmd).await
|
||||||
|
} else if path.join("package.json").exists() {
|
||||||
|
let cmd = if fix {
|
||||||
|
vec!["npx", "eslint", ".", "--fix"]
|
||||||
|
} else {
|
||||||
|
vec!["npx", "eslint", "."]
|
||||||
|
};
|
||||||
|
self.sandbox.execute_node(&path, &cmd).await
|
||||||
|
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||||
|
// pylint doesn't have auto-fix
|
||||||
|
self.sandbox.execute_python(&path, &["pylint", "."]).await
|
||||||
|
} else {
|
||||||
|
return Ok(ToolResult::error("Could not detect project type"));
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(exec_result) => {
|
||||||
|
let issues_found = exec_result.exit_code != 0;
|
||||||
|
Ok(ToolResult::success(json!({
|
||||||
|
"success": true,
|
||||||
|
"issues_found": issues_found,
|
||||||
|
"exit_code": exec_result.exit_code,
|
||||||
|
"stdout": exec_result.stdout,
|
||||||
|
"stderr": exec_result.stderr
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
Err(e) => Ok(ToolResult::error(&format!("Linting failed: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
16
crates/mcp/llm-server/Cargo.toml
Normal file
16
crates/mcp/llm-server/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
[package]
|
||||||
|
name = "owlen-mcp-llm-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
owlen-core = { path = "../../owlen-core" }
|
||||||
|
tokio = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
tokio-stream = { workspace = true }
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "owlen-mcp-llm-server"
|
||||||
|
path = "src/main.rs"
|
||||||
@@ -7,18 +7,22 @@
|
|||||||
clippy::empty_line_after_outer_attr
|
clippy::empty_line_after_outer_attr
|
||||||
)]
|
)]
|
||||||
|
|
||||||
|
use owlen_core::Provider;
|
||||||
|
use owlen_core::ProviderConfig;
|
||||||
|
use owlen_core::config::{Config as OwlenConfig, ensure_provider_config};
|
||||||
use owlen_core::mcp::protocol::{
|
use owlen_core::mcp::protocol::{
|
||||||
methods, ErrorCode, InitializeParams, InitializeResult, RequestId, RpcError, RpcErrorResponse,
|
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||||
RpcNotification, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, PROTOCOL_VERSION,
|
RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo,
|
||||||
|
methods,
|
||||||
};
|
};
|
||||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
|
use owlen_core::providers::OllamaProvider;
|
||||||
use owlen_core::types::{ChatParameters, ChatRequest, Message};
|
use owlen_core::types::{ChatParameters, ChatRequest, Message};
|
||||||
use owlen_core::Provider;
|
|
||||||
use owlen_ollama::OllamaProvider;
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::sync::Arc;
|
||||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
|
|
||||||
@@ -106,10 +110,60 @@ fn resources_list_descriptor() -> McpToolDescriptor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
|
||||||
|
let mut config = OwlenConfig::load(None).unwrap_or_default();
|
||||||
|
let requested_name =
|
||||||
|
env::var("OWLEN_PROVIDER").unwrap_or_else(|_| config.general.default_provider.clone());
|
||||||
|
let provider_key = canonical_provider_name(&requested_name);
|
||||||
|
if config.provider(&provider_key).is_none() {
|
||||||
|
ensure_provider_config(&mut config, &provider_key);
|
||||||
|
}
|
||||||
|
let provider_cfg: ProviderConfig =
|
||||||
|
config.provider(&provider_key).cloned().ok_or_else(|| {
|
||||||
|
RpcError::internal_error(format!(
|
||||||
|
"Provider '{provider_key}' not found in configuration"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match provider_cfg.provider_type.as_str() {
|
||||||
|
"ollama" | "ollama_cloud" => {
|
||||||
|
let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||||
|
.map_err(|e| {
|
||||||
|
RpcError::internal_error(format!(
|
||||||
|
"Failed to init Ollama provider from config: {e}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||||
|
}
|
||||||
|
other => Err(RpcError::internal_error(format!(
|
||||||
|
"Unsupported provider type '{other}' for MCP LLM server"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_provider() -> Result<Arc<dyn Provider>, RpcError> {
|
||||||
|
if let Ok(url) = env::var("OLLAMA_URL") {
|
||||||
|
let provider = OllamaProvider::new(&url).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to init Ollama provider: {e}"))
|
||||||
|
})?;
|
||||||
|
return Ok(Arc::new(provider) as Arc<dyn Provider>);
|
||||||
|
}
|
||||||
|
|
||||||
|
provider_from_config()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn canonical_provider_name(name: &str) -> String {
|
||||||
|
let normalized = name.trim().to_ascii_lowercase().replace('-', "_");
|
||||||
|
match normalized.as_str() {
|
||||||
|
"" => "ollama_local".to_string(),
|
||||||
|
"ollama" | "ollama_local" => "ollama_local".to_string(),
|
||||||
|
"ollama_cloud" => "ollama_cloud".to_string(),
|
||||||
|
other => other.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_generate_text(args: GenerateTextArgs) -> Result<String, RpcError> {
|
async fn handle_generate_text(args: GenerateTextArgs) -> Result<String, RpcError> {
|
||||||
// Create provider with default local Ollama URL
|
let provider = create_provider()?;
|
||||||
let provider = OllamaProvider::new("http://localhost:11434")
|
|
||||||
.map_err(|e| RpcError::internal_error(format!("Failed to init OllamaProvider: {}", e)))?;
|
|
||||||
|
|
||||||
let parameters = ChatParameters {
|
let parameters = ChatParameters {
|
||||||
temperature: args.temperature,
|
temperature: args.temperature,
|
||||||
@@ -127,7 +181,7 @@ async fn handle_generate_text(args: GenerateTextArgs) -> Result<String, RpcError
|
|||||||
|
|
||||||
// Use streaming API and collect output
|
// Use streaming API and collect output
|
||||||
let mut stream = provider
|
let mut stream = provider
|
||||||
.chat_stream(request)
|
.stream_prompt(request)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| RpcError::internal_error(format!("Chat request failed: {}", e)))?;
|
.map_err(|e| RpcError::internal_error(format!("Chat request failed: {}", e)))?;
|
||||||
let mut content = String::new();
|
let mut content = String::new();
|
||||||
@@ -177,7 +231,9 @@ async fn handle_request(req: &RpcRequest) -> Result<Value, RpcError> {
|
|||||||
supports_streaming: Some(true),
|
supports_streaming: Some(true),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
Ok(serde_json::to_value(result).unwrap())
|
serde_json::to_value(result).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to serialize init result: {}", e))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
methods::TOOLS_LIST => {
|
methods::TOOLS_LIST => {
|
||||||
let tools = vec![
|
let tools = vec![
|
||||||
@@ -189,15 +245,14 @@ async fn handle_request(req: &RpcRequest) -> Result<Value, RpcError> {
|
|||||||
}
|
}
|
||||||
// New method to list available Ollama models via the provider.
|
// New method to list available Ollama models via the provider.
|
||||||
methods::MODELS_LIST => {
|
methods::MODELS_LIST => {
|
||||||
// Reuse the provider instance for model listing.
|
let provider = create_provider()?;
|
||||||
let provider = OllamaProvider::new("http://localhost:11434").map_err(|e| {
|
|
||||||
RpcError::internal_error(format!("Failed to init OllamaProvider: {}", e))
|
|
||||||
})?;
|
|
||||||
let models = provider
|
let models = provider
|
||||||
.list_models()
|
.list_models()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| RpcError::internal_error(format!("Failed to list models: {}", e)))?;
|
.map_err(|e| RpcError::internal_error(format!("Failed to list models: {}", e)))?;
|
||||||
Ok(serde_json::to_value(models).unwrap())
|
serde_json::to_value(models).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to serialize model list: {}", e))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
methods::TOOLS_CALL => {
|
methods::TOOLS_CALL => {
|
||||||
// For streaming we will send incremental notifications directly from here.
|
// For streaming we will send incremental notifications directly from here.
|
||||||
@@ -283,10 +338,24 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
metadata: HashMap::new(),
|
metadata: HashMap::new(),
|
||||||
duration_ms: 0,
|
duration_ms: 0,
|
||||||
};
|
};
|
||||||
let final_resp = RpcResponse::new(
|
let payload = match serde_json::to_value(&response) {
|
||||||
id.clone(),
|
Ok(value) => value,
|
||||||
serde_json::to_value(response).unwrap(),
|
Err(e) => {
|
||||||
);
|
let err_resp = RpcErrorResponse::new(
|
||||||
|
id.clone(),
|
||||||
|
RpcError::internal_error(format!(
|
||||||
|
"Failed to serialize resource response: {}",
|
||||||
|
e
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
let s = serde_json::to_string(&err_resp)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||||
let s = serde_json::to_string(&final_resp)?;
|
let s = serde_json::to_string(&final_resp)?;
|
||||||
stdout.write_all(s.as_bytes()).await?;
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
stdout.write_all(b"\n").await?;
|
stdout.write_all(b"\n").await?;
|
||||||
@@ -327,10 +396,24 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
metadata: HashMap::new(),
|
metadata: HashMap::new(),
|
||||||
duration_ms: 0,
|
duration_ms: 0,
|
||||||
};
|
};
|
||||||
let final_resp = RpcResponse::new(
|
let payload = match serde_json::to_value(&response) {
|
||||||
id.clone(),
|
Ok(value) => value,
|
||||||
serde_json::to_value(response).unwrap(),
|
Err(e) => {
|
||||||
);
|
let err_resp = RpcErrorResponse::new(
|
||||||
|
id.clone(),
|
||||||
|
RpcError::internal_error(format!(
|
||||||
|
"Failed to serialize directory listing: {}",
|
||||||
|
e
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
let s = serde_json::to_string(&err_resp)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||||
let s = serde_json::to_string(&final_resp)?;
|
let s = serde_json::to_string(&final_resp)?;
|
||||||
stdout.write_all(s.as_bytes()).await?;
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
stdout.write_all(b"\n").await?;
|
stdout.write_all(b"\n").await?;
|
||||||
@@ -376,14 +459,14 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize Ollama provider and start streaming
|
// Initialize provider and start streaming
|
||||||
let provider = match OllamaProvider::new("http://localhost:11434") {
|
let provider = match create_provider() {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let err_resp = RpcErrorResponse::new(
|
let err_resp = RpcErrorResponse::new(
|
||||||
id.clone(),
|
id.clone(),
|
||||||
RpcError::internal_error(format!(
|
RpcError::internal_error(format!(
|
||||||
"Failed to init OllamaProvider: {}",
|
"Failed to initialize provider: {:?}",
|
||||||
e
|
e
|
||||||
)),
|
)),
|
||||||
);
|
);
|
||||||
@@ -406,7 +489,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
parameters,
|
parameters,
|
||||||
tools: None,
|
tools: None,
|
||||||
};
|
};
|
||||||
let mut stream = match provider.chat_stream(request).await {
|
let mut stream = match provider.stream_prompt(request).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let err_resp = RpcErrorResponse::new(
|
let err_resp = RpcErrorResponse::new(
|
||||||
@@ -462,8 +545,24 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
metadata: HashMap::new(),
|
metadata: HashMap::new(),
|
||||||
duration_ms: 0,
|
duration_ms: 0,
|
||||||
};
|
};
|
||||||
let final_resp =
|
let payload = match serde_json::to_value(&response) {
|
||||||
RpcResponse::new(id.clone(), serde_json::to_value(response).unwrap());
|
Ok(value) => value,
|
||||||
|
Err(e) => {
|
||||||
|
let err_resp = RpcErrorResponse::new(
|
||||||
|
id.clone(),
|
||||||
|
RpcError::internal_error(format!(
|
||||||
|
"Failed to serialize final streaming response: {}",
|
||||||
|
e
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
let s = serde_json::to_string(&err_resp)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||||
let s = serde_json::to_string(&final_resp)?;
|
let s = serde_json::to_string(&final_resp)?;
|
||||||
stdout.write_all(s.as_bytes()).await?;
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
stdout.write_all(b"\n").await?;
|
stdout.write_all(b"\n").await?;
|
||||||
21
crates/mcp/prompt-server/Cargo.toml
Normal file
21
crates/mcp/prompt-server/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "owlen-mcp-prompt-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
description = "MCP server that renders prompt templates (YAML) for Owlen"
|
||||||
|
license = "AGPL-3.0"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
owlen-core = { path = "../../owlen-core" }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
serde_yaml = { workspace = true }
|
||||||
|
tokio = { workspace = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
handlebars = { workspace = true }
|
||||||
|
dirs = { workspace = true }
|
||||||
|
futures = { workspace = true }
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "owlen_mcp_prompt_server"
|
||||||
|
path = "src/lib.rs"
|
||||||
415
crates/mcp/prompt-server/src/lib.rs
Normal file
415
crates/mcp/prompt-server/src/lib.rs
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
//! MCP server for rendering prompt templates with YAML storage and Handlebars rendering.
|
||||||
|
//!
|
||||||
|
//! Templates are stored in `~/.config/owlen/prompts/` as YAML files.
|
||||||
|
//! Provides full Handlebars templating support for dynamic prompt generation.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use handlebars::Handlebars;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use owlen_core::mcp::protocol::{
|
||||||
|
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||||
|
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||||
|
};
|
||||||
|
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
|
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
/// Prompt template definition
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PromptTemplate {
|
||||||
|
/// Template name
|
||||||
|
pub name: String,
|
||||||
|
/// Template version
|
||||||
|
pub version: String,
|
||||||
|
/// Optional mode restriction
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mode: Option<String>,
|
||||||
|
/// Handlebars template content
|
||||||
|
pub template: String,
|
||||||
|
/// Template description
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prompt server managing templates
|
||||||
|
pub struct PromptServer {
|
||||||
|
templates: Arc<RwLock<HashMap<String, PromptTemplate>>>,
|
||||||
|
handlebars: Handlebars<'static>,
|
||||||
|
templates_dir: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptServer {
|
||||||
|
/// Create a new prompt server
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let templates_dir = Self::get_templates_dir()?;
|
||||||
|
|
||||||
|
// Create templates directory if it doesn't exist
|
||||||
|
if !templates_dir.exists() {
|
||||||
|
fs::create_dir_all(&templates_dir)?;
|
||||||
|
Self::create_default_templates(&templates_dir)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut server = Self {
|
||||||
|
templates: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
handlebars: Handlebars::new(),
|
||||||
|
templates_dir,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load all templates
|
||||||
|
server.load_templates()?;
|
||||||
|
|
||||||
|
Ok(server)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the templates directory path
|
||||||
|
fn get_templates_dir() -> Result<PathBuf> {
|
||||||
|
let config_dir = dirs::config_dir().context("Could not determine config directory")?;
|
||||||
|
Ok(config_dir.join("owlen").join("prompts"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create default template examples
|
||||||
|
fn create_default_templates(dir: &Path) -> Result<()> {
|
||||||
|
let chat_mode_system = PromptTemplate {
|
||||||
|
name: "chat_mode_system".to_string(),
|
||||||
|
version: "1.0".to_string(),
|
||||||
|
mode: Some("chat".to_string()),
|
||||||
|
description: Some("System prompt for chat mode".to_string()),
|
||||||
|
template: r#"You are Owlen, a helpful AI assistant. You have access to these tools:
|
||||||
|
{{#each tools}}
|
||||||
|
- {{name}}: {{description}}
|
||||||
|
{{/each}}
|
||||||
|
|
||||||
|
Use the ReAct pattern:
|
||||||
|
THOUGHT: Your reasoning
|
||||||
|
ACTION: tool_name
|
||||||
|
ACTION_INPUT: {"param": "value"}
|
||||||
|
|
||||||
|
When you have enough information:
|
||||||
|
FINAL_ANSWER: Your response"#
|
||||||
|
.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let code_mode_system = PromptTemplate {
|
||||||
|
name: "code_mode_system".to_string(),
|
||||||
|
version: "1.0".to_string(),
|
||||||
|
mode: Some("code".to_string()),
|
||||||
|
description: Some("System prompt for code mode".to_string()),
|
||||||
|
template: r#"You are Owlen in code mode, with full development capabilities. You have access to:
|
||||||
|
{{#each tools}}
|
||||||
|
- {{name}}: {{description}}
|
||||||
|
{{/each}}
|
||||||
|
|
||||||
|
Use the ReAct pattern to solve coding tasks:
|
||||||
|
THOUGHT: Analyze what needs to be done
|
||||||
|
ACTION: tool_name (compile_project, run_tests, format_code, lint_code, etc.)
|
||||||
|
ACTION_INPUT: {"param": "value"}
|
||||||
|
|
||||||
|
Continue iterating until the task is complete, then provide:
|
||||||
|
FINAL_ANSWER: Summary of what was done"#
|
||||||
|
.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Save templates
|
||||||
|
let chat_path = dir.join("chat_mode_system.yaml");
|
||||||
|
let code_path = dir.join("code_mode_system.yaml");
|
||||||
|
|
||||||
|
fs::write(chat_path, serde_yaml::to_string(&chat_mode_system)?)?;
|
||||||
|
fs::write(code_path, serde_yaml::to_string(&code_mode_system)?)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load all templates from the templates directory
|
||||||
|
fn load_templates(&mut self) -> Result<()> {
|
||||||
|
let entries = fs::read_dir(&self.templates_dir)?;
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.extension().and_then(|s| s.to_str()) == Some("yaml")
|
||||||
|
|| path.extension().and_then(|s| s.to_str()) == Some("yml")
|
||||||
|
{
|
||||||
|
match self.load_template(&path) {
|
||||||
|
Ok(template) => {
|
||||||
|
// Register with Handlebars
|
||||||
|
if let Err(e) = self
|
||||||
|
.handlebars
|
||||||
|
.register_template_string(&template.name, &template.template)
|
||||||
|
{
|
||||||
|
eprintln!(
|
||||||
|
"Warning: Failed to register template {}: {}",
|
||||||
|
template.name, e
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let mut templates = self.templates.blocking_write();
|
||||||
|
templates.insert(template.name.clone(), template);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Warning: Failed to load template {:?}: {}", path, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a single template from file
|
||||||
|
fn load_template(&self, path: &Path) -> Result<PromptTemplate> {
|
||||||
|
let content = fs::read_to_string(path)?;
|
||||||
|
let template: PromptTemplate = serde_yaml::from_str(&content)?;
|
||||||
|
Ok(template)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a template by name
|
||||||
|
pub async fn get_template(&self, name: &str) -> Option<PromptTemplate> {
|
||||||
|
let templates = self.templates.read().await;
|
||||||
|
templates.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all available templates
|
||||||
|
pub async fn list_templates(&self) -> Vec<String> {
|
||||||
|
let templates = self.templates.read().await;
|
||||||
|
templates.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render a template with given variables
|
||||||
|
pub fn render_template(&self, name: &str, vars: &Value) -> Result<String> {
|
||||||
|
self.handlebars
|
||||||
|
.render(name, vars)
|
||||||
|
.context("Failed to render template")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reload all templates from disk
|
||||||
|
pub async fn reload_templates(&mut self) -> Result<()> {
|
||||||
|
{
|
||||||
|
let mut templates = self.templates.write().await;
|
||||||
|
templates.clear();
|
||||||
|
}
|
||||||
|
self.handlebars = Handlebars::new();
|
||||||
|
self.load_templates()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let mut stdin = io::BufReader::new(io::stdin());
|
||||||
|
let mut stdout = io::stdout();
|
||||||
|
|
||||||
|
let server = Arc::new(tokio::sync::Mutex::new(PromptServer::new()?));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut line = String::new();
|
||||||
|
match stdin.read_line(&mut line).await {
|
||||||
|
Ok(0) => break, // EOF
|
||||||
|
Ok(_) => {
|
||||||
|
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let err = RpcErrorResponse::new(
|
||||||
|
RequestId::Number(0),
|
||||||
|
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||||
|
);
|
||||||
|
let s = serde_json::to_string(&err)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = handle_request(req.clone(), server.clone()).await;
|
||||||
|
match resp {
|
||||||
|
Ok(r) => {
|
||||||
|
let s = serde_json::to_string(&r)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let err = RpcErrorResponse::new(req.id.clone(), e);
|
||||||
|
let s = serde_json::to_string(&err)?;
|
||||||
|
stdout.write_all(s.as_bytes()).await?;
|
||||||
|
stdout.write_all(b"\n").await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Error reading stdin: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
async fn handle_request(
|
||||||
|
req: RpcRequest,
|
||||||
|
server: Arc<tokio::sync::Mutex<PromptServer>>,
|
||||||
|
) -> Result<RpcResponse, RpcError> {
|
||||||
|
match req.method.as_str() {
|
||||||
|
methods::INITIALIZE => {
|
||||||
|
let params: InitializeParams =
|
||||||
|
serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||||
|
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||||
|
if !params.protocol_version.eq(PROTOCOL_VERSION) {
|
||||||
|
return Err(RpcError::new(
|
||||||
|
ErrorCode::INVALID_REQUEST,
|
||||||
|
format!(
|
||||||
|
"Incompatible protocol version. Client: {}, Server: {}",
|
||||||
|
params.protocol_version, PROTOCOL_VERSION
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let result = InitializeResult {
|
||||||
|
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||||
|
server_info: ServerInfo {
|
||||||
|
name: "owlen-mcp-prompt-server".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
},
|
||||||
|
capabilities: ServerCapabilities {
|
||||||
|
supports_tools: Some(true),
|
||||||
|
supports_resources: Some(false),
|
||||||
|
supports_streaming: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let payload = serde_json::to_value(result).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(RpcResponse::new(req.id, payload))
|
||||||
|
}
|
||||||
|
methods::TOOLS_LIST => {
|
||||||
|
let tools = vec![
|
||||||
|
McpToolDescriptor {
|
||||||
|
name: "get_prompt".to_string(),
|
||||||
|
description: "Retrieve a prompt template by name".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "description": "Template name"}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}),
|
||||||
|
requires_network: false,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
},
|
||||||
|
McpToolDescriptor {
|
||||||
|
name: "render_prompt".to_string(),
|
||||||
|
description: "Render a prompt template with Handlebars variables".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "description": "Template name"},
|
||||||
|
"vars": {"type": "object", "description": "Variables for Handlebars rendering"}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}),
|
||||||
|
requires_network: false,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
},
|
||||||
|
McpToolDescriptor {
|
||||||
|
name: "list_prompts".to_string(),
|
||||||
|
description: "List all available prompt templates".to_string(),
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
requires_network: false,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
},
|
||||||
|
McpToolDescriptor {
|
||||||
|
name: "reload_prompts".to_string(),
|
||||||
|
description: "Reload all prompts from disk".to_string(),
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
requires_network: false,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
Ok(RpcResponse::new(req.id, json!(tools)))
|
||||||
|
}
|
||||||
|
methods::TOOLS_CALL => {
|
||||||
|
let call: McpToolCall = serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||||
|
.map_err(|e| RpcError::invalid_params(format!("Invalid tool call: {}", e)))?;
|
||||||
|
|
||||||
|
let result = match call.name.as_str() {
|
||||||
|
"get_prompt" => {
|
||||||
|
let name = call
|
||||||
|
.arguments
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| RpcError::invalid_params("Missing 'name' parameter"))?;
|
||||||
|
|
||||||
|
let srv = server.lock().await;
|
||||||
|
match srv.get_template(name).await {
|
||||||
|
Some(template) => match serde_json::to_value(template) {
|
||||||
|
Ok(serialized) => {
|
||||||
|
json!({"success": true, "template": serialized})
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(RpcError::internal_error(format!(
|
||||||
|
"Failed to serialize template '{}': {}",
|
||||||
|
name, e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => json!({"success": false, "error": "Template not found"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"render_prompt" => {
|
||||||
|
let name = call
|
||||||
|
.arguments
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| RpcError::invalid_params("Missing 'name' parameter"))?;
|
||||||
|
|
||||||
|
let default_vars = json!({});
|
||||||
|
let vars = call.arguments.get("vars").unwrap_or(&default_vars);
|
||||||
|
|
||||||
|
let srv = server.lock().await;
|
||||||
|
match srv.render_template(name, vars) {
|
||||||
|
Ok(rendered) => json!({"success": true, "rendered": rendered}),
|
||||||
|
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"list_prompts" => {
|
||||||
|
let srv = server.lock().await;
|
||||||
|
let templates = srv.list_templates().await;
|
||||||
|
json!({"success": true, "templates": templates})
|
||||||
|
}
|
||||||
|
"reload_prompts" => {
|
||||||
|
let mut srv = server.lock().await;
|
||||||
|
match srv.reload_templates().await {
|
||||||
|
Ok(_) => json!({"success": true, "message": "Prompts reloaded"}),
|
||||||
|
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(RpcError::method_not_found(&call.name)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: result
|
||||||
|
.get("success")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false),
|
||||||
|
output: result,
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
duration_ms: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||||
|
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(RpcResponse::new(req.id, payload))
|
||||||
|
}
|
||||||
|
_ => Err(RpcError::method_not_found(&req.method)),
|
||||||
|
}
|
||||||
|
}
|
||||||
3
crates/mcp/prompt-server/templates/example.yaml
Normal file
3
crates/mcp/prompt-server/templates/example.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
prompt: |
|
||||||
|
Hello {{name}}!
|
||||||
|
Your role is: {{role}}.
|
||||||
12
crates/mcp/server/Cargo.toml
Normal file
12
crates/mcp/server/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "owlen-mcp-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
tokio = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
path-clean = "1.0"
|
||||||
|
owlen-core = { path = "../../owlen-core" }
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use owlen_core::mcp::protocol::{
|
use owlen_core::mcp::protocol::{
|
||||||
is_compatible, ErrorCode, InitializeParams, InitializeResult, RequestId, RpcError,
|
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, PROTOCOL_VERSION,
|
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, is_compatible,
|
||||||
};
|
};
|
||||||
use path_clean::PathClean;
|
use path_clean::PathClean;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@@ -11,7 +11,6 @@ description = "Command-line interface for OWLEN LLM client"
|
|||||||
[features]
|
[features]
|
||||||
default = ["chat-client"]
|
default = ["chat-client"]
|
||||||
chat-client = ["owlen-tui"]
|
chat-client = ["owlen-tui"]
|
||||||
code-client = []
|
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "owlen"
|
name = "owlen"
|
||||||
@@ -21,7 +20,7 @@ required-features = ["chat-client"]
|
|||||||
[[bin]]
|
[[bin]]
|
||||||
name = "owlen-code"
|
name = "owlen-code"
|
||||||
path = "src/code_main.rs"
|
path = "src/code_main.rs"
|
||||||
required-features = ["code-client"]
|
required-features = ["chat-client"]
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "owlen-agent"
|
name = "owlen-agent"
|
||||||
@@ -30,12 +29,15 @@ required-features = ["chat-client"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
owlen-core = { path = "../owlen-core" }
|
owlen-core = { path = "../owlen-core" }
|
||||||
owlen-ollama = { path = "../owlen-ollama" }
|
owlen-providers = { path = "../owlen-providers" }
|
||||||
# Optional TUI dependency, enabled by the "chat-client" feature.
|
# Optional TUI dependency, enabled by the "chat-client" feature.
|
||||||
owlen-tui = { path = "../owlen-tui", optional = true }
|
owlen-tui = { path = "../owlen-tui", optional = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
futures = { workspace = true }
|
||||||
|
|
||||||
# CLI framework
|
# CLI framework
|
||||||
clap = { version = "4.0", features = ["derive"] }
|
clap = { workspace = true, features = ["derive"] }
|
||||||
|
|
||||||
# Async runtime
|
# Async runtime
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
@@ -49,6 +51,10 @@ crossterm = { workspace = true }
|
|||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
regex = "1"
|
regex = { workspace = true }
|
||||||
thiserror = "1"
|
thiserror = { workspace = true }
|
||||||
dirs = "5"
|
dirs = { workspace = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio = { workspace = true }
|
||||||
|
tokio-test = { workspace = true }
|
||||||
|
|||||||
31
crates/owlen-cli/build.rs
Normal file
31
crates/owlen-cli/build.rs
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
const MIN_VERSION: (u32, u32, u32) = (1, 75, 0);
|
||||||
|
|
||||||
|
let rustc = std::env::var("RUSTC").unwrap_or_else(|_| "rustc".into());
|
||||||
|
let output = Command::new(&rustc)
|
||||||
|
.arg("--version")
|
||||||
|
.output()
|
||||||
|
.expect("failed to invoke rustc");
|
||||||
|
|
||||||
|
let version_line = String::from_utf8_lossy(&output.stdout);
|
||||||
|
let version_str = version_line.split_whitespace().nth(1).unwrap_or("0.0.0");
|
||||||
|
let sanitized = version_str.split('-').next().unwrap_or(version_str);
|
||||||
|
|
||||||
|
let mut parts = sanitized
|
||||||
|
.split('.')
|
||||||
|
.map(|part| part.parse::<u32>().unwrap_or(0));
|
||||||
|
let current = (
|
||||||
|
parts.next().unwrap_or(0),
|
||||||
|
parts.next().unwrap_or(0),
|
||||||
|
parts.next().unwrap_or(0),
|
||||||
|
);
|
||||||
|
|
||||||
|
if current < MIN_VERSION {
|
||||||
|
panic!(
|
||||||
|
"owlen requires rustc {}.{}.{} or newer (found {version_line})",
|
||||||
|
MIN_VERSION.0, MIN_VERSION.1, MIN_VERSION.2
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,11 +11,15 @@ use std::sync::Arc;
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use owlen_cli::agent::{AgentConfig, AgentExecutor};
|
use owlen_cli::agent::{AgentConfig, AgentExecutor};
|
||||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||||
use owlen_ollama::OllamaProvider;
|
|
||||||
|
|
||||||
/// Command‑line arguments for the agent binary.
|
/// Command‑line arguments for the agent binary.
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(name = "owlen-agent", author, version, about = "Run the ReAct agent")]
|
#[command(
|
||||||
|
name = "owlen-agent",
|
||||||
|
author,
|
||||||
|
version,
|
||||||
|
about = "Run the ReAct agent via MCP"
|
||||||
|
)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// The initial user query.
|
/// The initial user query.
|
||||||
prompt: String,
|
prompt: String,
|
||||||
@@ -31,11 +35,13 @@ struct Args {
|
|||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
// Initialise the LLM provider (Ollama) – uses default local URL.
|
// Initialise the MCP LLM client – it implements Provider and talks to the
|
||||||
let provider = Arc::new(OllamaProvider::new("http://localhost:11434")?);
|
// MCP LLM server which wraps Ollama. This ensures all communication goes
|
||||||
// Initialise the MCP client (remote LLM server) – this client also knows how
|
// through the MCP architecture (Phase 10 requirement).
|
||||||
// to call the built‑in resource tools.
|
let provider = Arc::new(RemoteMcpClient::new()?);
|
||||||
let mcp_client = Arc::new(RemoteMcpClient::new()?);
|
|
||||||
|
// The MCP client also serves as the tool client for resource operations
|
||||||
|
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||||
|
|
||||||
let config = AgentConfig {
|
let config = AgentConfig {
|
||||||
max_iterations: args.max_iter,
|
max_iterations: args.max_iter,
|
||||||
@@ -43,10 +49,11 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
..AgentConfig::default()
|
..AgentConfig::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let executor = AgentExecutor::new(provider, mcp_client, config, None);
|
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||||
match executor.run(args.prompt).await {
|
match executor.run(args.prompt).await {
|
||||||
Ok(answer) => {
|
Ok(result) => {
|
||||||
println!("\nFinal answer:\n{}", answer);
|
println!("\n✓ Agent completed in {} iterations", result.iterations);
|
||||||
|
println!("\nFinal answer:\n{}", result.answer);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Err(e) => Err(anyhow::anyhow!(e)),
|
Err(e) => Err(anyhow::anyhow!(e)),
|
||||||
|
|||||||
326
crates/owlen-cli/src/bootstrap.rs
Normal file
326
crates/owlen-cli/src/bootstrap.rs
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
use std::borrow::Cow;
|
||||||
|
use std::io;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use crossterm::{
|
||||||
|
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||||
|
execute,
|
||||||
|
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||||
|
};
|
||||||
|
use futures::stream;
|
||||||
|
use owlen_core::{
|
||||||
|
ChatStream, Error, Provider,
|
||||||
|
config::{Config, McpMode},
|
||||||
|
mcp::remote_client::RemoteMcpClient,
|
||||||
|
mode::Mode,
|
||||||
|
provider::ProviderManager,
|
||||||
|
providers::OllamaProvider,
|
||||||
|
session::{ControllerEvent, SessionController},
|
||||||
|
storage::StorageManager,
|
||||||
|
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||||
|
};
|
||||||
|
use owlen_tui::{
|
||||||
|
ChatApp, SessionEvent,
|
||||||
|
app::App as RuntimeApp,
|
||||||
|
config,
|
||||||
|
tui_controller::{TuiController, TuiRequest},
|
||||||
|
ui,
|
||||||
|
};
|
||||||
|
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::commands::cloud::{load_runtime_credentials, set_env_var};
|
||||||
|
|
||||||
|
pub async fn launch(initial_mode: Mode) -> Result<()> {
|
||||||
|
set_env_var("OWLEN_AUTO_CONSENT", "1");
|
||||||
|
|
||||||
|
let color_support = detect_terminal_color_support();
|
||||||
|
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||||
|
let _ = cfg.refresh_mcp_servers(None);
|
||||||
|
|
||||||
|
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||||
|
let term_label = match &color_support {
|
||||||
|
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||||
|
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||||
|
};
|
||||||
|
eprintln!(
|
||||||
|
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||||
|
term_label, BASIC_THEME_NAME, previous_theme
|
||||||
|
);
|
||||||
|
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||||
|
eprintln!(
|
||||||
|
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||||
|
term
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.validate()?;
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||||
|
|
||||||
|
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||||
|
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||||
|
|
||||||
|
let provider = build_provider(&cfg)?;
|
||||||
|
let mut offline_notice: Option<String> = None;
|
||||||
|
let provider = match provider.health_check().await {
|
||||||
|
Ok(_) => provider,
|
||||||
|
Err(err) => {
|
||||||
|
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||||
|
&& !cfg.effective_mcp_servers().is_empty()
|
||||||
|
{
|
||||||
|
"Ensure the configured MCP server is running and reachable."
|
||||||
|
} else {
|
||||||
|
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||||
|
};
|
||||||
|
let notice =
|
||||||
|
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||||
|
eprintln!("{notice}");
|
||||||
|
offline_notice = Some(notice.clone());
|
||||||
|
let fallback_model = cfg
|
||||||
|
.general
|
||||||
|
.default_model
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "offline".to_string());
|
||||||
|
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (controller_event_tx, controller_event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||||
|
let controller = SessionController::new(
|
||||||
|
provider,
|
||||||
|
cfg,
|
||||||
|
storage.clone(),
|
||||||
|
tui_controller,
|
||||||
|
false,
|
||||||
|
Some(controller_event_tx),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let provider_manager = Arc::new(ProviderManager::default());
|
||||||
|
let mut runtime = RuntimeApp::new(provider_manager);
|
||||||
|
let (mut app, mut session_rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||||
|
app.initialize_models().await?;
|
||||||
|
if let Some(notice) = offline_notice.clone() {
|
||||||
|
app.set_status_message(¬ice);
|
||||||
|
app.set_system_status(notice);
|
||||||
|
}
|
||||||
|
|
||||||
|
app.set_mode(initial_mode).await;
|
||||||
|
|
||||||
|
enable_raw_mode()?;
|
||||||
|
let mut stdout = io::stdout();
|
||||||
|
execute!(
|
||||||
|
stdout,
|
||||||
|
EnterAlternateScreen,
|
||||||
|
EnableMouseCapture,
|
||||||
|
EnableBracketedPaste
|
||||||
|
)?;
|
||||||
|
let backend = CrosstermBackend::new(stdout);
|
||||||
|
let mut terminal = Terminal::new(backend)?;
|
||||||
|
|
||||||
|
let result = run_app(&mut terminal, &mut runtime, &mut app, &mut session_rx).await;
|
||||||
|
|
||||||
|
config::save_config(&app.config())?;
|
||||||
|
|
||||||
|
disable_raw_mode()?;
|
||||||
|
execute!(
|
||||||
|
terminal.backend_mut(),
|
||||||
|
LeaveAlternateScreen,
|
||||||
|
DisableMouseCapture,
|
||||||
|
DisableBracketedPaste
|
||||||
|
)?;
|
||||||
|
terminal.show_cursor()?;
|
||||||
|
|
||||||
|
if let Err(err) = result {
|
||||||
|
println!("{err:?}");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||||
|
match cfg.mcp.mode {
|
||||||
|
McpMode::RemotePreferred => {
|
||||||
|
let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
|
||||||
|
RemoteMcpClient::new_with_config(mcp_server)
|
||||||
|
} else {
|
||||||
|
RemoteMcpClient::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
match remote_result {
|
||||||
|
Ok(client) => Ok(Arc::new(client) as Arc<dyn Provider>),
|
||||||
|
Err(err) if cfg.mcp.allow_fallback => {
|
||||||
|
log::warn!(
|
||||||
|
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
build_local_provider(cfg)
|
||||||
|
}
|
||||||
|
Err(err) => Err(anyhow!(err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
McpMode::RemoteOnly => {
|
||||||
|
let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
|
||||||
|
anyhow!("[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"")
|
||||||
|
})?;
|
||||||
|
let client = RemoteMcpClient::new_with_config(mcp_server)?;
|
||||||
|
Ok(Arc::new(client) as Arc<dyn Provider>)
|
||||||
|
}
|
||||||
|
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||||
|
McpMode::Disabled => Err(anyhow!(
|
||||||
|
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||||
|
let provider_name = cfg.general.default_provider.clone();
|
||||||
|
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||||
|
anyhow!(format!(
|
||||||
|
"No provider configuration found for '{provider_name}' in [providers]"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match provider_cfg.provider_type.as_str() {
|
||||||
|
"ollama" | "ollama_cloud" => {
|
||||||
|
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
||||||
|
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||||
|
}
|
||||||
|
other => Err(anyhow!(format!(
|
||||||
|
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum TerminalColorSupport {
|
||||||
|
Full,
|
||||||
|
Limited { term: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||||
|
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||||
|
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||||
|
let term_lower = term.to_lowercase();
|
||||||
|
let color_lower = colorterm.to_lowercase();
|
||||||
|
|
||||||
|
let supports_extended = term_lower.contains("256color")
|
||||||
|
|| color_lower.contains("truecolor")
|
||||||
|
|| color_lower.contains("24bit")
|
||||||
|
|| color_lower.contains("fullcolor");
|
||||||
|
|
||||||
|
if supports_extended {
|
||||||
|
TerminalColorSupport::Full
|
||||||
|
} else {
|
||||||
|
TerminalColorSupport::Limited { term }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||||
|
match support {
|
||||||
|
TerminalColorSupport::Full => None,
|
||||||
|
TerminalColorSupport::Limited { .. } => {
|
||||||
|
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||||
|
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||||
|
Some(previous)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct OfflineProvider {
|
||||||
|
reason: String,
|
||||||
|
placeholder_model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OfflineProvider {
|
||||||
|
fn new(reason: String, placeholder_model: String) -> Self {
|
||||||
|
Self {
|
||||||
|
reason,
|
||||||
|
placeholder_model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||||
|
let mut message = String::new();
|
||||||
|
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||||
|
message.push_str(&self.reason);
|
||||||
|
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||||
|
message.push_str(&format!(
|
||||||
|
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||||
|
requested_model
|
||||||
|
));
|
||||||
|
}
|
||||||
|
message.push_str(
|
||||||
|
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||||
|
);
|
||||||
|
|
||||||
|
ChatResponse {
|
||||||
|
message: Message::assistant(message),
|
||||||
|
usage: None,
|
||||||
|
is_streaming: false,
|
||||||
|
is_final: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for OfflineProvider {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"offline"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||||
|
Ok(vec![ModelInfo {
|
||||||
|
id: self.placeholder_model.clone(),
|
||||||
|
provider: "offline".to_string(),
|
||||||
|
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||||
|
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||||
|
context_window: None,
|
||||||
|
capabilities: vec![],
|
||||||
|
supports_tools: false,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||||
|
Ok(self.friendly_response(&request.model))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||||
|
let response = self.friendly_response(&request.model);
|
||||||
|
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> Result<(), Error> {
|
||||||
|
Err(Error::Provider(anyhow!(
|
||||||
|
"offline provider cannot reach any backing models"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_app(
|
||||||
|
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||||
|
runtime: &mut RuntimeApp,
|
||||||
|
app: &mut ChatApp,
|
||||||
|
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut render = |terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||||
|
state: &mut ChatApp|
|
||||||
|
-> Result<()> {
|
||||||
|
terminal.draw(|f| ui::render_chat(f, state))?;
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
runtime.run(terminal, app, session_rx, &mut render).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,143 +1,16 @@
|
|||||||
//! OWLEN Code Mode - TUI client optimized for coding assistance
|
//! Owlen CLI entrypoint optimised for code-first workflows.
|
||||||
|
#![allow(dead_code, unused_imports)]
|
||||||
|
|
||||||
|
mod bootstrap;
|
||||||
|
mod commands;
|
||||||
|
mod mcp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::{Arg, Command};
|
use owlen_core::config as core_config;
|
||||||
use owlen_core::{session::SessionController, storage::StorageManager};
|
use owlen_core::mode::Mode;
|
||||||
use owlen_ollama::OllamaProvider;
|
use owlen_tui::config;
|
||||||
use owlen_tui::{config, ui, AppState, CodeApp, Event, EventHandler, SessionEvent};
|
|
||||||
use std::io;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
|
|
||||||
use crossterm::{
|
|
||||||
event::{DisableMouseCapture, EnableMouseCapture},
|
|
||||||
execute,
|
|
||||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
|
||||||
};
|
|
||||||
use ratatui::{backend::CrosstermBackend, Terminal};
|
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
#[tokio::main(flavor = "multi_thread")]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
let matches = Command::new("owlen-code")
|
bootstrap::launch(Mode::Code).await
|
||||||
.about("OWLEN Code Mode - TUI optimized for programming assistance")
|
|
||||||
.version(env!("CARGO_PKG_VERSION"))
|
|
||||||
.arg(
|
|
||||||
Arg::new("model")
|
|
||||||
.short('m')
|
|
||||||
.long("model")
|
|
||||||
.value_name("MODEL")
|
|
||||||
.help("Preferred model to use for this session"),
|
|
||||||
)
|
|
||||||
.get_matches();
|
|
||||||
|
|
||||||
let mut config = config::try_load_config().unwrap_or_default();
|
|
||||||
// Disable encryption for code mode.
|
|
||||||
config.privacy.encrypt_local_data = false;
|
|
||||||
|
|
||||||
if let Some(model) = matches.get_one::<String>("model") {
|
|
||||||
config.general.default_model = Some(model.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let provider_name = config.general.default_provider.clone();
|
|
||||||
let provider_cfg = config::ensure_provider_config(&mut config, &provider_name).clone();
|
|
||||||
|
|
||||||
let provider_type = provider_cfg.provider_type.to_ascii_lowercase();
|
|
||||||
if provider_type != "ollama" && provider_type != "ollama-cloud" {
|
|
||||||
anyhow::bail!(
|
|
||||||
"Unsupported provider type '{}' configured for provider '{}'",
|
|
||||||
provider_cfg.provider_type,
|
|
||||||
provider_name
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let provider = Arc::new(OllamaProvider::from_config(
|
|
||||||
&provider_cfg,
|
|
||||||
Some(&config.general),
|
|
||||||
)?);
|
|
||||||
|
|
||||||
let storage = Arc::new(StorageManager::new().await?);
|
|
||||||
// Code client - code execution tools enabled
|
|
||||||
use owlen_core::ui::NoOpUiController;
|
|
||||||
let controller = SessionController::new(
|
|
||||||
provider,
|
|
||||||
config.clone(),
|
|
||||||
storage.clone(),
|
|
||||||
Arc::new(NoOpUiController),
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
let (mut app, mut session_rx) = CodeApp::new(controller).await?;
|
|
||||||
app.inner_mut().initialize_models().await?;
|
|
||||||
|
|
||||||
let cancellation_token = CancellationToken::new();
|
|
||||||
let (event_tx, event_rx) = mpsc::unbounded_channel();
|
|
||||||
let event_handler = EventHandler::new(event_tx, cancellation_token.clone());
|
|
||||||
let event_handle = tokio::spawn(async move { event_handler.run().await });
|
|
||||||
|
|
||||||
enable_raw_mode()?;
|
|
||||||
let mut stdout = io::stdout();
|
|
||||||
execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?;
|
|
||||||
let backend = CrosstermBackend::new(stdout);
|
|
||||||
let mut terminal = Terminal::new(backend)?;
|
|
||||||
|
|
||||||
let result = run_app(&mut terminal, &mut app, event_rx, &mut session_rx).await;
|
|
||||||
|
|
||||||
cancellation_token.cancel();
|
|
||||||
event_handle.await?;
|
|
||||||
|
|
||||||
config::save_config(&app.inner().config())?;
|
|
||||||
|
|
||||||
disable_raw_mode()?;
|
|
||||||
execute!(
|
|
||||||
terminal.backend_mut(),
|
|
||||||
LeaveAlternateScreen,
|
|
||||||
DisableMouseCapture
|
|
||||||
)?;
|
|
||||||
terminal.show_cursor()?;
|
|
||||||
|
|
||||||
if let Err(err) = result {
|
|
||||||
println!("{err:?}");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_app(
|
|
||||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
|
||||||
app: &mut CodeApp,
|
|
||||||
mut event_rx: mpsc::UnboundedReceiver<Event>,
|
|
||||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
|
||||||
) -> Result<()> {
|
|
||||||
loop {
|
|
||||||
// Advance loading animation frame
|
|
||||||
app.inner_mut().advance_loading_animation();
|
|
||||||
|
|
||||||
terminal.draw(|f| ui::render_chat(f, app.inner_mut()))?;
|
|
||||||
|
|
||||||
// Process any pending LLM requests AFTER UI has been drawn
|
|
||||||
if let Err(e) = app.inner_mut().process_pending_llm_request().await {
|
|
||||||
eprintln!("Error processing LLM request: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process any pending tool executions AFTER UI has been drawn
|
|
||||||
if let Err(e) = app.inner_mut().process_pending_tool_execution().await {
|
|
||||||
eprintln!("Error processing tool execution: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
Some(event) = event_rx.recv() => {
|
|
||||||
if let AppState::Quit = app.handle_event(event).await? {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(session_event) = session_rx.recv() => {
|
|
||||||
app.handle_session_event(session_event)?;
|
|
||||||
}
|
|
||||||
// Add a timeout to keep the animation going even when there are no events
|
|
||||||
_ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
|
|
||||||
// This will cause the loop to continue and advance the animation
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
479
crates/owlen-cli/src/commands/cloud.rs
Normal file
479
crates/owlen-cli/src/commands/cloud.rs
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
use std::ffi::OsStr;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow, bail};
|
||||||
|
use clap::Subcommand;
|
||||||
|
use owlen_core::LlmProvider;
|
||||||
|
use owlen_core::ProviderConfig;
|
||||||
|
use owlen_core::config::{
|
||||||
|
self as core_config, Config, OLLAMA_CLOUD_API_KEY_ENV, OLLAMA_CLOUD_BASE_URL,
|
||||||
|
OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY,
|
||||||
|
};
|
||||||
|
use owlen_core::credentials::{ApiCredentials, CredentialManager, OLLAMA_CLOUD_CREDENTIAL_ID};
|
||||||
|
use owlen_core::encryption;
|
||||||
|
use owlen_core::providers::OllamaProvider;
|
||||||
|
use owlen_core::storage::StorageManager;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
const DEFAULT_CLOUD_ENDPOINT: &str = OLLAMA_CLOUD_BASE_URL;
|
||||||
|
const CLOUD_ENDPOINT_KEY: &str = OLLAMA_CLOUD_ENDPOINT_KEY;
|
||||||
|
const CLOUD_PROVIDER_KEY: &str = "ollama_cloud";
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
pub enum CloudCommand {
|
||||||
|
/// Configure Ollama Cloud credentials
|
||||||
|
Setup {
|
||||||
|
/// API key passed directly on the command line (prompted when omitted)
|
||||||
|
#[arg(long)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
/// Override the cloud endpoint (default: https://ollama.com)
|
||||||
|
#[arg(long)]
|
||||||
|
endpoint: Option<String>,
|
||||||
|
/// Provider name to configure (default: ollama_cloud)
|
||||||
|
#[arg(long, default_value = "ollama_cloud")]
|
||||||
|
provider: String,
|
||||||
|
/// Overwrite the provider base URL with the cloud endpoint
|
||||||
|
#[arg(long)]
|
||||||
|
force_cloud_base_url: bool,
|
||||||
|
},
|
||||||
|
/// Check connectivity to Ollama Cloud
|
||||||
|
Status {
|
||||||
|
/// Provider name to check (default: ollama_cloud)
|
||||||
|
#[arg(long, default_value = "ollama_cloud")]
|
||||||
|
provider: String,
|
||||||
|
},
|
||||||
|
/// List available cloud-hosted models
|
||||||
|
Models {
|
||||||
|
/// Provider name to query (default: ollama_cloud)
|
||||||
|
#[arg(long, default_value = "ollama_cloud")]
|
||||||
|
provider: String,
|
||||||
|
},
|
||||||
|
/// Remove stored Ollama Cloud credentials
|
||||||
|
Logout {
|
||||||
|
/// Provider name to clear (default: ollama_cloud)
|
||||||
|
#[arg(long, default_value = "ollama_cloud")]
|
||||||
|
provider: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_cloud_command(command: CloudCommand) -> Result<()> {
|
||||||
|
match command {
|
||||||
|
CloudCommand::Setup {
|
||||||
|
api_key,
|
||||||
|
endpoint,
|
||||||
|
provider,
|
||||||
|
force_cloud_base_url,
|
||||||
|
} => setup(provider, api_key, endpoint, force_cloud_base_url).await,
|
||||||
|
CloudCommand::Status { provider } => status(provider).await,
|
||||||
|
CloudCommand::Models { provider } => models(provider).await,
|
||||||
|
CloudCommand::Logout { provider } => logout(provider).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn setup(
|
||||||
|
provider: String,
|
||||||
|
api_key: Option<String>,
|
||||||
|
endpoint: Option<String>,
|
||||||
|
force_cloud_base_url: bool,
|
||||||
|
) -> Result<()> {
|
||||||
|
let provider = canonical_provider_name(&provider);
|
||||||
|
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||||
|
let endpoint =
|
||||||
|
normalize_endpoint(&endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string()));
|
||||||
|
|
||||||
|
let base_changed = {
|
||||||
|
let entry = ensure_provider_entry(&mut config, &provider);
|
||||||
|
entry.enabled = true;
|
||||||
|
configure_cloud_endpoint(entry, &endpoint, force_cloud_base_url)
|
||||||
|
};
|
||||||
|
|
||||||
|
let key = match api_key {
|
||||||
|
Some(value) if !value.trim().is_empty() => value,
|
||||||
|
_ => {
|
||||||
|
let prompt = format!("Enter API key for {provider}: ");
|
||||||
|
encryption::prompt_password(&prompt)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if config.privacy.encrypt_local_data {
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
let manager = unlock_credential_manager(&config, storage.clone())?;
|
||||||
|
let credentials = ApiCredentials {
|
||||||
|
api_key: key.clone(),
|
||||||
|
endpoint: endpoint.clone(),
|
||||||
|
};
|
||||||
|
manager
|
||||||
|
.store_credentials(OLLAMA_CLOUD_CREDENTIAL_ID, &credentials)
|
||||||
|
.await?;
|
||||||
|
// Ensure plaintext key is not persisted to disk.
|
||||||
|
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||||
|
entry.api_key = None;
|
||||||
|
}
|
||||||
|
} else if let Some(entry) = config.providers.get_mut(&provider) {
|
||||||
|
entry.api_key = Some(key.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::config::save_config(&config)?;
|
||||||
|
println!("Saved Ollama configuration for provider '{provider}'.");
|
||||||
|
if config.privacy.encrypt_local_data {
|
||||||
|
println!("API key stored securely in the encrypted credential vault.");
|
||||||
|
} else {
|
||||||
|
println!("API key stored in plaintext configuration (encryption disabled).");
|
||||||
|
}
|
||||||
|
if !force_cloud_base_url && !base_changed {
|
||||||
|
println!(
|
||||||
|
"Local base URL preserved; cloud endpoint stored as {}.",
|
||||||
|
CLOUD_ENDPOINT_KEY
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn status(provider: String) -> Result<()> {
|
||||||
|
let provider = canonical_provider_name(&provider);
|
||||||
|
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
let manager = if config.privacy.encrypt_local_data {
|
||||||
|
Some(unlock_credential_manager(&config, storage.clone())?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let api_key = hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||||
|
{
|
||||||
|
let entry = ensure_provider_entry(&mut config, &provider);
|
||||||
|
entry.enabled = true;
|
||||||
|
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider_cfg = config
|
||||||
|
.provider(&provider)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||||
|
|
||||||
|
let endpoint =
|
||||||
|
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||||
|
let mut runtime_cfg = provider_cfg.clone();
|
||||||
|
runtime_cfg.base_url = Some(endpoint.clone());
|
||||||
|
runtime_cfg.extra.insert(
|
||||||
|
OLLAMA_MODE_KEY.to_string(),
|
||||||
|
Value::String("cloud".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||||
|
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||||
|
|
||||||
|
match ollama.health_check().await {
|
||||||
|
Ok(_) => {
|
||||||
|
println!("✓ Connected to {provider} ({})", endpoint);
|
||||||
|
if api_key.is_none() && config.privacy.encrypt_local_data {
|
||||||
|
println!(
|
||||||
|
"Warning: No API key stored; connection succeeded via environment variables."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
println!("✗ Failed to reach {provider}: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn models(provider: String) -> Result<()> {
|
||||||
|
let provider = canonical_provider_name(&provider);
|
||||||
|
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
let manager = if config.privacy.encrypt_local_data {
|
||||||
|
Some(unlock_credential_manager(&config, storage.clone())?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||||
|
|
||||||
|
{
|
||||||
|
let entry = ensure_provider_entry(&mut config, &provider);
|
||||||
|
entry.enabled = true;
|
||||||
|
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider_cfg = config
|
||||||
|
.provider(&provider)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||||
|
|
||||||
|
let endpoint =
|
||||||
|
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||||
|
let mut runtime_cfg = provider_cfg.clone();
|
||||||
|
runtime_cfg.base_url = Some(endpoint);
|
||||||
|
runtime_cfg.extra.insert(
|
||||||
|
OLLAMA_MODE_KEY.to_string(),
|
||||||
|
Value::String("cloud".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||||
|
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||||
|
|
||||||
|
match ollama.list_models().await {
|
||||||
|
Ok(models) => {
|
||||||
|
if models.is_empty() {
|
||||||
|
println!("No cloud models reported by '{}'.", provider);
|
||||||
|
} else {
|
||||||
|
println!("Models available via '{}':", provider);
|
||||||
|
for model in models {
|
||||||
|
if let Some(description) = &model.description {
|
||||||
|
println!(" - {} ({})", model.id, description);
|
||||||
|
} else {
|
||||||
|
println!(" - {}", model.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
bail!("Failed to list models: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn logout(provider: String) -> Result<()> {
|
||||||
|
let provider = canonical_provider_name(&provider);
|
||||||
|
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
|
||||||
|
if config.privacy.encrypt_local_data {
|
||||||
|
let manager = unlock_credential_manager(&config, storage.clone())?;
|
||||||
|
manager
|
||||||
|
.delete_credentials(OLLAMA_CLOUD_CREDENTIAL_ID)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||||
|
entry.api_key = None;
|
||||||
|
entry.enabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::config::save_config(&config)?;
|
||||||
|
println!("Cleared credentials for provider '{provider}'.");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_provider_entry<'a>(config: &'a mut Config, provider: &str) -> &'a mut ProviderConfig {
|
||||||
|
core_config::ensure_provider_config_mut(config, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_cloud_endpoint(entry: &mut ProviderConfig, endpoint: &str, force: bool) -> bool {
|
||||||
|
let normalized = normalize_endpoint(endpoint);
|
||||||
|
let previous_base = entry.base_url.clone();
|
||||||
|
entry.extra.insert(
|
||||||
|
CLOUD_ENDPOINT_KEY.to_string(),
|
||||||
|
Value::String(normalized.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
if entry.api_key_env.is_none() {
|
||||||
|
entry.api_key_env = Some(OLLAMA_CLOUD_API_KEY_ENV.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if force
|
||||||
|
|| entry
|
||||||
|
.base_url
|
||||||
|
.as_ref()
|
||||||
|
.map(|value| value.trim().is_empty())
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
entry.base_url = Some(normalized.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if force {
|
||||||
|
entry.enabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.base_url != previous_base
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_cloud_endpoint(cfg: &ProviderConfig) -> Option<String> {
|
||||||
|
if let Some(value) = cfg
|
||||||
|
.extra
|
||||||
|
.get(CLOUD_ENDPOINT_KEY)
|
||||||
|
.and_then(|value| value.as_str())
|
||||||
|
.map(normalize_endpoint)
|
||||||
|
{
|
||||||
|
return Some(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.base_url
|
||||||
|
.as_ref()
|
||||||
|
.map(|value| value.trim_end_matches('/').to_string())
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_endpoint(endpoint: &str) -> String {
|
||||||
|
let trimmed = endpoint.trim().trim_end_matches('/');
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
DEFAULT_CLOUD_ENDPOINT.to_string()
|
||||||
|
} else {
|
||||||
|
trimmed.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn canonical_provider_name(provider: &str) -> String {
|
||||||
|
let normalized = provider.trim().to_ascii_lowercase().replace('-', "_");
|
||||||
|
match normalized.as_str() {
|
||||||
|
"" => CLOUD_PROVIDER_KEY.to_string(),
|
||||||
|
"ollama" => CLOUD_PROVIDER_KEY.to_string(),
|
||||||
|
"ollama_cloud" => CLOUD_PROVIDER_KEY.to_string(),
|
||||||
|
value => value.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_env_var<K, V>(key: K, value: V)
|
||||||
|
where
|
||||||
|
K: AsRef<OsStr>,
|
||||||
|
V: AsRef<OsStr>,
|
||||||
|
{
|
||||||
|
// Safety: the CLI updates process-wide environment variables during startup while no
|
||||||
|
// other threads are mutating the environment.
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_env_if_missing(var: &str, value: &str) {
|
||||||
|
if std::env::var(var)
|
||||||
|
.map(|v| v.trim().is_empty())
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
set_env_var(var, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unlock_credential_manager(
|
||||||
|
config: &Config,
|
||||||
|
storage: Arc<StorageManager>,
|
||||||
|
) -> Result<Arc<CredentialManager>> {
|
||||||
|
if !config.privacy.encrypt_local_data {
|
||||||
|
bail!("Credential manager requested but encryption is disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
let secure_path = vault_path(&storage)?;
|
||||||
|
let handle = unlock_vault(&secure_path)?;
|
||||||
|
let master_key = Arc::new(handle.data.master_key.clone());
|
||||||
|
Ok(Arc::new(CredentialManager::new(
|
||||||
|
storage,
|
||||||
|
master_key.clone(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vault_path(storage: &StorageManager) -> Result<PathBuf> {
|
||||||
|
let base_dir = storage
|
||||||
|
.database_path()
|
||||||
|
.parent()
|
||||||
|
.map(|p| p.to_path_buf())
|
||||||
|
.or_else(dirs::data_local_dir)
|
||||||
|
.unwrap_or_else(|| PathBuf::from("."));
|
||||||
|
Ok(base_dir.join("encrypted_data.json"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unlock_vault(path: &Path) -> Result<encryption::VaultHandle> {
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
if path.exists() {
|
||||||
|
if let Some(password) = env::var("OWLEN_MASTER_PASSWORD")
|
||||||
|
.ok()
|
||||||
|
.map(|value| value.trim().to_string())
|
||||||
|
.filter(|password| !password.is_empty())
|
||||||
|
{
|
||||||
|
return encryption::unlock_with_password(path.to_path_buf(), &password)
|
||||||
|
.context("Failed to unlock vault with OWLEN_MASTER_PASSWORD");
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt in 0..3 {
|
||||||
|
let password = encryption::prompt_password("Enter master password: ")?;
|
||||||
|
match encryption::unlock_with_password(path.to_path_buf(), &password) {
|
||||||
|
Ok(handle) => {
|
||||||
|
set_env_var("OWLEN_MASTER_PASSWORD", password);
|
||||||
|
return Ok(handle);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("Failed to unlock vault: {err}");
|
||||||
|
if attempt == 2 {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Unable to unlock encrypted credential vault");
|
||||||
|
}
|
||||||
|
|
||||||
|
let handle = encryption::unlock_interactive(path.to_path_buf())?;
|
||||||
|
if env::var("OWLEN_MASTER_PASSWORD")
|
||||||
|
.map(|v| v.trim().is_empty())
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
let password = encryption::prompt_password("Cache master password for this session: ")?;
|
||||||
|
set_env_var("OWLEN_MASTER_PASSWORD", password);
|
||||||
|
}
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn hydrate_api_key(
|
||||||
|
config: &mut Config,
|
||||||
|
manager: Option<&Arc<CredentialManager>>,
|
||||||
|
) -> Result<Option<String>> {
|
||||||
|
let credentials = match manager {
|
||||||
|
Some(manager) => manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await?,
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(credentials) = credentials {
|
||||||
|
let key = credentials.api_key.trim().to_string();
|
||||||
|
if !key.is_empty() {
|
||||||
|
set_env_if_missing("OLLAMA_API_KEY", &key);
|
||||||
|
set_env_if_missing("OLLAMA_CLOUD_API_KEY", &key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cfg = core_config::ensure_provider_config_mut(config, CLOUD_PROVIDER_KEY);
|
||||||
|
configure_cloud_endpoint(cfg, &credentials.endpoint, false);
|
||||||
|
return Ok(Some(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(key) = config
|
||||||
|
.provider(CLOUD_PROVIDER_KEY)
|
||||||
|
.and_then(|cfg| cfg.api_key.as_ref())
|
||||||
|
.map(|value| value.trim())
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
{
|
||||||
|
set_env_if_missing("OLLAMA_API_KEY", key);
|
||||||
|
set_env_if_missing("OLLAMA_CLOUD_API_KEY", key);
|
||||||
|
return Ok(Some(key.to_string()));
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn load_runtime_credentials(
|
||||||
|
config: &mut Config,
|
||||||
|
storage: Arc<StorageManager>,
|
||||||
|
) -> Result<()> {
|
||||||
|
if config.privacy.encrypt_local_data {
|
||||||
|
let manager = unlock_credential_manager(config, storage.clone())?;
|
||||||
|
hydrate_api_key(config, Some(&manager)).await?;
|
||||||
|
} else {
|
||||||
|
hydrate_api_key(config, None).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn canonicalises_provider_names() {
|
||||||
|
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), CLOUD_PROVIDER_KEY);
|
||||||
|
assert_eq!(canonical_provider_name(" ollama-cloud"), CLOUD_PROVIDER_KEY);
|
||||||
|
assert_eq!(canonical_provider_name(""), CLOUD_PROVIDER_KEY);
|
||||||
|
}
|
||||||
|
}
|
||||||
4
crates/owlen-cli/src/commands/mod.rs
Normal file
4
crates/owlen-cli/src/commands/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
//! Command implementations for the `owlen` CLI.
|
||||||
|
|
||||||
|
pub mod cloud;
|
||||||
|
pub mod providers;
|
||||||
651
crates/owlen-cli/src/commands/providers.rs
Normal file
651
crates/owlen-cli/src/commands/providers.rs
Normal file
@@ -0,0 +1,651 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use clap::{Args, Subcommand};
|
||||||
|
use owlen_core::ProviderConfig;
|
||||||
|
use owlen_core::config::{self as core_config, Config};
|
||||||
|
use owlen_core::provider::{
|
||||||
|
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||||
|
};
|
||||||
|
use owlen_core::storage::StorageManager;
|
||||||
|
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||||
|
use owlen_tui::config as tui_config;
|
||||||
|
|
||||||
|
use super::cloud;
|
||||||
|
|
||||||
|
/// CLI subcommands for provider management.
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
pub enum ProvidersCommand {
|
||||||
|
/// List configured providers and their metadata.
|
||||||
|
List,
|
||||||
|
/// Run health checks against providers.
|
||||||
|
Status {
|
||||||
|
/// Optional provider identifier to check.
|
||||||
|
#[arg(value_name = "PROVIDER")]
|
||||||
|
provider: Option<String>,
|
||||||
|
},
|
||||||
|
/// Enable a provider in the configuration.
|
||||||
|
Enable {
|
||||||
|
/// Provider identifier to enable.
|
||||||
|
provider: String,
|
||||||
|
},
|
||||||
|
/// Disable a provider in the configuration.
|
||||||
|
Disable {
|
||||||
|
/// Provider identifier to disable.
|
||||||
|
provider: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Arguments for the `owlen models` command.
|
||||||
|
#[derive(Debug, Default, Args)]
|
||||||
|
pub struct ModelsArgs {
|
||||||
|
/// Restrict output to a specific provider.
|
||||||
|
#[arg(long)]
|
||||||
|
pub provider: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_providers_command(command: ProvidersCommand) -> Result<()> {
|
||||||
|
match command {
|
||||||
|
ProvidersCommand::List => list_providers(),
|
||||||
|
ProvidersCommand::Status { provider } => status_providers(provider.as_deref()).await,
|
||||||
|
ProvidersCommand::Enable { provider } => toggle_provider(&provider, true),
|
||||||
|
ProvidersCommand::Disable { provider } => toggle_provider(&provider, false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_models_command(args: ModelsArgs) -> Result<()> {
|
||||||
|
list_models(args.provider.as_deref()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_providers() -> Result<()> {
|
||||||
|
let config = tui_config::try_load_config().unwrap_or_default();
|
||||||
|
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||||
|
|
||||||
|
let mut rows = Vec::new();
|
||||||
|
for (id, cfg) in &config.providers {
|
||||||
|
let type_label = describe_provider_type(id, cfg);
|
||||||
|
let auth_label = describe_auth(cfg, requires_auth(id, cfg));
|
||||||
|
let enabled = if cfg.enabled { "yes" } else { "no" };
|
||||||
|
let default = if id == &default_provider { "*" } else { "" };
|
||||||
|
let base = cfg
|
||||||
|
.base_url
|
||||||
|
.as_ref()
|
||||||
|
.map(|value| value.trim().to_string())
|
||||||
|
.unwrap_or_else(|| "-".to_string());
|
||||||
|
|
||||||
|
rows.push(ProviderListRow {
|
||||||
|
id: id.to_string(),
|
||||||
|
type_label,
|
||||||
|
enabled: enabled.to_string(),
|
||||||
|
default: default.to_string(),
|
||||||
|
auth: auth_label,
|
||||||
|
base_url: base,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||||
|
|
||||||
|
let id_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.id.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(8)
|
||||||
|
.max("Provider".len());
|
||||||
|
let enabled_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.enabled.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(7)
|
||||||
|
.max("Enabled".len());
|
||||||
|
let default_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.default.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(7)
|
||||||
|
.max("Default".len());
|
||||||
|
let type_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.type_label.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(4)
|
||||||
|
.max("Type".len());
|
||||||
|
let auth_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.auth.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(4)
|
||||||
|
.max("Auth".len());
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} Base URL",
|
||||||
|
"Provider",
|
||||||
|
"Enabled",
|
||||||
|
"Default",
|
||||||
|
"Type",
|
||||||
|
"Auth",
|
||||||
|
id_width = id_width,
|
||||||
|
enabled_width = enabled_width,
|
||||||
|
default_width = default_width,
|
||||||
|
type_width = type_width,
|
||||||
|
auth_width = auth_width,
|
||||||
|
);
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
println!(
|
||||||
|
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} {}",
|
||||||
|
row.id,
|
||||||
|
row.enabled,
|
||||||
|
row.default,
|
||||||
|
row.type_label,
|
||||||
|
row.auth,
|
||||||
|
row.base_url,
|
||||||
|
id_width = id_width,
|
||||||
|
enabled_width = enabled_width,
|
||||||
|
default_width = default_width,
|
||||||
|
type_width = type_width,
|
||||||
|
auth_width = auth_width,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn status_providers(filter: Option<&str>) -> Result<()> {
|
||||||
|
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||||
|
let filter = filter.map(canonical_provider_id);
|
||||||
|
verify_provider_filter(&config, filter.as_deref())?;
|
||||||
|
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||||
|
|
||||||
|
let manager = ProviderManager::new(&config);
|
||||||
|
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||||
|
let health = manager.refresh_health().await;
|
||||||
|
|
||||||
|
let mut rows = Vec::new();
|
||||||
|
for record in records {
|
||||||
|
let status = health.get(&record.id).copied();
|
||||||
|
rows.push(ProviderStatusRow::from_record(record, status));
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||||
|
print_status_rows(&rows);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(filter: Option<&str>) -> Result<()> {
|
||||||
|
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||||
|
let filter = filter.map(canonical_provider_id);
|
||||||
|
verify_provider_filter(&config, filter.as_deref())?;
|
||||||
|
|
||||||
|
let storage = Arc::new(StorageManager::new().await?);
|
||||||
|
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||||
|
|
||||||
|
let manager = ProviderManager::new(&config);
|
||||||
|
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||||
|
let models = manager
|
||||||
|
.list_all_models()
|
||||||
|
.await
|
||||||
|
.map_err(|err| anyhow!(err))?;
|
||||||
|
let statuses = manager.provider_statuses().await;
|
||||||
|
|
||||||
|
print_models(records, models, statuses);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_provider_filter(config: &Config, filter: Option<&str>) -> Result<()> {
|
||||||
|
if let Some(filter) = filter
|
||||||
|
&& !config.providers.contains_key(filter)
|
||||||
|
{
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Provider '{}' is not defined in configuration.",
|
||||||
|
filter
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn toggle_provider(provider: &str, enable: bool) -> Result<()> {
|
||||||
|
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||||
|
let canonical = canonical_provider_id(provider);
|
||||||
|
if canonical.is_empty() {
|
||||||
|
return Err(anyhow!("Provider name cannot be empty."));
|
||||||
|
}
|
||||||
|
|
||||||
|
let previous_default = config.general.default_provider.clone();
|
||||||
|
let previous_fallback_enabled = config.providers.get("ollama_local").map(|cfg| cfg.enabled);
|
||||||
|
|
||||||
|
let previous_enabled;
|
||||||
|
{
|
||||||
|
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||||
|
previous_enabled = entry.enabled;
|
||||||
|
if previous_enabled == enable {
|
||||||
|
println!(
|
||||||
|
"Provider '{}' is already {}.",
|
||||||
|
canonical,
|
||||||
|
if enable { "enabled" } else { "disabled" }
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
entry.enabled = enable;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enable && config.general.default_provider == canonical {
|
||||||
|
if let Some(candidate) = choose_fallback_provider(&config, &canonical) {
|
||||||
|
config.general.default_provider = candidate.clone();
|
||||||
|
println!(
|
||||||
|
"Default provider set to '{}' because '{}' was disabled.",
|
||||||
|
candidate, canonical
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||||
|
entry.enabled = true;
|
||||||
|
config.general.default_provider = "ollama_local".to_string();
|
||||||
|
println!(
|
||||||
|
"Enabled 'ollama_local' and made it default because no other providers are active."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = config.validate() {
|
||||||
|
{
|
||||||
|
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||||
|
entry.enabled = previous_enabled;
|
||||||
|
}
|
||||||
|
config.general.default_provider = previous_default;
|
||||||
|
if let Some(enabled) = previous_fallback_enabled
|
||||||
|
&& let Some(entry) = config.providers.get_mut("ollama_local")
|
||||||
|
{
|
||||||
|
entry.enabled = enabled;
|
||||||
|
}
|
||||||
|
return Err(anyhow!(err));
|
||||||
|
}
|
||||||
|
|
||||||
|
tui_config::save_config(&config).map_err(|err| anyhow!(err))?;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{} provider '{}'.",
|
||||||
|
if enable { "Enabled" } else { "Disabled" },
|
||||||
|
canonical
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_fallback_provider(config: &Config, exclude: &str) -> Option<String> {
|
||||||
|
if exclude != "ollama_local"
|
||||||
|
&& let Some(cfg) = config.providers.get("ollama_local")
|
||||||
|
&& cfg.enabled
|
||||||
|
{
|
||||||
|
return Some("ollama_local".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut candidates: Vec<String> = config
|
||||||
|
.providers
|
||||||
|
.iter()
|
||||||
|
.filter(|(id, cfg)| cfg.enabled && id.as_str() != exclude)
|
||||||
|
.map(|(id, _)| id.clone())
|
||||||
|
.collect();
|
||||||
|
candidates.sort();
|
||||||
|
candidates.into_iter().next()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register_enabled_providers(
|
||||||
|
manager: &ProviderManager,
|
||||||
|
config: &Config,
|
||||||
|
filter: Option<&str>,
|
||||||
|
) -> Result<Vec<ProviderRecord>> {
|
||||||
|
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||||
|
let mut records = Vec::new();
|
||||||
|
|
||||||
|
for (id, cfg) in &config.providers {
|
||||||
|
if let Some(filter) = filter
|
||||||
|
&& id != filter
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut record = ProviderRecord::from_config(id, cfg, id == &default_provider);
|
||||||
|
if !cfg.enabled {
|
||||||
|
records.push(record);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match instantiate_provider(id, cfg) {
|
||||||
|
Ok(provider) => {
|
||||||
|
let metadata = provider.metadata().clone();
|
||||||
|
record.provider_type_label = provider_type_label(metadata.provider_type);
|
||||||
|
record.requires_auth = metadata.requires_auth;
|
||||||
|
record.metadata = Some(metadata);
|
||||||
|
manager.register_provider(provider).await;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
record.registration_error = Some(err.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
records.push(record);
|
||||||
|
}
|
||||||
|
|
||||||
|
records.sort_by(|a, b| a.id.cmp(&b.id));
|
||||||
|
Ok(records)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn instantiate_provider(id: &str, cfg: &ProviderConfig) -> Result<Arc<dyn ModelProvider>> {
|
||||||
|
let kind = cfg.provider_type.trim().to_ascii_lowercase();
|
||||||
|
if kind == "ollama" || id == "ollama_local" {
|
||||||
|
let provider = OllamaLocalProvider::new(cfg.base_url.clone(), None, None)
|
||||||
|
.map_err(|err| anyhow!(err))?;
|
||||||
|
Ok(Arc::new(provider))
|
||||||
|
} else if kind == "ollama_cloud" || id == "ollama_cloud" {
|
||||||
|
let provider = OllamaCloudProvider::new(cfg.base_url.clone(), cfg.api_key.clone(), None)
|
||||||
|
.map_err(|err| anyhow!(err))?;
|
||||||
|
Ok(Arc::new(provider))
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(
|
||||||
|
"Provider '{}' uses unsupported type '{}'.",
|
||||||
|
id,
|
||||||
|
if kind.is_empty() {
|
||||||
|
"unknown"
|
||||||
|
} else {
|
||||||
|
kind.as_str()
|
||||||
|
}
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe_provider_type(id: &str, cfg: &ProviderConfig) -> String {
|
||||||
|
if cfg.provider_type.trim().eq_ignore_ascii_case("ollama") || id.ends_with("_local") {
|
||||||
|
"Local".to_string()
|
||||||
|
} else if cfg
|
||||||
|
.provider_type
|
||||||
|
.trim()
|
||||||
|
.eq_ignore_ascii_case("ollama_cloud")
|
||||||
|
|| id.contains("cloud")
|
||||||
|
{
|
||||||
|
"Cloud".to_string()
|
||||||
|
} else {
|
||||||
|
"Custom".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn requires_auth(id: &str, cfg: &ProviderConfig) -> bool {
|
||||||
|
cfg.api_key.is_some()
|
||||||
|
|| cfg.api_key_env.is_some()
|
||||||
|
|| matches!(id, "ollama_cloud" | "openai" | "anthropic")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe_auth(cfg: &ProviderConfig, required: bool) -> String {
|
||||||
|
if let Some(env) = cfg
|
||||||
|
.api_key_env
|
||||||
|
.as_ref()
|
||||||
|
.map(|value| value.trim())
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
{
|
||||||
|
format!("env:{env}")
|
||||||
|
} else if cfg
|
||||||
|
.api_key
|
||||||
|
.as_ref()
|
||||||
|
.map(|value| !value.trim().is_empty())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
"config".to_string()
|
||||||
|
} else if required {
|
||||||
|
"required".to_string()
|
||||||
|
} else {
|
||||||
|
"-".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn canonical_provider_id(raw: &str) -> String {
|
||||||
|
let trimmed = raw.trim().to_ascii_lowercase();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return trimmed;
|
||||||
|
}
|
||||||
|
|
||||||
|
match trimmed.as_str() {
|
||||||
|
"ollama" | "ollama-local" => "ollama_local".to_string(),
|
||||||
|
"ollama_cloud" | "ollama-cloud" => "ollama_cloud".to_string(),
|
||||||
|
other => other.replace('-', "_"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_type_label(provider_type: ProviderType) -> String {
|
||||||
|
match provider_type {
|
||||||
|
ProviderType::Local => "Local".to_string(),
|
||||||
|
ProviderType::Cloud => "Cloud".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_status_strings(status: ProviderStatus) -> (&'static str, &'static str) {
|
||||||
|
match status {
|
||||||
|
ProviderStatus::Available => ("OK", "available"),
|
||||||
|
ProviderStatus::Unavailable => ("ERR", "unavailable"),
|
||||||
|
ProviderStatus::RequiresSetup => ("SETUP", "requires setup"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_status_rows(rows: &[ProviderStatusRow]) {
|
||||||
|
let id_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.id.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(8)
|
||||||
|
.max("Provider".len());
|
||||||
|
let type_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.provider_type.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(4)
|
||||||
|
.max("Type".len());
|
||||||
|
let status_width = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.indicator.len() + 1 + row.status_label.len())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(6)
|
||||||
|
.max("State".len());
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} Details",
|
||||||
|
"Provider",
|
||||||
|
"Def",
|
||||||
|
"Type",
|
||||||
|
"State",
|
||||||
|
id_width = id_width,
|
||||||
|
type_width = type_width,
|
||||||
|
status_width = status_width,
|
||||||
|
);
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
let def = if row.default_provider { "*" } else { "-" };
|
||||||
|
let details = row.detail.as_deref().unwrap_or("-");
|
||||||
|
println!(
|
||||||
|
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} {}",
|
||||||
|
row.id,
|
||||||
|
def,
|
||||||
|
row.provider_type,
|
||||||
|
format!("{} {}", row.indicator, row.status_label),
|
||||||
|
details,
|
||||||
|
id_width = id_width,
|
||||||
|
type_width = type_width,
|
||||||
|
status_width = status_width,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_models(
|
||||||
|
records: Vec<ProviderRecord>,
|
||||||
|
models: Vec<AnnotatedModelInfo>,
|
||||||
|
statuses: HashMap<String, ProviderStatus>,
|
||||||
|
) {
|
||||||
|
let mut grouped: HashMap<String, Vec<AnnotatedModelInfo>> = HashMap::new();
|
||||||
|
for info in models {
|
||||||
|
grouped
|
||||||
|
.entry(info.provider_id.clone())
|
||||||
|
.or_default()
|
||||||
|
.push(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
for record in records {
|
||||||
|
let status = statuses.get(&record.id).copied().or_else(|| {
|
||||||
|
if record.metadata.is_some() && record.registration_error.is_none() && record.enabled {
|
||||||
|
Some(ProviderStatus::Unavailable)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let (indicator, label, status_value) = if !record.enabled {
|
||||||
|
("-", "disabled", None)
|
||||||
|
} else if record.registration_error.is_some() {
|
||||||
|
("ERR", "error", None)
|
||||||
|
} else if let Some(status) = status {
|
||||||
|
let (indicator, label) = provider_status_strings(status);
|
||||||
|
(indicator, label, Some(status))
|
||||||
|
} else {
|
||||||
|
("?", "unknown", None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let title = if record.default_provider {
|
||||||
|
format!("{} (default)", record.id)
|
||||||
|
} else {
|
||||||
|
record.id.clone()
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"{} {} [{}] {}",
|
||||||
|
indicator, title, record.provider_type_label, label
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(err) = &record.registration_error {
|
||||||
|
println!(" error: {}", err);
|
||||||
|
println!();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !record.enabled {
|
||||||
|
println!(" provider disabled");
|
||||||
|
println!();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(entries) = grouped.get(&record.id) {
|
||||||
|
let mut entries = entries.clone();
|
||||||
|
entries.sort_by(|a, b| a.model.name.cmp(&b.model.name));
|
||||||
|
if entries.is_empty() {
|
||||||
|
println!(" (no models reported)");
|
||||||
|
} else {
|
||||||
|
for entry in entries {
|
||||||
|
let mut line = format!(" - {}", entry.model.name);
|
||||||
|
if let Some(description) = &entry.model.description
|
||||||
|
&& !description.trim().is_empty()
|
||||||
|
{
|
||||||
|
line.push_str(&format!(" — {}", description.trim()));
|
||||||
|
}
|
||||||
|
println!("{}", line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!(" (no models reported)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ProviderStatus::RequiresSetup) = status_value
|
||||||
|
&& record.requires_auth
|
||||||
|
{
|
||||||
|
println!(" configure provider credentials or API key");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProviderListRow {
|
||||||
|
id: String,
|
||||||
|
type_label: String,
|
||||||
|
enabled: String,
|
||||||
|
default: String,
|
||||||
|
auth: String,
|
||||||
|
base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProviderRecord {
|
||||||
|
id: String,
|
||||||
|
enabled: bool,
|
||||||
|
default_provider: bool,
|
||||||
|
provider_type_label: String,
|
||||||
|
requires_auth: bool,
|
||||||
|
registration_error: Option<String>,
|
||||||
|
metadata: Option<owlen_core::provider::ProviderMetadata>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderRecord {
|
||||||
|
fn from_config(id: &str, cfg: &ProviderConfig, default_provider: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
id: id.to_string(),
|
||||||
|
enabled: cfg.enabled,
|
||||||
|
default_provider,
|
||||||
|
provider_type_label: describe_provider_type(id, cfg),
|
||||||
|
requires_auth: requires_auth(id, cfg),
|
||||||
|
registration_error: None,
|
||||||
|
metadata: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProviderStatusRow {
|
||||||
|
id: String,
|
||||||
|
provider_type: String,
|
||||||
|
default_provider: bool,
|
||||||
|
indicator: String,
|
||||||
|
status_label: String,
|
||||||
|
detail: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderStatusRow {
|
||||||
|
fn from_record(record: ProviderRecord, status: Option<ProviderStatus>) -> Self {
|
||||||
|
if !record.enabled {
|
||||||
|
return Self {
|
||||||
|
id: record.id,
|
||||||
|
provider_type: record.provider_type_label,
|
||||||
|
default_provider: record.default_provider,
|
||||||
|
indicator: "-".to_string(),
|
||||||
|
status_label: "disabled".to_string(),
|
||||||
|
detail: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(err) = record.registration_error {
|
||||||
|
return Self {
|
||||||
|
id: record.id,
|
||||||
|
provider_type: record.provider_type_label,
|
||||||
|
default_provider: record.default_provider,
|
||||||
|
indicator: "ERR".to_string(),
|
||||||
|
status_label: "error".to_string(),
|
||||||
|
detail: Some(err),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(status) = status {
|
||||||
|
let (indicator, label) = provider_status_strings(status);
|
||||||
|
return Self {
|
||||||
|
id: record.id,
|
||||||
|
provider_type: record.provider_type_label,
|
||||||
|
default_provider: record.default_provider,
|
||||||
|
indicator: indicator.to_string(),
|
||||||
|
status_label: label.to_string(),
|
||||||
|
detail: if matches!(status, ProviderStatus::RequiresSetup) && record.requires_auth {
|
||||||
|
Some("credentials required".to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
id: record.id,
|
||||||
|
provider_type: record.provider_type_label,
|
||||||
|
default_provider: record.default_provider,
|
||||||
|
indicator: "?".to_string(),
|
||||||
|
status_label: "unknown".to_string(),
|
||||||
|
detail: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,136 +1,228 @@
|
|||||||
|
#![allow(clippy::collapsible_if)] // TODO: Remove once Rust 2024 let-chains are available
|
||||||
|
|
||||||
//! OWLEN CLI - Chat TUI client
|
//! OWLEN CLI - Chat TUI client
|
||||||
|
|
||||||
|
mod bootstrap;
|
||||||
|
mod commands;
|
||||||
|
mod mcp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use owlen_core::{session::SessionController, storage::StorageManager};
|
use clap::{Parser, Subcommand};
|
||||||
use owlen_ollama::OllamaProvider;
|
use commands::{
|
||||||
use owlen_tui::tui_controller::{TuiController, TuiRequest};
|
cloud::{CloudCommand, run_cloud_command},
|
||||||
use owlen_tui::{config, ui, AppState, ChatApp, Event, EventHandler, SessionEvent};
|
providers::{ModelsArgs, ProvidersCommand, run_models_command, run_providers_command},
|
||||||
use std::io;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
|
|
||||||
use crossterm::{
|
|
||||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
|
||||||
execute,
|
|
||||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
|
||||||
};
|
};
|
||||||
use ratatui::{prelude::CrosstermBackend, Terminal};
|
use mcp::{McpCommand, run_mcp_command};
|
||||||
|
use owlen_core::config as core_config;
|
||||||
|
use owlen_core::config::McpMode;
|
||||||
|
use owlen_core::mode::Mode;
|
||||||
|
use owlen_tui::config;
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
/// Owlen - Terminal UI for LLM chat
|
||||||
async fn main() -> Result<()> {
|
#[derive(Parser, Debug)]
|
||||||
// (imports completed above)
|
#[command(name = "owlen")]
|
||||||
|
#[command(about = "Terminal UI for LLM chat via MCP", long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Start in code mode (enables all tools)
|
||||||
|
#[arg(long, short = 'c')]
|
||||||
|
code: bool,
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<OwlenCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
// (main logic starts below)
|
#[derive(Debug, Subcommand)]
|
||||||
// Set auto-consent for TUI mode to prevent blocking stdin reads
|
enum OwlenCommand {
|
||||||
std::env::set_var("OWLEN_AUTO_CONSENT", "1");
|
/// Inspect or upgrade configuration files
|
||||||
|
#[command(subcommand)]
|
||||||
|
Config(ConfigCommand),
|
||||||
|
/// Manage Ollama Cloud credentials
|
||||||
|
#[command(subcommand)]
|
||||||
|
Cloud(CloudCommand),
|
||||||
|
/// Manage model providers
|
||||||
|
#[command(subcommand)]
|
||||||
|
Providers(ProvidersCommand),
|
||||||
|
/// List models exposed by configured providers
|
||||||
|
Models(ModelsArgs),
|
||||||
|
/// Manage MCP server registrations
|
||||||
|
#[command(subcommand)]
|
||||||
|
Mcp(McpCommand),
|
||||||
|
/// Show manual steps for updating Owlen to the latest revision
|
||||||
|
Upgrade,
|
||||||
|
}
|
||||||
|
|
||||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
#[derive(Debug, Subcommand)]
|
||||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
enum ConfigCommand {
|
||||||
|
/// Automatically upgrade legacy configuration values and ensure validity
|
||||||
|
Doctor,
|
||||||
|
/// Print the resolved configuration file path
|
||||||
|
Path,
|
||||||
|
}
|
||||||
|
|
||||||
// Load configuration (or fall back to defaults) for the session controller.
|
async fn run_command(command: OwlenCommand) -> Result<()> {
|
||||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
match command {
|
||||||
// Disable encryption for CLI to avoid password prompts in this environment.
|
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
|
||||||
cfg.privacy.encrypt_local_data = false;
|
OwlenCommand::Cloud(cloud_cmd) => run_cloud_command(cloud_cmd).await,
|
||||||
// Determine provider configuration
|
OwlenCommand::Providers(provider_cmd) => run_providers_command(provider_cmd).await,
|
||||||
let provider_name = cfg.general.default_provider.clone();
|
OwlenCommand::Models(args) => run_models_command(args).await,
|
||||||
let provider_cfg = config::ensure_provider_config(&mut cfg, &provider_name).clone();
|
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
|
||||||
let provider_type = provider_cfg.provider_type.to_ascii_lowercase();
|
OwlenCommand::Upgrade => {
|
||||||
if provider_type != "ollama" && provider_type != "ollama-cloud" {
|
println!(
|
||||||
anyhow::bail!(
|
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
|
||||||
"Unsupported provider type '{}' configured for provider '{}'",
|
);
|
||||||
provider_cfg.provider_type,
|
println!(
|
||||||
provider_name,
|
"If you installed from the AUR, use your package manager (e.g., yay -S owlen-git)."
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let provider = Arc::new(OllamaProvider::from_config(
|
}
|
||||||
&provider_cfg,
|
|
||||||
Some(&cfg.general),
|
|
||||||
)?);
|
|
||||||
let storage = Arc::new(StorageManager::new().await?);
|
|
||||||
let controller =
|
|
||||||
SessionController::new(provider, cfg, storage.clone(), tui_controller, false).await?;
|
|
||||||
let (mut app, mut session_rx) = ChatApp::new(controller).await?;
|
|
||||||
app.initialize_models().await?;
|
|
||||||
|
|
||||||
// Event infrastructure
|
fn run_config_command(command: ConfigCommand) -> Result<()> {
|
||||||
let cancellation_token = CancellationToken::new();
|
match command {
|
||||||
let (event_tx, event_rx) = mpsc::unbounded_channel();
|
ConfigCommand::Doctor => run_config_doctor(),
|
||||||
let event_handler = EventHandler::new(event_tx, cancellation_token.clone());
|
ConfigCommand::Path => {
|
||||||
let event_handle = tokio::spawn(async move { event_handler.run().await });
|
let path = core_config::default_config_path();
|
||||||
|
println!("{}", path.display());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Terminal setup
|
fn run_config_doctor() -> Result<()> {
|
||||||
enable_raw_mode()?;
|
let config_path = core_config::default_config_path();
|
||||||
let mut stdout = io::stdout();
|
let existed = config_path.exists();
|
||||||
execute!(
|
let mut config = config::try_load_config().unwrap_or_default();
|
||||||
stdout,
|
let _ = config.refresh_mcp_servers(None);
|
||||||
EnterAlternateScreen,
|
let mut changes = Vec::new();
|
||||||
EnableMouseCapture,
|
|
||||||
EnableBracketedPaste
|
|
||||||
)?;
|
|
||||||
let backend = CrosstermBackend::new(stdout);
|
|
||||||
let mut terminal = Terminal::new(backend)?;
|
|
||||||
|
|
||||||
let result = run_app(&mut terminal, &mut app, event_rx, &mut session_rx).await;
|
if !existed {
|
||||||
|
changes.push("created configuration file from defaults".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Shutdown
|
if config.provider(&config.general.default_provider).is_none() {
|
||||||
cancellation_token.cancel();
|
config.general.default_provider = "ollama_local".to_string();
|
||||||
event_handle.await?;
|
changes.push("default provider missing; reset to 'ollama_local'".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Persist configuration updates (e.g., selected model)
|
for key in ["ollama_local", "ollama_cloud", "openai", "anthropic"] {
|
||||||
config::save_config(&app.config())?;
|
if !config.providers.contains_key(key) {
|
||||||
|
core_config::ensure_provider_config_mut(&mut config, key);
|
||||||
|
changes.push(format!("added default configuration for provider '{key}'"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
disable_raw_mode()?;
|
if let Some(entry) = config.providers.get_mut("ollama_local") {
|
||||||
execute!(
|
if entry.provider_type.trim().is_empty() || entry.provider_type != "ollama" {
|
||||||
terminal.backend_mut(),
|
entry.provider_type = "ollama".to_string();
|
||||||
LeaveAlternateScreen,
|
changes.push("normalised providers.ollama_local.provider_type to 'ollama'".to_string());
|
||||||
DisableMouseCapture,
|
}
|
||||||
DisableBracketedPaste
|
}
|
||||||
)?;
|
|
||||||
terminal.show_cursor()?;
|
|
||||||
|
|
||||||
if let Err(err) = result {
|
let mut ensure_default_enabled = true;
|
||||||
println!("{err:?}");
|
|
||||||
|
if !config.providers.values().any(|cfg| cfg.enabled) {
|
||||||
|
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||||
|
if !entry.enabled {
|
||||||
|
entry.enabled = true;
|
||||||
|
changes.push("no providers were enabled; enabled 'ollama_local'".to_string());
|
||||||
|
}
|
||||||
|
if config.general.default_provider != "ollama_local" {
|
||||||
|
config.general.default_provider = "ollama_local".to_string();
|
||||||
|
changes.push(
|
||||||
|
"default provider reset to 'ollama_local' because no providers were enabled"
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
ensure_default_enabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ensure_default_enabled {
|
||||||
|
let default_id = config.general.default_provider.clone();
|
||||||
|
if let Some(default_cfg) = config.providers.get(&default_id) {
|
||||||
|
if !default_cfg.enabled {
|
||||||
|
if let Some(new_default) = config
|
||||||
|
.providers
|
||||||
|
.iter()
|
||||||
|
.filter(|(id, cfg)| cfg.enabled && *id != &default_id)
|
||||||
|
.map(|(id, _)| id.clone())
|
||||||
|
.min()
|
||||||
|
{
|
||||||
|
config.general.default_provider = new_default.clone();
|
||||||
|
changes.push(format!(
|
||||||
|
"default provider '{default_id}' was disabled; switched default to '{new_default}'"
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
let entry =
|
||||||
|
core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||||
|
if !entry.enabled {
|
||||||
|
entry.enabled = true;
|
||||||
|
changes.push(
|
||||||
|
"enabled 'ollama_local' because default provider was disabled"
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if config.general.default_provider != "ollama_local" {
|
||||||
|
config.general.default_provider = "ollama_local".to_string();
|
||||||
|
changes.push(
|
||||||
|
"default provider reset to 'ollama_local' because previous default was disabled"
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match config.mcp.mode {
|
||||||
|
McpMode::Legacy => {
|
||||||
|
config.mcp.mode = McpMode::LocalOnly;
|
||||||
|
config.mcp.warn_on_legacy = true;
|
||||||
|
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
|
||||||
|
}
|
||||||
|
McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
|
||||||
|
config.mcp.mode = McpMode::RemotePreferred;
|
||||||
|
config.mcp.allow_fallback = true;
|
||||||
|
changes.push(
|
||||||
|
"downgraded remote-only configuration to remote_preferred because no servers are defined"
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
McpMode::RemotePreferred
|
||||||
|
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
|
||||||
|
{
|
||||||
|
config.mcp.allow_fallback = true;
|
||||||
|
changes.push(
|
||||||
|
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.validate()?;
|
||||||
|
config::save_config(&config)?;
|
||||||
|
|
||||||
|
if changes.is_empty() {
|
||||||
|
println!(
|
||||||
|
"Configuration already up to date: {}",
|
||||||
|
config_path.display()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!("Updated {}:", config_path.display());
|
||||||
|
for change in changes {
|
||||||
|
println!(" - {change}");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_app(
|
#[tokio::main(flavor = "multi_thread")]
|
||||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
async fn main() -> Result<()> {
|
||||||
app: &mut ChatApp,
|
// Parse command-line arguments
|
||||||
mut event_rx: mpsc::UnboundedReceiver<Event>,
|
let Args { code, command } = Args::parse();
|
||||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
if let Some(command) = command {
|
||||||
) -> Result<()> {
|
return run_command(command).await;
|
||||||
loop {
|
|
||||||
// Advance loading animation frame
|
|
||||||
app.advance_loading_animation();
|
|
||||||
|
|
||||||
terminal.draw(|f| ui::render_chat(f, app))?;
|
|
||||||
|
|
||||||
// Process any pending LLM requests AFTER UI has been drawn
|
|
||||||
if let Err(e) = app.process_pending_llm_request().await {
|
|
||||||
eprintln!("Error processing LLM request: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process any pending tool executions AFTER UI has been drawn
|
|
||||||
if let Err(e) = app.process_pending_tool_execution().await {
|
|
||||||
eprintln!("Error processing tool execution: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
Some(event) = event_rx.recv() => {
|
|
||||||
if let AppState::Quit = app.handle_event(event).await? {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(session_event) = session_rx.recv() => {
|
|
||||||
app.handle_session_event(session_event)?;
|
|
||||||
}
|
|
||||||
// Add a timeout to keep the animation going even when there are no events
|
|
||||||
_ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
|
|
||||||
// This will cause the loop to continue and advance the animation
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
let initial_mode = if code { Mode::Code } else { Mode::Chat };
|
||||||
|
bootstrap::launch(initial_mode).await
|
||||||
}
|
}
|
||||||
|
|||||||
259
crates/owlen-cli/src/mcp.rs
Normal file
259
crates/owlen-cli/src/mcp.rs
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use clap::{Args, Subcommand, ValueEnum};
|
||||||
|
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
|
||||||
|
use owlen_tui::config as tui_config;
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
pub enum McpCommand {
|
||||||
|
/// Add or update an MCP server in the selected scope
|
||||||
|
Add(AddArgs),
|
||||||
|
/// List MCP servers across scopes
|
||||||
|
List(ListArgs),
|
||||||
|
/// Remove an MCP server from a scope
|
||||||
|
Remove(RemoveArgs),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
|
||||||
|
match command {
|
||||||
|
McpCommand::Add(args) => handle_add(args),
|
||||||
|
McpCommand::List(args) => handle_list(args),
|
||||||
|
McpCommand::Remove(args) => handle_remove(args),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||||
|
pub enum ScopeArg {
|
||||||
|
User,
|
||||||
|
#[default]
|
||||||
|
Project,
|
||||||
|
Local,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ScopeArg> for McpConfigScope {
|
||||||
|
fn from(value: ScopeArg) -> Self {
|
||||||
|
match value {
|
||||||
|
ScopeArg::User => McpConfigScope::User,
|
||||||
|
ScopeArg::Project => McpConfigScope::Project,
|
||||||
|
ScopeArg::Local => McpConfigScope::Local,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Args)]
|
||||||
|
pub struct AddArgs {
|
||||||
|
/// Logical name used to reference the server
|
||||||
|
pub name: String,
|
||||||
|
/// Command or endpoint invoked for the server
|
||||||
|
pub command: String,
|
||||||
|
/// Transport mechanism (stdio, http, websocket)
|
||||||
|
#[arg(long, default_value = "stdio")]
|
||||||
|
pub transport: String,
|
||||||
|
/// Configuration scope to write the server into
|
||||||
|
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
|
||||||
|
pub scope: ScopeArg,
|
||||||
|
/// Environment variables (KEY=VALUE) passed to the server process
|
||||||
|
#[arg(long = "env")]
|
||||||
|
pub env: Vec<String>,
|
||||||
|
/// Additional arguments appended when launching the server
|
||||||
|
#[arg(trailing_var_arg = true, value_name = "ARG")]
|
||||||
|
pub args: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Args, Default)]
|
||||||
|
pub struct ListArgs {
|
||||||
|
/// Restrict output to a specific configuration scope
|
||||||
|
#[arg(long, value_enum)]
|
||||||
|
pub scope: Option<ScopeArg>,
|
||||||
|
/// Display only the effective servers (after precedence resolution)
|
||||||
|
#[arg(long)]
|
||||||
|
pub effective_only: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Args)]
|
||||||
|
pub struct RemoveArgs {
|
||||||
|
/// Name of the server to remove
|
||||||
|
pub name: String,
|
||||||
|
/// Optional explicit scope to remove from
|
||||||
|
#[arg(long, value_enum)]
|
||||||
|
pub scope: Option<ScopeArg>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_add(args: AddArgs) -> Result<()> {
|
||||||
|
let mut config = load_config()?;
|
||||||
|
let scope: McpConfigScope = args.scope.into();
|
||||||
|
let mut env_map = HashMap::new();
|
||||||
|
for pair in &args.env {
|
||||||
|
let (key, value) = pair
|
||||||
|
.split_once('=')
|
||||||
|
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
|
||||||
|
if key.trim().is_empty() {
|
||||||
|
return Err(anyhow!("Environment variable name cannot be empty"));
|
||||||
|
}
|
||||||
|
env_map.insert(key.trim().to_string(), value.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let server = McpServerConfig {
|
||||||
|
name: args.name.clone(),
|
||||||
|
command: args.command.clone(),
|
||||||
|
args: args.args.clone(),
|
||||||
|
transport: args.transport.to_lowercase(),
|
||||||
|
env: env_map,
|
||||||
|
oauth: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
config.add_mcp_server(scope, server.clone(), None)?;
|
||||||
|
if matches!(scope, McpConfigScope::User) {
|
||||||
|
tui_config::save_config(&config)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||||
|
println!(
|
||||||
|
"Registered MCP server '{}' in {} scope ({})",
|
||||||
|
server.name,
|
||||||
|
scope,
|
||||||
|
path.display()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!(
|
||||||
|
"Registered MCP server '{}' in {} scope.",
|
||||||
|
server.name, scope
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_list(args: ListArgs) -> Result<()> {
|
||||||
|
let mut config = load_config()?;
|
||||||
|
config.refresh_mcp_servers(None)?;
|
||||||
|
|
||||||
|
let scoped = config.scoped_mcp_servers();
|
||||||
|
if scoped.is_empty() {
|
||||||
|
println!("No MCP servers configured.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let filter_scope = args.scope.map(|scope| scope.into());
|
||||||
|
let effective = config.effective_mcp_servers();
|
||||||
|
let mut active = HashSet::new();
|
||||||
|
for server in effective {
|
||||||
|
active.insert((
|
||||||
|
server.name.clone(),
|
||||||
|
server.command.clone(),
|
||||||
|
server.transport.to_lowercase(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{:<2} {:<8} {:<20} {:<10} Command",
|
||||||
|
"", "Scope", "Name", "Transport"
|
||||||
|
);
|
||||||
|
for entry in scoped {
|
||||||
|
if filter_scope
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = format_command_line(&entry.config.command, &entry.config.args);
|
||||||
|
let key = (
|
||||||
|
entry.config.name.clone(),
|
||||||
|
entry.config.command.clone(),
|
||||||
|
entry.config.transport.to_lowercase(),
|
||||||
|
);
|
||||||
|
let marker = if active.contains(&key) { "*" } else { " " };
|
||||||
|
|
||||||
|
if args.effective_only && marker != "*" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{} {:<8} {:<20} {:<10} {}",
|
||||||
|
marker, entry.scope, entry.config.name, entry.config.transport, payload
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let scoped_resources = config.scoped_mcp_resources();
|
||||||
|
if !scoped_resources.is_empty() {
|
||||||
|
println!();
|
||||||
|
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
|
||||||
|
let effective_keys: HashSet<(String, String)> = config
|
||||||
|
.effective_mcp_resources()
|
||||||
|
.iter()
|
||||||
|
.map(|res| (res.server.clone(), res.uri.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for entry in scoped_resources {
|
||||||
|
if filter_scope
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||||
|
let marker = if effective_keys.contains(&key) {
|
||||||
|
"*"
|
||||||
|
} else {
|
||||||
|
" "
|
||||||
|
};
|
||||||
|
if args.effective_only && marker != "*" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
|
||||||
|
let title = entry.config.title.as_deref().unwrap_or("—");
|
||||||
|
|
||||||
|
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_remove(args: RemoveArgs) -> Result<()> {
|
||||||
|
let mut config = load_config()?;
|
||||||
|
let scope_hint = args.scope.map(|scope| scope.into());
|
||||||
|
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Some(scope) => {
|
||||||
|
if matches!(scope, McpConfigScope::User) {
|
||||||
|
tui_config::save_config(&config)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||||
|
println!(
|
||||||
|
"Removed MCP server '{}' from {} scope ({})",
|
||||||
|
args.name,
|
||||||
|
scope,
|
||||||
|
path.display()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
println!("No MCP server named '{}' was found.", args.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_config() -> Result<Config> {
|
||||||
|
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||||
|
config.refresh_mcp_servers(None)?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_command_line(command: &str, args: &[String]) -> String {
|
||||||
|
if args.is_empty() {
|
||||||
|
command.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{} {}", command, args.join(" "))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@
|
|||||||
|
|
||||||
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
|
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
|
||||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||||
use owlen_ollama::OllamaProvider;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -39,7 +38,7 @@ async fn test_react_parsing_tool_call() {
|
|||||||
async fn test_react_parsing_final_answer() {
|
async fn test_react_parsing_final_answer() {
|
||||||
let executor = create_test_executor();
|
let executor = create_test_executor();
|
||||||
|
|
||||||
let text = "THOUGHT: I have enough information now\nACTION: final_answer\nACTION_INPUT: The answer is 42\n";
|
let text = "THOUGHT: I have enough information now\nFINAL_ANSWER: The answer is 42\n";
|
||||||
|
|
||||||
let result = executor.parse_response(text);
|
let result = executor.parse_response(text);
|
||||||
|
|
||||||
@@ -72,21 +71,20 @@ async fn test_react_parsing_with_multiline_thought() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore] // Requires Ollama to be running
|
#[ignore] // Requires MCP LLM server to be running
|
||||||
async fn test_agent_single_tool_scenario() {
|
async fn test_agent_single_tool_scenario() {
|
||||||
// This test requires a running Ollama instance and MCP server
|
// This test requires a running MCP LLM server (which wraps Ollama)
|
||||||
let provider = Arc::new(OllamaProvider::new("http://localhost:11434").unwrap());
|
let provider = Arc::new(RemoteMcpClient::new().unwrap());
|
||||||
let mcp_client = Arc::new(RemoteMcpClient::new().unwrap());
|
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||||
|
|
||||||
let config = AgentConfig {
|
let config = AgentConfig {
|
||||||
max_iterations: 5,
|
max_iterations: 5,
|
||||||
model: "llama3.2".to_string(),
|
model: "llama3.2".to_string(),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
max_tool_calls: 10,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let executor = AgentExecutor::new(provider, mcp_client, config, None);
|
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||||
|
|
||||||
// Simple query that should complete in one tool call
|
// Simple query that should complete in one tool call
|
||||||
let result = executor
|
let result = executor
|
||||||
@@ -94,9 +92,12 @@ async fn test_agent_single_tool_scenario() {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(answer) => {
|
Ok(agent_result) => {
|
||||||
assert!(!answer.is_empty(), "Answer should not be empty");
|
assert!(
|
||||||
println!("Agent answer: {}", answer);
|
!agent_result.answer.is_empty(),
|
||||||
|
"Answer should not be empty"
|
||||||
|
);
|
||||||
|
println!("Agent answer: {}", agent_result.answer);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// It's okay if this fails due to LLM not following format
|
// It's okay if this fails due to LLM not following format
|
||||||
@@ -109,18 +110,17 @@ async fn test_agent_single_tool_scenario() {
|
|||||||
#[ignore] // Requires Ollama to be running
|
#[ignore] // Requires Ollama to be running
|
||||||
async fn test_agent_multi_step_workflow() {
|
async fn test_agent_multi_step_workflow() {
|
||||||
// Test a query that requires multiple tool calls
|
// Test a query that requires multiple tool calls
|
||||||
let provider = Arc::new(OllamaProvider::new("http://localhost:11434").unwrap());
|
let provider = Arc::new(RemoteMcpClient::new().unwrap());
|
||||||
let mcp_client = Arc::new(RemoteMcpClient::new().unwrap());
|
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||||
|
|
||||||
let config = AgentConfig {
|
let config = AgentConfig {
|
||||||
max_iterations: 10,
|
max_iterations: 10,
|
||||||
model: "llama3.2".to_string(),
|
model: "llama3.2".to_string(),
|
||||||
temperature: Some(0.5), // Lower temperature for more consistent behavior
|
temperature: Some(0.5), // Lower temperature for more consistent behavior
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
max_tool_calls: 20,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let executor = AgentExecutor::new(provider, mcp_client, config, None);
|
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||||
|
|
||||||
// Query requiring multiple steps: list -> read -> analyze
|
// Query requiring multiple steps: list -> read -> analyze
|
||||||
let result = executor
|
let result = executor
|
||||||
@@ -128,9 +128,9 @@ async fn test_agent_multi_step_workflow() {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(answer) => {
|
Ok(agent_result) => {
|
||||||
assert!(!answer.is_empty());
|
assert!(!agent_result.answer.is_empty());
|
||||||
println!("Multi-step answer: {}", answer);
|
println!("Multi-step answer: {:?}", agent_result);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("Multi-step test skipped: {}", e);
|
println!("Multi-step test skipped: {}", e);
|
||||||
@@ -141,18 +141,17 @@ async fn test_agent_multi_step_workflow() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore] // Requires Ollama
|
#[ignore] // Requires Ollama
|
||||||
async fn test_agent_iteration_limit() {
|
async fn test_agent_iteration_limit() {
|
||||||
let provider = Arc::new(OllamaProvider::new("http://localhost:11434").unwrap());
|
let provider = Arc::new(RemoteMcpClient::new().unwrap());
|
||||||
let mcp_client = Arc::new(RemoteMcpClient::new().unwrap());
|
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||||
|
|
||||||
let config = AgentConfig {
|
let config = AgentConfig {
|
||||||
max_iterations: 2, // Very low limit to test enforcement
|
max_iterations: 2, // Very low limit to test enforcement
|
||||||
model: "llama3.2".to_string(),
|
model: "llama3.2".to_string(),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
max_tool_calls: 5,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let executor = AgentExecutor::new(provider, mcp_client, config, None);
|
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||||
|
|
||||||
// Complex query that would require many iterations
|
// Complex query that would require many iterations
|
||||||
let result = executor
|
let result = executor
|
||||||
@@ -183,18 +182,17 @@ async fn test_agent_iteration_limit() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore] // Requires Ollama
|
#[ignore] // Requires Ollama
|
||||||
async fn test_agent_tool_budget_enforcement() {
|
async fn test_agent_tool_budget_enforcement() {
|
||||||
let provider = Arc::new(OllamaProvider::new("http://localhost:11434").unwrap());
|
let provider = Arc::new(RemoteMcpClient::new().unwrap());
|
||||||
let mcp_client = Arc::new(RemoteMcpClient::new().unwrap());
|
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||||
|
|
||||||
let config = AgentConfig {
|
let config = AgentConfig {
|
||||||
max_iterations: 20,
|
max_iterations: 3, // Very low iteration limit to enforce budget
|
||||||
model: "llama3.2".to_string(),
|
model: "llama3.2".to_string(),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
max_tool_calls: 3, // Very low tool call budget
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let executor = AgentExecutor::new(provider, mcp_client, config, None);
|
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||||
|
|
||||||
// Query that would require many tool calls
|
// Query that would require many tool calls
|
||||||
let result = executor
|
let result = executor
|
||||||
@@ -224,12 +222,9 @@ async fn test_agent_tool_budget_enforcement() {
|
|||||||
// Helper function to create a test executor
|
// Helper function to create a test executor
|
||||||
// For parsing tests, we don't need a real connection
|
// For parsing tests, we don't need a real connection
|
||||||
fn create_test_executor() -> AgentExecutor {
|
fn create_test_executor() -> AgentExecutor {
|
||||||
// Create dummy instances - the parse_response method doesn't actually use them
|
|
||||||
let provider = Arc::new(OllamaProvider::new("http://localhost:11434").unwrap());
|
|
||||||
|
|
||||||
// For parsing tests, we can accept the error from RemoteMcpClient::new()
|
// For parsing tests, we can accept the error from RemoteMcpClient::new()
|
||||||
// since we're only testing parse_response which doesn't use the MCP client
|
// since we're only testing parse_response which doesn't use the MCP client
|
||||||
let mcp_client = match RemoteMcpClient::new() {
|
let provider = match RemoteMcpClient::new() {
|
||||||
Ok(client) => Arc::new(client),
|
Ok(client) => Arc::new(client),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
// If MCP server binary doesn't exist, parsing tests can still run
|
// If MCP server binary doesn't exist, parsing tests can still run
|
||||||
@@ -239,18 +234,20 @@ fn create_test_executor() -> AgentExecutor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||||
|
|
||||||
let config = AgentConfig::default();
|
let config = AgentConfig::default();
|
||||||
AgentExecutor::new(provider, mcp_client, config, None)
|
AgentExecutor::new(provider, mcp_client, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_agent_config_defaults() {
|
fn test_agent_config_defaults() {
|
||||||
let config = AgentConfig::default();
|
let config = AgentConfig::default();
|
||||||
|
|
||||||
assert_eq!(config.max_iterations, 10);
|
assert_eq!(config.max_iterations, 15);
|
||||||
assert_eq!(config.model, "ollama");
|
assert_eq!(config.model, "llama3.2:latest");
|
||||||
assert_eq!(config.temperature, Some(0.7));
|
assert_eq!(config.temperature, Some(0.7));
|
||||||
assert_eq!(config.max_tool_calls, 20);
|
// max_tool_calls field removed - agent now tracks iterations instead
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -260,12 +257,10 @@ fn test_agent_config_custom() {
|
|||||||
model: "custom-model".to_string(),
|
model: "custom-model".to_string(),
|
||||||
temperature: Some(0.5),
|
temperature: Some(0.5),
|
||||||
max_tokens: Some(2000),
|
max_tokens: Some(2000),
|
||||||
max_tool_calls: 30,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(config.max_iterations, 15);
|
assert_eq!(config.max_iterations, 15);
|
||||||
assert_eq!(config.model, "custom-model");
|
assert_eq!(config.model, "custom-model");
|
||||||
assert_eq!(config.temperature, Some(0.5));
|
assert_eq!(config.temperature, Some(0.5));
|
||||||
assert_eq!(config.max_tokens, Some(2000));
|
assert_eq!(config.max_tokens, Some(2000));
|
||||||
assert_eq!(config.max_tool_calls, 30);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ description = "Core traits and types for OWLEN LLM client"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
log = "0.4.20"
|
log = { workspace = true }
|
||||||
regex = { workspace = true }
|
regex = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
@@ -21,10 +21,11 @@ unicode-width = "0.1"
|
|||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
textwrap = { workspace = true }
|
textwrap = { workspace = true }
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
|
futures-util = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
toml = { workspace = true }
|
toml = { workspace = true }
|
||||||
shellexpand = { workspace = true }
|
shellexpand = { workspace = true }
|
||||||
dirs = "5.0"
|
dirs = { workspace = true }
|
||||||
ratatui = { workspace = true }
|
ratatui = { workspace = true }
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
jsonschema = { workspace = true }
|
jsonschema = { workspace = true }
|
||||||
@@ -42,7 +43,11 @@ duckduckgo = "0.2.0"
|
|||||||
reqwest = { workspace = true, features = ["default"] }
|
reqwest = { workspace = true, features = ["default"] }
|
||||||
reqwest_011 = { version = "0.11", package = "reqwest" }
|
reqwest_011 = { version = "0.11", package = "reqwest" }
|
||||||
path-clean = "1.0"
|
path-clean = "1.0"
|
||||||
tokio-stream = "0.1"
|
tokio-stream = { workspace = true }
|
||||||
|
tokio-tungstenite = "0.21"
|
||||||
|
tungstenite = "0.21"
|
||||||
|
ollama-rs = { version = "0.3", features = ["stream", "headers"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio-test = { workspace = true }
|
tokio-test = { workspace = true }
|
||||||
|
httpmock = "0.7"
|
||||||
|
|||||||
@@ -1,377 +1,421 @@
|
|||||||
//! High‑level agentic executor implementing the ReAct pattern.
|
//! Agentic execution loop with ReAct pattern support.
|
||||||
//!
|
//!
|
||||||
//! The executor coordinates three responsibilities:
|
//! This module provides the core agent orchestration logic that allows an LLM
|
||||||
//! 1. Build a ReAct prompt from the conversation history and the list of
|
//! to reason about tasks, execute tools, and observe results in an iterative loop.
|
||||||
//! available MCP tools.
|
|
||||||
//! 2. Send the prompt to an LLM provider (any type implementing
|
|
||||||
//! `owlen_core::Provider`).
|
|
||||||
//! 3. Parse the LLM response, optionally invoke a tool via an MCP client,
|
|
||||||
//! and feed the observation back into the conversation.
|
|
||||||
//!
|
|
||||||
//! The implementation is intentionally minimal – it provides the core loop
|
|
||||||
//! required by Phase 4 of the roadmap. Integration with the TUI and additional
|
|
||||||
//! safety mechanisms can be added on top of this module.
|
|
||||||
|
|
||||||
|
use crate::Provider;
|
||||||
|
use crate::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
|
use crate::types::{ChatParameters, ChatRequest, Message};
|
||||||
|
use crate::{Error, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::ui::UiController;
|
/// Maximum number of agent iterations before stopping
|
||||||
|
const DEFAULT_MAX_ITERATIONS: usize = 15;
|
||||||
|
|
||||||
use dirs;
|
/// Parsed response from the LLM in ReAct format
|
||||||
use regex::Regex;
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
use serde_json::json;
|
|
||||||
use std::fs::OpenOptions;
|
|
||||||
use std::io::Write;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
use tokio::signal;
|
|
||||||
|
|
||||||
use crate::mcp::client::McpClient;
|
|
||||||
use crate::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
|
||||||
use crate::{
|
|
||||||
types::{ChatRequest, Message},
|
|
||||||
Error, Provider, Result as CoreResult,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Configuration for the agent executor.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct AgentConfig {
|
|
||||||
/// Maximum number of ReAct iterations before the executor aborts.
|
|
||||||
pub max_iterations: usize,
|
|
||||||
/// Model name to use for the LLM provider.
|
|
||||||
pub model: String,
|
|
||||||
/// Optional temperature.
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
/// Optional max_tokens.
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
/// Maximum number of tool calls allowed per execution (budget).
|
|
||||||
pub max_tool_calls: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for AgentConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
max_iterations: 10,
|
|
||||||
model: "ollama".into(),
|
|
||||||
temperature: Some(0.7),
|
|
||||||
max_tokens: None,
|
|
||||||
max_tool_calls: 20,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Enum representing the possible parsed LLM responses in ReAct format.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum LlmResponse {
|
pub enum LlmResponse {
|
||||||
/// A reasoning step without action.
|
/// LLM wants to execute a tool
|
||||||
Reasoning { thought: String },
|
|
||||||
/// The model wants to invoke a tool.
|
|
||||||
ToolCall {
|
ToolCall {
|
||||||
thought: String,
|
thought: String,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
arguments: serde_json::Value,
|
arguments: serde_json::Value,
|
||||||
},
|
},
|
||||||
/// The model produced a final answer.
|
/// LLM has reached a final answer
|
||||||
FinalAnswer { thought: String, answer: String },
|
FinalAnswer { thought: String, answer: String },
|
||||||
|
/// LLM is just reasoning without taking action
|
||||||
|
Reasoning { thought: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error type for the agent executor.
|
/// Parse error when LLM response doesn't match expected format
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum AgentError {
|
pub enum ParseError {
|
||||||
#[error("LLM provider error: {0}")]
|
#[error("No recognizable pattern found in response")]
|
||||||
Provider(Error),
|
NoPattern,
|
||||||
#[error("MCP client error: {0}")]
|
#[error("Missing required field: {0}")]
|
||||||
Mcp(Error),
|
MissingField(String),
|
||||||
#[error("Tool execution denied by user")]
|
#[error("Invalid JSON in ACTION_INPUT: {0}")]
|
||||||
ToolDenied,
|
InvalidJson(String),
|
||||||
#[error("Failed to parse LLM response")]
|
|
||||||
Parse,
|
|
||||||
#[error("Maximum iterations ({0}) reached without final answer")]
|
|
||||||
MaxIterationsReached(usize),
|
|
||||||
#[error("Agent execution cancelled by user (Ctrl+C)")]
|
|
||||||
Cancelled,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Core executor handling the ReAct loop.
|
/// Result of an agent execution
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AgentResult {
|
||||||
|
/// Final answer from the agent
|
||||||
|
pub answer: String,
|
||||||
|
/// Number of iterations taken
|
||||||
|
pub iterations: usize,
|
||||||
|
/// All messages exchanged during execution
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
/// Whether the agent completed successfully
|
||||||
|
pub success: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for agent execution
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AgentConfig {
|
||||||
|
/// Maximum number of iterations
|
||||||
|
pub max_iterations: usize,
|
||||||
|
/// Model to use for reasoning
|
||||||
|
pub model: String,
|
||||||
|
/// Temperature for LLM sampling
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
/// Max tokens per LLM call
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AgentConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_iterations: DEFAULT_MAX_ITERATIONS,
|
||||||
|
model: "llama3.2:latest".to_string(),
|
||||||
|
temperature: Some(0.7),
|
||||||
|
max_tokens: Some(4096),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent executor that orchestrates the ReAct loop
|
||||||
pub struct AgentExecutor {
|
pub struct AgentExecutor {
|
||||||
llm_client: Arc<dyn Provider + Send + Sync>,
|
/// LLM provider for reasoning
|
||||||
tool_client: Arc<dyn McpClient + Send + Sync>,
|
llm_client: Arc<dyn Provider>,
|
||||||
|
/// MCP client for tool execution
|
||||||
|
tool_client: Arc<dyn McpClient>,
|
||||||
|
/// Agent configuration
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
ui_controller: Option<Arc<dyn UiController + Send + Sync>>, // optional UI for confirmations
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentExecutor {
|
impl AgentExecutor {
|
||||||
/// Construct a new executor.
|
/// Create a new agent executor
|
||||||
pub fn new(
|
pub fn new(
|
||||||
llm_client: Arc<dyn Provider + Send + Sync>,
|
llm_client: Arc<dyn Provider>,
|
||||||
tool_client: Arc<dyn McpClient + Send + Sync>,
|
tool_client: Arc<dyn McpClient>,
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
ui_controller: Option<Arc<dyn UiController + Send + Sync>>, // pass None for headless
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
llm_client,
|
llm_client,
|
||||||
tool_client,
|
tool_client,
|
||||||
config,
|
config,
|
||||||
ui_controller,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Discover tools exposed by the MCP server.
|
/// Run the agent loop with the given query
|
||||||
async fn discover_tools(&self) -> CoreResult<Vec<McpToolDescriptor>> {
|
pub async fn run(&self, query: String) -> Result<AgentResult> {
|
||||||
self.tool_client.list_tools().await
|
let mut messages = vec![Message::user(query)];
|
||||||
}
|
let tools = self.discover_tools().await?;
|
||||||
|
|
||||||
// #[allow(dead_code)]
|
for iteration in 0..self.config.max_iterations {
|
||||||
// Build a ReAct prompt from the current message history and discovered tools.
|
let prompt = self.build_react_prompt(&messages, &tools);
|
||||||
/*
|
let response = self.generate_llm_response(prompt).await?;
|
||||||
#[allow(dead_code)]
|
|
||||||
fn build_prompt(
|
|
||||||
&self,
|
|
||||||
history: &[Message],
|
|
||||||
tools: &[McpToolDescriptor],
|
|
||||||
) -> String {
|
|
||||||
// System prompt describing the format.
|
|
||||||
let system = "You are an intelligent agent following the ReAct pattern. Use the following sections:\nTHOUGHT: your reasoning\nACTION: the tool name you want to call (or "final_answer")\nACTION_INPUT: JSON arguments for the tool.\nIf ACTION is "final_answer", provide the final answer in the next line after the ACTION_INPUT.\n";
|
|
||||||
|
|
||||||
let mut prompt = format!("System: {}\n", system);
|
match self.parse_response(&response)? {
|
||||||
// Append conversation history.
|
|
||||||
for msg in history {
|
|
||||||
let role = match msg.role {
|
|
||||||
Role::User => "User",
|
|
||||||
Role::Assistant => "Assistant",
|
|
||||||
Role::System => "System",
|
|
||||||
Role::Tool => "Tool",
|
|
||||||
};
|
|
||||||
prompt.push_str(&format!("{}: {}\n", role, msg.content));
|
|
||||||
}
|
|
||||||
// Append tool descriptions.
|
|
||||||
if !tools.is_empty() {
|
|
||||||
let tools_json = json!(tools);
|
|
||||||
prompt.push_str(&format!("Available tools (JSON schema): {}\n", tools_json));
|
|
||||||
}
|
|
||||||
prompt
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// build_prompt removed; not used in current implementation
|
|
||||||
|
|
||||||
/// Parse raw LLM text into a structured `LlmResponse`.
|
|
||||||
pub fn parse_response(&self, text: &str) -> std::result::Result<LlmResponse, AgentError> {
|
|
||||||
// Normalise line endings.
|
|
||||||
let txt = text.trim();
|
|
||||||
// Regex patterns for parsing ReAct format.
|
|
||||||
// THOUGHT and ACTION capture up to the next newline.
|
|
||||||
// ACTION_INPUT captures everything remaining (including multiline JSON).
|
|
||||||
let thought_re = Regex::new(r"(?s)THOUGHT:\s*(?P<thought>.+?)(?:\n|$)").unwrap();
|
|
||||||
let action_re = Regex::new(r"(?s)ACTION:\s*(?P<action>.+?)(?:\n|$)").unwrap();
|
|
||||||
// ACTION_INPUT captures rest of text (multiline-friendly)
|
|
||||||
let input_re = Regex::new(r"(?s)ACTION_INPUT:\s*(?P<input>.+)").unwrap();
|
|
||||||
|
|
||||||
let thought = thought_re
|
|
||||||
.captures(txt)
|
|
||||||
.and_then(|c| c.name("thought"))
|
|
||||||
.map(|m| m.as_str().trim().to_string())
|
|
||||||
.ok_or(AgentError::Parse)?;
|
|
||||||
let action = action_re
|
|
||||||
.captures(txt)
|
|
||||||
.and_then(|c| c.name("action"))
|
|
||||||
.map(|m| m.as_str().trim().to_string())
|
|
||||||
.ok_or(AgentError::Parse)?;
|
|
||||||
let input = input_re
|
|
||||||
.captures(txt)
|
|
||||||
.and_then(|c| c.name("input"))
|
|
||||||
.map(|m| m.as_str().trim().to_string())
|
|
||||||
.ok_or(AgentError::Parse)?;
|
|
||||||
|
|
||||||
if action.eq_ignore_ascii_case("final_answer") {
|
|
||||||
Ok(LlmResponse::FinalAnswer {
|
|
||||||
thought,
|
|
||||||
answer: input,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Parse arguments as JSON, falling back to a string if invalid.
|
|
||||||
let args = serde_json::from_str(&input).unwrap_or_else(|_| json!(input));
|
|
||||||
Ok(LlmResponse::ToolCall {
|
|
||||||
thought,
|
|
||||||
tool_name: action,
|
|
||||||
arguments: args,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute a single tool call via the MCP client.
|
|
||||||
async fn execute_tool(
|
|
||||||
&self,
|
|
||||||
name: &str,
|
|
||||||
arguments: serde_json::Value,
|
|
||||||
) -> CoreResult<McpToolResponse> {
|
|
||||||
// For potentially unsafe tools (write/delete) ask for UI confirmation
|
|
||||||
// if a controller is available.
|
|
||||||
let dangerous = name.contains("write") || name.contains("delete");
|
|
||||||
if dangerous {
|
|
||||||
if let Some(controller) = &self.ui_controller {
|
|
||||||
let prompt = format!(
|
|
||||||
"Confirm execution of potentially unsafe tool '{}' with args {}?",
|
|
||||||
name, arguments
|
|
||||||
);
|
|
||||||
if !controller.confirm(&prompt).await {
|
|
||||||
return Err(Error::PermissionDenied(format!(
|
|
||||||
"Tool '{}' denied by user",
|
|
||||||
name
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let call = McpToolCall {
|
|
||||||
name: name.to_string(),
|
|
||||||
arguments,
|
|
||||||
};
|
|
||||||
self.tool_client.call_tool(call).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the full ReAct loop and return the final answer.
|
|
||||||
pub async fn run(&self, query: String) -> std::result::Result<String, AgentError> {
|
|
||||||
let tools = self.discover_tools().await.map_err(AgentError::Mcp)?;
|
|
||||||
|
|
||||||
// Build system prompt with ReAct format instructions
|
|
||||||
let tools_desc = tools
|
|
||||||
.iter()
|
|
||||||
.map(|t| {
|
|
||||||
let schema_str = serde_json::to_string_pretty(&t.input_schema)
|
|
||||||
.unwrap_or_else(|_| "{}".to_string());
|
|
||||||
format!(
|
|
||||||
"- {}: {}\n Input schema: {}",
|
|
||||||
t.name, t.description, schema_str
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
let system_prompt = format!(
|
|
||||||
"You are an AI assistant that uses the ReAct (Reasoning + Acting) pattern to solve tasks.\n\n\
|
|
||||||
You must ALWAYS respond in this exact format:\n\n\
|
|
||||||
THOUGHT: <your reasoning about what to do next>\n\
|
|
||||||
ACTION: <tool_name or \"final_answer\">\n\
|
|
||||||
ACTION_INPUT: <JSON arguments for the tool, or the final answer text>\n\n\
|
|
||||||
Available tools:\n{}\n\n\
|
|
||||||
HOW IT WORKS:\n\
|
|
||||||
1. When you call a tool, you will receive its output in the next message\n\
|
|
||||||
2. After receiving the tool output, analyze it and either:\n\
|
|
||||||
a) Use the information to provide a final answer\n\
|
|
||||||
b) Call another tool if you need more information\n\
|
|
||||||
3. When you have the information needed to answer the user's question, provide a final answer\n\n\
|
|
||||||
To provide a final answer:\n\
|
|
||||||
THOUGHT: <summary of what you learned>\n\
|
|
||||||
ACTION: final_answer\n\
|
|
||||||
ACTION_INPUT: <your complete answer using the information from the tools>\n\n\
|
|
||||||
IMPORTANT: You MUST follow this format exactly. Do not deviate from it.\n\
|
|
||||||
IMPORTANT: Only use the tools listed above. Do not try to use tools that are not listed.\n\
|
|
||||||
IMPORTANT: When providing the final answer, include the actual information you learned, not just the tool arguments.",
|
|
||||||
tools_desc
|
|
||||||
);
|
|
||||||
|
|
||||||
// Initialize conversation with system prompt and user query
|
|
||||||
let mut messages = vec![Message::system(system_prompt.clone()), Message::user(query)];
|
|
||||||
|
|
||||||
// Cancellation flag set when Ctrl+C is received.
|
|
||||||
let cancelled = Arc::new(AtomicBool::new(false));
|
|
||||||
let cancel_flag = cancelled.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// Wait for Ctrl+C signal.
|
|
||||||
let _ = signal::ctrl_c().await;
|
|
||||||
cancel_flag.store(true, Ordering::SeqCst);
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut tool_calls = 0usize;
|
|
||||||
for _ in 0..self.config.max_iterations {
|
|
||||||
if cancelled.load(Ordering::SeqCst) {
|
|
||||||
return Err(AgentError::Cancelled);
|
|
||||||
}
|
|
||||||
// Build a ChatRequest for the provider.
|
|
||||||
let chat_req = ChatRequest {
|
|
||||||
model: self.config.model.clone(),
|
|
||||||
messages: messages.clone(),
|
|
||||||
parameters: crate::types::ChatParameters {
|
|
||||||
temperature: self.config.temperature,
|
|
||||||
max_tokens: self.config.max_tokens,
|
|
||||||
stream: false,
|
|
||||||
extra: Default::default(),
|
|
||||||
},
|
|
||||||
tools: Some(tools.clone()),
|
|
||||||
};
|
|
||||||
let raw_resp = self
|
|
||||||
.llm_client
|
|
||||||
.chat(chat_req)
|
|
||||||
.await
|
|
||||||
.map_err(AgentError::Provider)?;
|
|
||||||
let parsed = self
|
|
||||||
.parse_response(&raw_resp.message.content)
|
|
||||||
.map_err(|e| {
|
|
||||||
eprintln!("\n=== PARSE ERROR ===");
|
|
||||||
eprintln!("Error: {:?}", e);
|
|
||||||
eprintln!("LLM Response:\n{}", raw_resp.message.content);
|
|
||||||
eprintln!("=== END ===\n");
|
|
||||||
e
|
|
||||||
})?;
|
|
||||||
match parsed {
|
|
||||||
LlmResponse::Reasoning { thought } => {
|
|
||||||
// Append the reasoning as an assistant message.
|
|
||||||
messages.push(Message::assistant(thought));
|
|
||||||
}
|
|
||||||
LlmResponse::ToolCall {
|
LlmResponse::ToolCall {
|
||||||
thought,
|
thought,
|
||||||
tool_name,
|
tool_name,
|
||||||
arguments,
|
arguments,
|
||||||
} => {
|
} => {
|
||||||
// Record the thought.
|
// Add assistant's reasoning
|
||||||
messages.push(Message::assistant(thought));
|
messages.push(Message::assistant(format!(
|
||||||
// Enforce tool call budget.
|
"THOUGHT: {}\nACTION: {}\nACTION_INPUT: {}",
|
||||||
tool_calls += 1;
|
thought,
|
||||||
if tool_calls > self.config.max_tool_calls {
|
tool_name,
|
||||||
return Err(AgentError::MaxIterationsReached(self.config.max_iterations));
|
serde_json::to_string_pretty(&arguments).unwrap_or_default()
|
||||||
}
|
)));
|
||||||
// Execute tool.
|
|
||||||
let args_clone = arguments.clone();
|
// Execute the tool
|
||||||
let tool_resp = self
|
let result = self.execute_tool(&tool_name, arguments).await?;
|
||||||
.execute_tool(&tool_name, args_clone.clone())
|
|
||||||
.await
|
// Add observation
|
||||||
.map_err(AgentError::Mcp)?;
|
messages.push(Message::tool(
|
||||||
// Convert tool output to a string for the message.
|
tool_name.clone(),
|
||||||
let output_str = tool_resp
|
format!(
|
||||||
.output
|
"OBSERVATION: {}",
|
||||||
.as_str()
|
serde_json::to_string_pretty(&result.output).unwrap_or_default()
|
||||||
.map(|s| s.to_string())
|
),
|
||||||
.unwrap_or_else(|| tool_resp.output.to_string());
|
));
|
||||||
// Audit log the tool execution.
|
|
||||||
if let Some(config_dir) = dirs::config_dir() {
|
|
||||||
let log_path = config_dir.join("owlen/logs/tool_execution.log");
|
|
||||||
if let Some(parent) = log_path.parent() {
|
|
||||||
let _ = std::fs::create_dir_all(parent);
|
|
||||||
}
|
|
||||||
if let Ok(mut file) =
|
|
||||||
OpenOptions::new().create(true).append(true).open(&log_path)
|
|
||||||
{
|
|
||||||
let ts = SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs();
|
|
||||||
let _ = writeln!(
|
|
||||||
file,
|
|
||||||
"{} | tool: {} | args: {} | output: {}",
|
|
||||||
ts, tool_name, args_clone, output_str
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messages.push(Message::tool(tool_name, output_str));
|
|
||||||
}
|
}
|
||||||
LlmResponse::FinalAnswer { thought, answer } => {
|
LlmResponse::FinalAnswer { thought, answer } => {
|
||||||
// Append final thought and answer, then return.
|
messages.push(Message::assistant(format!(
|
||||||
messages.push(Message::assistant(thought));
|
"THOUGHT: {}\nFINAL_ANSWER: {}",
|
||||||
// The final answer should be a single assistant message.
|
thought, answer
|
||||||
messages.push(Message::assistant(answer.clone()));
|
)));
|
||||||
return Ok(answer);
|
return Ok(AgentResult {
|
||||||
|
answer,
|
||||||
|
iterations: iteration + 1,
|
||||||
|
messages,
|
||||||
|
success: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
LlmResponse::Reasoning { thought } => {
|
||||||
|
messages.push(Message::assistant(format!("THOUGHT: {}", thought)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(AgentError::MaxIterationsReached(self.config.max_iterations))
|
|
||||||
|
// Max iterations reached
|
||||||
|
Ok(AgentResult {
|
||||||
|
answer: "Maximum iterations reached without finding a final answer".to_string(),
|
||||||
|
iterations: self.config.max_iterations,
|
||||||
|
messages,
|
||||||
|
success: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Discover available tools from the MCP client
|
||||||
|
async fn discover_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
|
self.tool_client.list_tools().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a ReAct-formatted prompt with available tools
|
||||||
|
fn build_react_prompt(
|
||||||
|
&self,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[McpToolDescriptor],
|
||||||
|
) -> Vec<Message> {
|
||||||
|
let mut prompt_messages = Vec::new();
|
||||||
|
|
||||||
|
// System prompt with ReAct instructions
|
||||||
|
let system_prompt = self.build_system_prompt(tools);
|
||||||
|
prompt_messages.push(Message::system(system_prompt));
|
||||||
|
|
||||||
|
// Add conversation history
|
||||||
|
prompt_messages.extend_from_slice(messages);
|
||||||
|
|
||||||
|
prompt_messages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the system prompt with ReAct format and tool descriptions
|
||||||
|
fn build_system_prompt(&self, tools: &[McpToolDescriptor]) -> String {
|
||||||
|
let mut prompt = String::from(
|
||||||
|
"You are an AI assistant that uses the ReAct (Reasoning and Acting) pattern to solve tasks.\n\n\
|
||||||
|
You have access to the following tools:\n\n",
|
||||||
|
);
|
||||||
|
|
||||||
|
for tool in tools {
|
||||||
|
prompt.push_str(&format!("- {}: {}\n", tool.name, tool.description));
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.push_str(
|
||||||
|
"\nUse the following format:\n\n\
|
||||||
|
THOUGHT: Your reasoning about what to do next\n\
|
||||||
|
ACTION: tool_name\n\
|
||||||
|
ACTION_INPUT: {\"param\": \"value\"}\n\n\
|
||||||
|
You will receive:\n\
|
||||||
|
OBSERVATION: The result of the tool execution\n\n\
|
||||||
|
Continue this process until you have enough information, then provide:\n\
|
||||||
|
THOUGHT: Final reasoning\n\
|
||||||
|
FINAL_ANSWER: Your comprehensive answer\n\n\
|
||||||
|
Important:\n\
|
||||||
|
- Always start with THOUGHT to explain your reasoning\n\
|
||||||
|
- ACTION must be one of the available tools\n\
|
||||||
|
- ACTION_INPUT must be valid JSON\n\
|
||||||
|
- Use FINAL_ANSWER only when you have sufficient information\n",
|
||||||
|
);
|
||||||
|
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate an LLM response
|
||||||
|
async fn generate_llm_response(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
messages,
|
||||||
|
parameters: ChatParameters {
|
||||||
|
temperature: self.config.temperature,
|
||||||
|
max_tokens: self.config.max_tokens,
|
||||||
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self.llm_client.send_prompt(request).await?;
|
||||||
|
Ok(response.message.content)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse LLM response into structured format
|
||||||
|
pub fn parse_response(&self, text: &str) -> Result<LlmResponse> {
|
||||||
|
let lines: Vec<&str> = text.lines().collect();
|
||||||
|
let mut thought = String::new();
|
||||||
|
let mut action = String::new();
|
||||||
|
let mut action_input = String::new();
|
||||||
|
let mut final_answer = String::new();
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < lines.len() {
|
||||||
|
let line = lines[i].trim();
|
||||||
|
|
||||||
|
if line.starts_with("THOUGHT:") {
|
||||||
|
thought = line
|
||||||
|
.strip_prefix("THOUGHT:")
|
||||||
|
.unwrap_or("")
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
// Collect multi-line thoughts
|
||||||
|
i += 1;
|
||||||
|
while i < lines.len()
|
||||||
|
&& !lines[i].trim().starts_with("ACTION")
|
||||||
|
&& !lines[i].trim().starts_with("FINAL_ANSWER")
|
||||||
|
{
|
||||||
|
if !lines[i].trim().is_empty() {
|
||||||
|
thought.push(' ');
|
||||||
|
thought.push_str(lines[i].trim());
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.starts_with("ACTION:") {
|
||||||
|
action = line
|
||||||
|
.strip_prefix("ACTION:")
|
||||||
|
.unwrap_or("")
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.starts_with("ACTION_INPUT:") {
|
||||||
|
action_input = line
|
||||||
|
.strip_prefix("ACTION_INPUT:")
|
||||||
|
.unwrap_or("")
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
// Collect multi-line JSON
|
||||||
|
i += 1;
|
||||||
|
while i < lines.len()
|
||||||
|
&& !lines[i].trim().starts_with("THOUGHT")
|
||||||
|
&& !lines[i].trim().starts_with("ACTION")
|
||||||
|
{
|
||||||
|
action_input.push(' ');
|
||||||
|
action_input.push_str(lines[i].trim());
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.starts_with("FINAL_ANSWER:") {
|
||||||
|
final_answer = line
|
||||||
|
.strip_prefix("FINAL_ANSWER:")
|
||||||
|
.unwrap_or("")
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
// Collect multi-line answer
|
||||||
|
i += 1;
|
||||||
|
while i < lines.len() {
|
||||||
|
if !lines[i].trim().is_empty() {
|
||||||
|
final_answer.push(' ');
|
||||||
|
final_answer.push_str(lines[i].trim());
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine response type
|
||||||
|
if !final_answer.is_empty() {
|
||||||
|
return Ok(LlmResponse::FinalAnswer {
|
||||||
|
thought,
|
||||||
|
answer: final_answer,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !action.is_empty() {
|
||||||
|
let arguments = if action_input.is_empty() {
|
||||||
|
serde_json::json!({})
|
||||||
|
} else {
|
||||||
|
serde_json::from_str(&action_input)
|
||||||
|
.map_err(|e| Error::Agent(ParseError::InvalidJson(e.to_string()).to_string()))?
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok(LlmResponse::ToolCall {
|
||||||
|
thought,
|
||||||
|
tool_name: action,
|
||||||
|
arguments,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !thought.is_empty() {
|
||||||
|
return Ok(LlmResponse::Reasoning { thought });
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::Agent(ParseError::NoPattern.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a tool call
|
||||||
|
async fn execute_tool(
|
||||||
|
&self,
|
||||||
|
tool_name: &str,
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
) -> Result<McpToolResponse> {
|
||||||
|
let call = McpToolCall {
|
||||||
|
name: tool_name.to_string(),
|
||||||
|
arguments,
|
||||||
|
};
|
||||||
|
self.tool_client.call_tool(call).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::llm::test_utils::MockProvider;
|
||||||
|
use crate::mcp::test_utils::MockMcpClient;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_tool_call() {
|
||||||
|
let executor = AgentExecutor {
|
||||||
|
llm_client: Arc::new(MockProvider::default()),
|
||||||
|
tool_client: Arc::new(MockMcpClient),
|
||||||
|
config: AgentConfig::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = r#"
|
||||||
|
THOUGHT: I need to search for information about Rust
|
||||||
|
ACTION: web_search
|
||||||
|
ACTION_INPUT: {"query": "Rust programming language"}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let result = executor.parse_response(text).unwrap();
|
||||||
|
match result {
|
||||||
|
LlmResponse::ToolCall {
|
||||||
|
thought,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
} => {
|
||||||
|
assert!(thought.contains("search for information"));
|
||||||
|
assert_eq!(tool_name, "web_search");
|
||||||
|
assert_eq!(arguments["query"], "Rust programming language");
|
||||||
|
}
|
||||||
|
_ => panic!("Expected ToolCall"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_final_answer() {
|
||||||
|
let executor = AgentExecutor {
|
||||||
|
llm_client: Arc::new(MockProvider::default()),
|
||||||
|
tool_client: Arc::new(MockMcpClient),
|
||||||
|
config: AgentConfig::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = r#"
|
||||||
|
THOUGHT: I now have enough information to answer
|
||||||
|
FINAL_ANSWER: Rust is a systems programming language focused on safety and performance.
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let result = executor.parse_response(text).unwrap();
|
||||||
|
match result {
|
||||||
|
LlmResponse::FinalAnswer { thought, answer } => {
|
||||||
|
assert!(thought.contains("enough information"));
|
||||||
|
assert!(answer.contains("Rust is a systems programming language"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected FinalAnswer"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -58,17 +58,21 @@ impl ConsentManager {
|
|||||||
/// Load consent records from vault storage
|
/// Load consent records from vault storage
|
||||||
pub fn from_vault(vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Self {
|
pub fn from_vault(vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Self {
|
||||||
let guard = vault.lock().expect("Vault mutex poisoned");
|
let guard = vault.lock().expect("Vault mutex poisoned");
|
||||||
if let Some(consent_data) = guard.settings().get("consent_records") {
|
if let Some(permanent_records) =
|
||||||
if let Ok(permanent_records) =
|
guard
|
||||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
.settings()
|
||||||
{
|
.get("consent_records")
|
||||||
return Self {
|
.and_then(|consent_data| {
|
||||||
permanent_records,
|
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||||
session_records: HashMap::new(),
|
.ok()
|
||||||
once_records: HashMap::new(),
|
})
|
||||||
pending_requests: HashMap::new(),
|
{
|
||||||
};
|
return Self {
|
||||||
}
|
permanent_records,
|
||||||
|
session_records: HashMap::new(),
|
||||||
|
once_records: HashMap::new(),
|
||||||
|
pending_requests: HashMap::new(),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
@@ -91,17 +95,21 @@ impl ConsentManager {
|
|||||||
endpoints: Vec<String>,
|
endpoints: Vec<String>,
|
||||||
) -> Result<ConsentScope> {
|
) -> Result<ConsentScope> {
|
||||||
// Check if already granted permanently
|
// Check if already granted permanently
|
||||||
if let Some(existing) = self.permanent_records.get(tool_name) {
|
if self
|
||||||
if existing.scope == ConsentScope::Permanent {
|
.permanent_records
|
||||||
return Ok(ConsentScope::Permanent);
|
.get(tool_name)
|
||||||
}
|
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
||||||
|
{
|
||||||
|
return Ok(ConsentScope::Permanent);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if granted for session
|
// Check if granted for session
|
||||||
if let Some(existing) = self.session_records.get(tool_name) {
|
if self
|
||||||
if existing.scope == ConsentScope::Session {
|
.session_records
|
||||||
return Ok(ConsentScope::Session);
|
.get(tool_name)
|
||||||
}
|
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
||||||
|
{
|
||||||
|
return Ok(ConsentScope::Session);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if request is already pending (prevent duplicate prompts)
|
// Check if request is already pending (prevent duplicate prompts)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use crate::Result;
|
||||||
use crate::storage::StorageManager;
|
use crate::storage::StorageManager;
|
||||||
use crate::types::{Conversation, Message};
|
use crate::types::{Conversation, Message};
|
||||||
use crate::Result;
|
|
||||||
use serde_json::{Number, Value};
|
use serde_json::{Number, Value};
|
||||||
use std::collections::{HashMap, VecDeque};
|
use std::collections::{HashMap, VecDeque};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@@ -213,6 +213,34 @@ impl ConversationManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cancel_stream(&mut self, message_id: Uuid, notice: impl Into<String>) -> Result<()> {
|
||||||
|
let index = self
|
||||||
|
.message_index
|
||||||
|
.get(&message_id)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||||
|
|
||||||
|
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||||
|
message.content = notice.into();
|
||||||
|
message.timestamp = std::time::SystemTime::now();
|
||||||
|
message
|
||||||
|
.metadata
|
||||||
|
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||||
|
message.metadata.remove(PLACEHOLDER_FLAG);
|
||||||
|
let millis = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64;
|
||||||
|
message.metadata.insert(
|
||||||
|
LAST_CHUNK_TS.to_string(),
|
||||||
|
Value::Number(Number::from(millis)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.streaming.remove(&message_id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Set tool calls on a streaming message
|
/// Set tool calls on a streaming message
|
||||||
pub fn set_tool_calls_on_message(
|
pub fn set_tool_calls_on_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{storage::StorageManager, Error, Result};
|
use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct ApiCredentials {
|
pub struct ApiCredentials {
|
||||||
@@ -10,6 +10,8 @@ pub struct ApiCredentials {
|
|||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const OLLAMA_CLOUD_CREDENTIAL_ID: &str = "provider_ollama_cloud";
|
||||||
|
|
||||||
pub struct CredentialManager {
|
pub struct CredentialManager {
|
||||||
storage: Arc<StorageManager>,
|
storage: Arc<StorageManager>,
|
||||||
master_key: Arc<Vec<u8>>,
|
master_key: Arc<Vec<u8>>,
|
||||||
@@ -29,6 +31,10 @@ impl CredentialManager {
|
|||||||
format!("{}_{}", self.namespace, tool_name)
|
format!("{}_{}", self.namespace, tool_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn oauth_storage_key(&self, resource: &str) -> String {
|
||||||
|
self.namespaced_key(&format!("oauth_{resource}"))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn store_credentials(
|
pub async fn store_credentials(
|
||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
@@ -66,4 +72,37 @@ impl CredentialManager {
|
|||||||
let key = self.namespaced_key(tool_name);
|
let key = self.namespaced_key(tool_name);
|
||||||
self.storage.delete_secure_item(&key).await
|
self.storage.delete_secure_item(&key).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> {
|
||||||
|
let key = self.oauth_storage_key(resource);
|
||||||
|
let payload = serde_json::to_vec(token).map_err(|err| {
|
||||||
|
Error::Storage(format!(
|
||||||
|
"Failed to serialize OAuth token for secure storage: {err}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
self.storage
|
||||||
|
.store_secure_item(&key, &payload, &self.master_key)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn load_oauth_token(&self, resource: &str) -> Result<Option<OAuthToken>> {
|
||||||
|
let key = self.oauth_storage_key(resource);
|
||||||
|
let raw = self
|
||||||
|
.storage
|
||||||
|
.load_secure_item(&key, &self.master_key)
|
||||||
|
.await?;
|
||||||
|
if let Some(bytes) = raw {
|
||||||
|
let token = serde_json::from_slice(&bytes).map_err(|err| {
|
||||||
|
Error::Storage(format!("Failed to deserialize stored OAuth token: {err}"))
|
||||||
|
})?;
|
||||||
|
Ok(Some(token))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> {
|
||||||
|
let key = self.oauth_storage_key(resource);
|
||||||
|
self.storage.delete_secure_item(&key).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ use std::fs;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use aes_gcm::{
|
use aes_gcm::{
|
||||||
aead::{Aead, KeyInit},
|
|
||||||
Aes256Gcm, Nonce,
|
Aes256Gcm, Nonce,
|
||||||
|
aead::{Aead, KeyInit},
|
||||||
};
|
};
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{Context, Result, bail};
|
||||||
use ring::digest;
|
use ring::digest;
|
||||||
use ring::rand::{SecureRandom, SystemRandom};
|
use ring::rand::{SecureRandom, SystemRandom};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|||||||
32
crates/owlen-core/src/facade/llm_client.rs
Normal file
32
crates/owlen-core/src/facade/llm_client.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
Result,
|
||||||
|
llm::ChatStream,
|
||||||
|
mcp::{McpToolCall, McpToolDescriptor, McpToolResponse},
|
||||||
|
types::{ChatRequest, ChatResponse, ModelInfo},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Object-safe facade for interacting with LLM backends.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait LlmClient: Send + Sync {
|
||||||
|
/// List the models exposed by this client.
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||||
|
|
||||||
|
/// Issue a one-shot chat request and wait for the complete response.
|
||||||
|
async fn send_chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||||
|
|
||||||
|
/// Stream chat responses incrementally.
|
||||||
|
async fn stream_chat(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||||
|
|
||||||
|
/// Enumerate tools exposed by the backing provider.
|
||||||
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>>;
|
||||||
|
|
||||||
|
/// Invoke a tool exposed by the provider.
|
||||||
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience alias for trait-object clients.
|
||||||
|
pub type DynLlmClient = Arc<dyn LlmClient>;
|
||||||
1
crates/owlen-core/src/facade/mod.rs
Normal file
1
crates/owlen-core/src/facade/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub mod llm_client;
|
||||||
@@ -1,19 +1,20 @@
|
|||||||
use crate::types::Message;
|
use crate::types::Message;
|
||||||
|
use crate::ui::RoleLabelDisplay;
|
||||||
|
|
||||||
/// Formats messages for display across different clients.
|
/// Formats messages for display across different clients.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MessageFormatter {
|
pub struct MessageFormatter {
|
||||||
wrap_width: usize,
|
wrap_width: usize,
|
||||||
show_role_labels: bool,
|
role_label_mode: RoleLabelDisplay,
|
||||||
preserve_empty_lines: bool,
|
preserve_empty_lines: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MessageFormatter {
|
impl MessageFormatter {
|
||||||
/// Create a new formatter
|
/// Create a new formatter
|
||||||
pub fn new(wrap_width: usize, show_role_labels: bool) -> Self {
|
pub fn new(wrap_width: usize, role_label_mode: RoleLabelDisplay) -> Self {
|
||||||
Self {
|
Self {
|
||||||
wrap_width: wrap_width.max(20),
|
wrap_width: wrap_width.max(20),
|
||||||
show_role_labels,
|
role_label_mode,
|
||||||
preserve_empty_lines: false,
|
preserve_empty_lines: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -29,9 +30,19 @@ impl MessageFormatter {
|
|||||||
self.wrap_width = width.max(20);
|
self.wrap_width = width.max(20);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether role labels should be shown alongside messages
|
/// The configured role label layout preference.
|
||||||
|
pub fn role_label_mode(&self) -> RoleLabelDisplay {
|
||||||
|
self.role_label_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether any role label should be shown alongside messages.
|
||||||
pub fn show_role_labels(&self) -> bool {
|
pub fn show_role_labels(&self) -> bool {
|
||||||
self.show_role_labels
|
!matches!(self.role_label_mode, RoleLabelDisplay::None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the role label layout preference.
|
||||||
|
pub fn set_role_label_mode(&mut self, mode: RoleLabelDisplay) {
|
||||||
|
self.role_label_mode = mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format_message(&self, message: &Message) -> Vec<String> {
|
pub fn format_message(&self, message: &Message) -> Vec<String> {
|
||||||
|
|||||||
@@ -191,6 +191,12 @@ impl InputBuffer {
|
|||||||
self.history.pop_back();
|
self.history.pop_back();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Clear saved input history entries.
|
||||||
|
pub fn clear_history(&mut self) {
|
||||||
|
self.history.clear();
|
||||||
|
self.history_index = None;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prev_char_boundary(buffer: &str, cursor: usize) -> usize {
|
fn prev_char_boundary(buffer: &str, cursor: usize) -> usize {
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
#![allow(clippy::collapsible_if)] // TODO: Remove once we can rely on Rust 2024 let-chains
|
||||||
|
|
||||||
//! Core traits and types for OWLEN LLM client
|
//! Core traits and types for OWLEN LLM client
|
||||||
//!
|
//!
|
||||||
//! This crate provides the foundational abstractions for building
|
//! This crate provides the foundational abstractions for building
|
||||||
@@ -9,14 +11,20 @@ pub mod consent;
|
|||||||
pub mod conversation;
|
pub mod conversation;
|
||||||
pub mod credentials;
|
pub mod credentials;
|
||||||
pub mod encryption;
|
pub mod encryption;
|
||||||
|
pub mod facade;
|
||||||
pub mod formatting;
|
pub mod formatting;
|
||||||
pub mod input;
|
pub mod input;
|
||||||
|
pub mod llm;
|
||||||
pub mod mcp;
|
pub mod mcp;
|
||||||
|
pub mod mode;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
pub mod oauth;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
|
pub mod providers;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod sandbox;
|
pub mod sandbox;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
pub mod state;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod theme;
|
pub mod theme;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
@@ -33,12 +41,24 @@ pub use credentials::*;
|
|||||||
pub use encryption::*;
|
pub use encryption::*;
|
||||||
pub use formatting::*;
|
pub use formatting::*;
|
||||||
pub use input::*;
|
pub use input::*;
|
||||||
pub use mcp::*;
|
pub use oauth::*;
|
||||||
|
// Export MCP types but exclude test_utils to avoid ambiguity
|
||||||
|
pub use facade::llm_client::*;
|
||||||
|
pub use llm::{
|
||||||
|
ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream,
|
||||||
|
};
|
||||||
|
pub use mcp::{
|
||||||
|
LocalMcpClient, McpServer, McpToolCall, McpToolDescriptor, McpToolResponse, client, factory,
|
||||||
|
failover, permission, protocol, remote_client,
|
||||||
|
};
|
||||||
|
pub use mode::*;
|
||||||
pub use model::*;
|
pub use model::*;
|
||||||
pub use provider::*;
|
pub use provider::*;
|
||||||
|
pub use providers::*;
|
||||||
pub use router::*;
|
pub use router::*;
|
||||||
pub use sandbox::*;
|
pub use sandbox::*;
|
||||||
pub use session::*;
|
pub use session::*;
|
||||||
|
pub use state::*;
|
||||||
pub use theme::*;
|
pub use theme::*;
|
||||||
pub use tools::*;
|
pub use tools::*;
|
||||||
pub use validation::*;
|
pub use validation::*;
|
||||||
@@ -84,4 +104,7 @@ pub enum Error {
|
|||||||
|
|
||||||
#[error("Permission denied: {0}")]
|
#[error("Permission denied: {0}")]
|
||||||
PermissionDenied(String),
|
PermissionDenied(String),
|
||||||
|
|
||||||
|
#[error("Agent execution error: {0}")]
|
||||||
|
Agent(String),
|
||||||
}
|
}
|
||||||
|
|||||||
337
crates/owlen-core/src/llm/mod.rs
Normal file
337
crates/owlen-core/src/llm/mod.rs
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
//! LLM provider abstractions and registry.
|
||||||
|
//!
|
||||||
|
//! This module defines the provider trait hierarchy along with helpers that
|
||||||
|
//! make it easy to register concrete LLM backends and access them through
|
||||||
|
//! dynamic dispatch when wiring the application together.
|
||||||
|
|
||||||
|
use crate::{Error, Result, types::*};
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use futures::{Stream, StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::any::Any;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// A boxed stream of chat responses produced by a provider.
|
||||||
|
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
||||||
|
|
||||||
|
/// Trait implemented by every LLM backend Owlen can speak to.
|
||||||
|
///
|
||||||
|
/// Providers expose both one-shot and streaming prompt APIs. Concrete
|
||||||
|
/// implementations typically live in `crate::providers`.
|
||||||
|
pub trait LlmProvider: Send + Sync + 'static + Any + Sized {
|
||||||
|
/// Stream type returned by [`Self::stream_prompt`].
|
||||||
|
type Stream: Stream<Item = Result<ChatResponse>> + Send + 'static;
|
||||||
|
|
||||||
|
type ListModelsFuture<'a>: Future<Output = Result<Vec<ModelInfo>>> + Send
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
type SendPromptFuture<'a>: Future<Output = Result<ChatResponse>> + Send
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
type StreamPromptFuture<'a>: Future<Output = Result<Self::Stream>> + Send
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
type HealthCheckFuture<'a>: Future<Output = Result<()>> + Send
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
/// Human-readable provider identifier.
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Return metadata on all models exposed by this provider.
|
||||||
|
fn list_models(&self) -> Self::ListModelsFuture<'_>;
|
||||||
|
|
||||||
|
/// Issue a prompt and wait for the provider to return the full response.
|
||||||
|
fn send_prompt(&self, request: ChatRequest) -> Self::SendPromptFuture<'_>;
|
||||||
|
|
||||||
|
/// Issue a prompt and receive responses incrementally as a stream.
|
||||||
|
fn stream_prompt(&self, request: ChatRequest) -> Self::StreamPromptFuture<'_>;
|
||||||
|
|
||||||
|
/// Perform a lightweight health check.
|
||||||
|
fn health_check(&self) -> Self::HealthCheckFuture<'_>;
|
||||||
|
|
||||||
|
/// Provider-specific configuration schema (optional).
|
||||||
|
fn config_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Access the provider as an `Any` for downcasting.
|
||||||
|
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper that requests a streamed generation and yields the first chunk as a
|
||||||
|
/// regular response. This is handy for providers that only implement the
|
||||||
|
/// streaming API.
|
||||||
|
pub async fn send_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result<ChatResponse>
|
||||||
|
where
|
||||||
|
P: LlmProvider + 'a,
|
||||||
|
{
|
||||||
|
let stream = provider.stream_prompt(request).await?;
|
||||||
|
let mut boxed: ChatStream = Box::pin(stream);
|
||||||
|
match boxed.next().await {
|
||||||
|
Some(Ok(response)) => Ok(response),
|
||||||
|
Some(Err(err)) => Err(err),
|
||||||
|
None => Err(Error::Provider(anyhow!(
|
||||||
|
"Empty chat stream from provider {}",
|
||||||
|
provider.name()
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Object-safe wrapper around [`LlmProvider`] for dynamic dispatch scenarios.
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
pub trait Provider: Send + Sync {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||||
|
|
||||||
|
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||||
|
|
||||||
|
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||||
|
|
||||||
|
async fn health_check(&self) -> Result<()>;
|
||||||
|
|
||||||
|
fn config_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any(&self) -> &(dyn Any + Send + Sync);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl<T> Provider for T
|
||||||
|
where
|
||||||
|
T: LlmProvider,
|
||||||
|
{
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
LlmProvider::name(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
|
LlmProvider::list_models(self).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||||
|
LlmProvider::send_prompt(self, request).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||||
|
let stream = LlmProvider::stream_prompt(self, request).await?;
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> Result<()> {
|
||||||
|
LlmProvider::health_check(self).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_schema(&self) -> serde_json::Value {
|
||||||
|
LlmProvider::config_schema(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||||
|
LlmProvider::as_any(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runtime configuration for a provider instance.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
/// Whether this provider should be activated.
|
||||||
|
#[serde(default = "ProviderConfig::default_enabled")]
|
||||||
|
pub enabled: bool,
|
||||||
|
/// Provider type identifier used to resolve implementations.
|
||||||
|
#[serde(default)]
|
||||||
|
pub provider_type: String,
|
||||||
|
/// Base URL for API calls.
|
||||||
|
#[serde(default)]
|
||||||
|
pub base_url: Option<String>,
|
||||||
|
/// API key or token material.
|
||||||
|
#[serde(default)]
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
/// Environment variable holding the API key.
|
||||||
|
#[serde(default)]
|
||||||
|
pub api_key_env: Option<String>,
|
||||||
|
/// Additional provider-specific configuration.
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: HashMap<String, Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderConfig {
|
||||||
|
const fn default_enabled() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Merge the current configuration with overrides from `other`.
|
||||||
|
pub fn merge_from(&mut self, mut other: ProviderConfig) {
|
||||||
|
self.enabled = other.enabled;
|
||||||
|
|
||||||
|
if !other.provider_type.is_empty() {
|
||||||
|
self.provider_type = other.provider_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(base_url) = other.base_url.take() {
|
||||||
|
self.base_url = Some(base_url);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(api_key) = other.api_key.take() {
|
||||||
|
self.api_key = Some(api_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(api_key_env) = other.api_key_env.take() {
|
||||||
|
self.api_key_env = Some(api_key_env);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !other.extra.is_empty() {
|
||||||
|
self.extra.extend(other.extra);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Static registry of providers available to the application.
|
||||||
|
pub struct ProviderRegistry {
|
||||||
|
providers: HashMap<String, Arc<dyn Provider>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
providers: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register<P: LlmProvider + 'static>(&mut self, provider: P) {
|
||||||
|
self.register_arc(Arc::new(provider));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
|
||||||
|
let name = provider.name().to_string();
|
||||||
|
self.providers.insert(name, provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||||
|
self.providers.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_providers(&self) -> Vec<String> {
|
||||||
|
self.providers.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
|
let mut all_models = Vec::new();
|
||||||
|
|
||||||
|
for provider in self.providers.values() {
|
||||||
|
match provider.list_models().await {
|
||||||
|
Ok(mut models) => all_models.append(&mut models),
|
||||||
|
Err(_) => {
|
||||||
|
// Ignore failing providers and continue.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(all_models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ProviderRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test utilities for constructing mock providers.
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod test_utils {
|
||||||
|
use super::*;
|
||||||
|
use futures::stream;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
|
/// Simple provider stub that always returns the same response.
|
||||||
|
pub struct MockProvider {
|
||||||
|
name: String,
|
||||||
|
response: ChatResponse,
|
||||||
|
call_count: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
pub fn new(name: impl Into<String>, response: ChatResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
response,
|
||||||
|
call_count: AtomicUsize::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_count(&self) -> usize {
|
||||||
|
self.call_count.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MockProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(
|
||||||
|
"mock-provider",
|
||||||
|
ChatResponse {
|
||||||
|
message: Message::assistant("mock response".to_string()),
|
||||||
|
usage: None,
|
||||||
|
is_streaming: false,
|
||||||
|
is_final: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlmProvider for MockProvider {
|
||||||
|
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||||
|
|
||||||
|
type ListModelsFuture<'a>
|
||||||
|
= futures::future::Ready<Result<Vec<ModelInfo>>>
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
type SendPromptFuture<'a>
|
||||||
|
= futures::future::Ready<Result<ChatResponse>>
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
type StreamPromptFuture<'a>
|
||||||
|
= futures::future::Ready<Result<Self::Stream>>
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
type HealthCheckFuture<'a>
|
||||||
|
= futures::future::Ready<Result<()>>
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||||
|
futures::future::ready(Ok(vec![]))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_prompt(&self, _request: ChatRequest) -> Self::SendPromptFuture<'_> {
|
||||||
|
self.call_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
futures::future::ready(Ok(self.response.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_prompt(&self, _request: ChatRequest) -> Self::StreamPromptFuture<'_> {
|
||||||
|
self.call_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let response = self.response.clone();
|
||||||
|
futures::future::ready(Ok(stream::iter(vec![Ok(response)])))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||||
|
futures::future::ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
|
use crate::Result;
|
||||||
|
use crate::mode::Mode;
|
||||||
use crate::tools::registry::ToolRegistry;
|
use crate::tools::registry::ToolRegistry;
|
||||||
use crate::validation::SchemaValidator;
|
use crate::validation::SchemaValidator;
|
||||||
use crate::Result;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
pub use client::McpClient;
|
pub use client::McpClient;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -11,6 +12,7 @@ use std::time::Duration;
|
|||||||
|
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod failover;
|
||||||
pub mod permission;
|
pub mod permission;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
pub mod remote_client;
|
pub mod remote_client;
|
||||||
@@ -46,6 +48,7 @@ pub struct McpToolResponse {
|
|||||||
pub struct McpServer {
|
pub struct McpServer {
|
||||||
registry: Arc<ToolRegistry>,
|
registry: Arc<ToolRegistry>,
|
||||||
validator: Arc<SchemaValidator>,
|
validator: Arc<SchemaValidator>,
|
||||||
|
mode: Arc<tokio::sync::RwLock<Mode>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl McpServer {
|
impl McpServer {
|
||||||
@@ -53,14 +56,29 @@ impl McpServer {
|
|||||||
Self {
|
Self {
|
||||||
registry,
|
registry,
|
||||||
validator,
|
validator,
|
||||||
|
mode: Arc::new(tokio::sync::RwLock::new(Mode::default())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the current operating mode
|
||||||
|
pub async fn set_mode(&self, mode: Mode) {
|
||||||
|
*self.mode.write().await = mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current operating mode
|
||||||
|
pub async fn get_mode(&self) -> Mode {
|
||||||
|
*self.mode.read().await
|
||||||
|
}
|
||||||
|
|
||||||
/// Enumerate the registered tools as MCP descriptors
|
/// Enumerate the registered tools as MCP descriptors
|
||||||
pub fn list_tools(&self) -> Vec<McpToolDescriptor> {
|
pub async fn list_tools(&self) -> Vec<McpToolDescriptor> {
|
||||||
|
let mode = self.get_mode().await;
|
||||||
|
let available_tools = self.registry.available_tools(mode).await;
|
||||||
|
|
||||||
self.registry
|
self.registry
|
||||||
.all()
|
.all()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
.filter(|tool| available_tools.contains(&tool.name().to_string()))
|
||||||
.map(|tool| McpToolDescriptor {
|
.map(|tool| McpToolDescriptor {
|
||||||
name: tool.name().to_string(),
|
name: tool.name().to_string(),
|
||||||
description: tool.description().to_string(),
|
description: tool.description().to_string(),
|
||||||
@@ -74,7 +92,11 @@ impl McpServer {
|
|||||||
/// Execute a tool call after validating inputs against the registered schema
|
/// Execute a tool call after validating inputs against the registered schema
|
||||||
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||||
self.validator.validate(&call.name, &call.arguments)?;
|
self.validator.validate(&call.name, &call.arguments)?;
|
||||||
let result = self.registry.execute(&call.name, call.arguments).await?;
|
let mode = self.get_mode().await;
|
||||||
|
let result = self
|
||||||
|
.registry
|
||||||
|
.execute(&call.name, call.arguments, mode)
|
||||||
|
.await?;
|
||||||
Ok(McpToolResponse {
|
Ok(McpToolResponse {
|
||||||
name: call.name,
|
name: call.name,
|
||||||
success: result.success,
|
success: result.success,
|
||||||
@@ -99,15 +121,67 @@ impl LocalMcpClient {
|
|||||||
server: McpServer::new(registry, validator),
|
server: McpServer::new(registry, validator),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the current operating mode
|
||||||
|
pub async fn set_mode(&self, mode: Mode) {
|
||||||
|
self.server.set_mode(mode).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current operating mode
|
||||||
|
pub async fn get_mode(&self) -> Mode {
|
||||||
|
self.server.get_mode().await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl McpClient for LocalMcpClient {
|
impl McpClient for LocalMcpClient {
|
||||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
Ok(self.server.list_tools())
|
Ok(self.server.list_tools().await)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||||
self.server.call_tool(call).await
|
self.server.call_tool(call).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
||||||
|
self.server.set_mode(mode).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod test_utils {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Mock MCP client for testing
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct MockMcpClient;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl McpClient for MockMcpClient {
|
||||||
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
|
Ok(vec![McpToolDescriptor {
|
||||||
|
name: "mock_tool".to_string(),
|
||||||
|
description: "A mock tool for testing".to_string(),
|
||||||
|
input_schema: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
requires_network: false,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||||
|
Ok(McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: true,
|
||||||
|
output: serde_json::json!({"result": "mock result"}),
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
duration_ms: 10,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
use crate::{Error, Result};
|
use crate::{Result, mode::Mode};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
/// Trait for a client that can interact with an MCP server
|
/// Trait for a client that can interact with an MCP server
|
||||||
@@ -10,42 +10,12 @@ pub trait McpClient: Send + Sync {
|
|||||||
|
|
||||||
/// Call a tool on the server
|
/// Call a tool on the server
|
||||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||||
}
|
|
||||||
|
|
||||||
/// Placeholder for a client that connects to a remote MCP server.
|
/// Update the server with the active operating mode.
|
||||||
pub struct RemoteMcpClient;
|
async fn set_mode(&self, _mode: Mode) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
impl RemoteMcpClient {
|
|
||||||
pub fn new() -> Result<Self> {
|
|
||||||
// Attempt to spawn the MCP server binary located at ./target/debug/owlen-mcp-server
|
|
||||||
// The server runs over STDIO and will be managed by the client instance.
|
|
||||||
// For now we just verify that the binary exists; the actual process handling
|
|
||||||
// is performed lazily in the async methods.
|
|
||||||
let path = "./target/debug/owlen-mcp-server";
|
|
||||||
if std::path::Path::new(path).exists() {
|
|
||||||
Ok(Self)
|
|
||||||
} else {
|
|
||||||
Err(Error::NotImplemented(format!(
|
|
||||||
"Remote MCP server binary not found at {}",
|
|
||||||
path
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
// Re-export the concrete implementation that supports stdio and HTTP transports.
|
||||||
impl McpClient for RemoteMcpClient {
|
pub use super::remote_client::RemoteMcpClient;
|
||||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
|
||||||
// TODO: Implement remote call
|
|
||||||
Err(Error::NotImplemented(
|
|
||||||
"Remote MCP client is not implemented".to_string(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn call_tool(&self, _call: McpToolCall) -> Result<McpToolResponse> {
|
|
||||||
// TODO: Implement remote call
|
|
||||||
Err(Error::NotImplemented(
|
|
||||||
"Remote MCP client is not implemented".to_string(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,11 +3,15 @@
|
|||||||
/// Provides a unified interface for creating MCP clients based on configuration.
|
/// Provides a unified interface for creating MCP clients based on configuration.
|
||||||
/// Supports switching between local (in-process) and remote (STDIO) execution modes.
|
/// Supports switching between local (in-process) and remote (STDIO) execution modes.
|
||||||
use super::client::McpClient;
|
use super::client::McpClient;
|
||||||
use super::{remote_client::RemoteMcpClient, LocalMcpClient};
|
use super::{
|
||||||
|
LocalMcpClient,
|
||||||
|
remote_client::{McpRuntimeSecrets, RemoteMcpClient},
|
||||||
|
};
|
||||||
use crate::config::{Config, McpMode};
|
use crate::config::{Config, McpMode};
|
||||||
use crate::tools::registry::ToolRegistry;
|
use crate::tools::registry::ToolRegistry;
|
||||||
use crate::validation::SchemaValidator;
|
use crate::validation::SchemaValidator;
|
||||||
use crate::Result;
|
use crate::{Error, Result};
|
||||||
|
use log::{info, warn};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Factory for creating MCP clients based on configuration
|
/// Factory for creating MCP clients based on configuration
|
||||||
@@ -30,27 +34,78 @@ impl McpClientFactory {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an MCP client based on the current configuration
|
/// Create an MCP client based on the current configuration.
|
||||||
pub fn create(&self) -> Result<Box<dyn McpClient>> {
|
pub fn create(&self) -> Result<Box<dyn McpClient>> {
|
||||||
|
self.create_with_secrets(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides).
|
||||||
|
pub fn create_with_secrets(
|
||||||
|
&self,
|
||||||
|
runtime: Option<McpRuntimeSecrets>,
|
||||||
|
) -> Result<Box<dyn McpClient>> {
|
||||||
match self.config.mcp.mode {
|
match self.config.mcp.mode {
|
||||||
McpMode::Legacy => {
|
McpMode::Disabled => Err(Error::Config(
|
||||||
// Use local in-process client
|
"MCP mode is set to 'disabled'; tooling cannot function in this configuration."
|
||||||
|
.to_string(),
|
||||||
|
)),
|
||||||
|
McpMode::LocalOnly | McpMode::Legacy => {
|
||||||
|
if matches!(self.config.mcp.mode, McpMode::Legacy) {
|
||||||
|
warn!("Using deprecated MCP legacy mode; consider switching to 'local_only'.");
|
||||||
|
}
|
||||||
Ok(Box::new(LocalMcpClient::new(
|
Ok(Box::new(LocalMcpClient::new(
|
||||||
self.registry.clone(),
|
self.registry.clone(),
|
||||||
self.validator.clone(),
|
self.validator.clone(),
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
McpMode::Enabled => {
|
McpMode::RemoteOnly => {
|
||||||
// Attempt to use remote client, fall back to local if unavailable
|
let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| {
|
||||||
match RemoteMcpClient::new() {
|
Error::Config(
|
||||||
Ok(client) => Ok(Box::new(client)),
|
"MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]"
|
||||||
Err(e) => {
|
.to_string(),
|
||||||
eprintln!("Warning: Failed to start remote MCP client: {}. Falling back to local mode.", e);
|
)
|
||||||
Ok(Box::new(LocalMcpClient::new(
|
})?;
|
||||||
self.registry.clone(),
|
|
||||||
self.validator.clone(),
|
RemoteMcpClient::new_with_runtime(server_cfg, runtime)
|
||||||
)))
|
.map(|client| Box::new(client) as Box<dyn McpClient>)
|
||||||
|
.map_err(|e| {
|
||||||
|
Error::Config(format!(
|
||||||
|
"Failed to start remote MCP client '{}': {e}",
|
||||||
|
server_cfg.name
|
||||||
|
))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
McpMode::RemotePreferred => {
|
||||||
|
if let Some(server_cfg) = self.config.effective_mcp_servers().first() {
|
||||||
|
match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()) {
|
||||||
|
Ok(client) => {
|
||||||
|
info!(
|
||||||
|
"Connected to remote MCP server '{}' via {} transport.",
|
||||||
|
server_cfg.name, server_cfg.transport
|
||||||
|
);
|
||||||
|
Ok(Box::new(client) as Box<dyn McpClient>)
|
||||||
|
}
|
||||||
|
Err(e) if self.config.mcp.allow_fallback => {
|
||||||
|
warn!(
|
||||||
|
"Failed to start remote MCP client '{}': {}. Falling back to local tooling.",
|
||||||
|
server_cfg.name, e
|
||||||
|
);
|
||||||
|
Ok(Box::new(LocalMcpClient::new(
|
||||||
|
self.registry.clone(),
|
||||||
|
self.validator.clone(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
Err(e) => Err(Error::Config(format!(
|
||||||
|
"Failed to start remote MCP client '{}': {e}. To allow fallback, set [mcp].allow_fallback = true.",
|
||||||
|
server_cfg.name
|
||||||
|
))),
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
warn!("No MCP servers configured; using local MCP tooling.");
|
||||||
|
Ok(Box::new(LocalMcpClient::new(
|
||||||
|
self.registry.clone(),
|
||||||
|
self.validator.clone(),
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,12 +120,10 @@ impl McpClientFactory {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::Error;
|
||||||
|
use crate::config::McpServerConfig;
|
||||||
|
|
||||||
#[test]
|
fn build_factory(config: Config) -> McpClientFactory {
|
||||||
fn test_factory_creates_local_client_in_legacy_mode() {
|
|
||||||
let mut config = Config::default();
|
|
||||||
config.mcp.mode = McpMode::Legacy;
|
|
||||||
|
|
||||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||||
let registry = Arc::new(ToolRegistry::new(
|
let registry = Arc::new(ToolRegistry::new(
|
||||||
Arc::new(tokio::sync::Mutex::new(config.clone())),
|
Arc::new(tokio::sync::Mutex::new(config.clone())),
|
||||||
@@ -78,9 +131,61 @@ mod tests {
|
|||||||
));
|
));
|
||||||
let validator = Arc::new(SchemaValidator::new());
|
let validator = Arc::new(SchemaValidator::new());
|
||||||
|
|
||||||
let factory = McpClientFactory::new(Arc::new(config), registry, validator);
|
McpClientFactory::new(Arc::new(config), registry, validator)
|
||||||
|
}
|
||||||
|
|
||||||
// Should create without error
|
#[test]
|
||||||
|
fn test_factory_creates_local_client_when_no_servers_configured() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.refresh_mcp_servers(None).unwrap();
|
||||||
|
|
||||||
|
let factory = build_factory(config);
|
||||||
|
|
||||||
|
// Should create without error and fall back to local client
|
||||||
|
let result = factory.create();
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_remote_only_without_servers_errors() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.mcp.mode = McpMode::RemoteOnly;
|
||||||
|
config.mcp_servers.clear();
|
||||||
|
config.refresh_mcp_servers(None).unwrap();
|
||||||
|
|
||||||
|
let factory = build_factory(config);
|
||||||
|
let result = factory.create();
|
||||||
|
assert!(matches!(result, Err(Error::Config(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_remote_preferred_without_fallback_propagates_remote_error() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.mcp.mode = McpMode::RemotePreferred;
|
||||||
|
config.mcp.allow_fallback = false;
|
||||||
|
config.mcp_servers = vec![McpServerConfig {
|
||||||
|
name: "invalid".to_string(),
|
||||||
|
command: "nonexistent-mcp-server-binary".to_string(),
|
||||||
|
args: Vec::new(),
|
||||||
|
transport: "stdio".to_string(),
|
||||||
|
env: std::collections::HashMap::new(),
|
||||||
|
oauth: None,
|
||||||
|
}];
|
||||||
|
config.refresh_mcp_servers(None).unwrap();
|
||||||
|
|
||||||
|
let factory = build_factory(config);
|
||||||
|
let result = factory.create();
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(Error::Config(message)) if message.contains("Failed to start remote MCP client"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_legacy_mode_uses_local_client() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.mcp.mode = McpMode::Legacy;
|
||||||
|
|
||||||
|
let factory = build_factory(config);
|
||||||
let result = factory.create();
|
let result = factory.create();
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|||||||
323
crates/owlen-core/src/mcp/failover.rs
Normal file
323
crates/owlen-core/src/mcp/failover.rs
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
//! Failover and redundancy support for MCP clients
|
||||||
|
//!
|
||||||
|
//! Provides automatic failover between multiple MCP servers with:
|
||||||
|
//! - Health checking
|
||||||
|
//! - Priority-based selection
|
||||||
|
//! - Automatic retry with exponential backoff
|
||||||
|
//! - Circuit breaker pattern
|
||||||
|
|
||||||
|
use super::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
|
use crate::{Error, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
/// Server health status
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum ServerHealth {
|
||||||
|
/// Server is healthy and available
|
||||||
|
Healthy,
|
||||||
|
/// Server is experiencing issues but may recover
|
||||||
|
Degraded { since: Instant },
|
||||||
|
/// Server is down
|
||||||
|
Down { since: Instant },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server configuration with priority
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ServerEntry {
|
||||||
|
/// Name for logging
|
||||||
|
pub name: String,
|
||||||
|
/// MCP client instance
|
||||||
|
pub client: Arc<dyn McpClient>,
|
||||||
|
/// Priority (lower = higher priority)
|
||||||
|
pub priority: u32,
|
||||||
|
/// Health status
|
||||||
|
health: Arc<RwLock<ServerHealth>>,
|
||||||
|
/// Last health check time
|
||||||
|
last_check: Arc<RwLock<Option<Instant>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerEntry {
|
||||||
|
pub fn new(name: String, client: Arc<dyn McpClient>, priority: u32) -> Self {
|
||||||
|
Self {
|
||||||
|
name,
|
||||||
|
client,
|
||||||
|
priority,
|
||||||
|
health: Arc::new(RwLock::new(ServerHealth::Healthy)),
|
||||||
|
last_check: Arc::new(RwLock::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if server is available
|
||||||
|
pub async fn is_available(&self) -> bool {
|
||||||
|
let health = self.health.read().await;
|
||||||
|
matches!(*health, ServerHealth::Healthy)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark server as healthy
|
||||||
|
pub async fn mark_healthy(&self) {
|
||||||
|
let mut health = self.health.write().await;
|
||||||
|
*health = ServerHealth::Healthy;
|
||||||
|
let mut last_check = self.last_check.write().await;
|
||||||
|
*last_check = Some(Instant::now());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark server as down
|
||||||
|
pub async fn mark_down(&self) {
|
||||||
|
let mut health = self.health.write().await;
|
||||||
|
*health = ServerHealth::Down {
|
||||||
|
since: Instant::now(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark server as degraded
|
||||||
|
pub async fn mark_degraded(&self) {
|
||||||
|
let mut health = self.health.write().await;
|
||||||
|
if matches!(*health, ServerHealth::Healthy) {
|
||||||
|
*health = ServerHealth::Degraded {
|
||||||
|
since: Instant::now(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current health status
|
||||||
|
pub async fn get_health(&self) -> ServerHealth {
|
||||||
|
self.health.read().await.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Failover configuration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct FailoverConfig {
|
||||||
|
/// Maximum number of retry attempts
|
||||||
|
pub max_retries: usize,
|
||||||
|
/// Base retry delay (will be exponentially increased)
|
||||||
|
pub base_retry_delay: Duration,
|
||||||
|
/// Health check interval
|
||||||
|
pub health_check_interval: Duration,
|
||||||
|
/// Timeout for health checks
|
||||||
|
pub health_check_timeout: Duration,
|
||||||
|
/// Circuit breaker threshold (failures before opening circuit)
|
||||||
|
pub circuit_breaker_threshold: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FailoverConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_retries: 3,
|
||||||
|
base_retry_delay: Duration::from_millis(100),
|
||||||
|
health_check_interval: Duration::from_secs(30),
|
||||||
|
health_check_timeout: Duration::from_secs(5),
|
||||||
|
circuit_breaker_threshold: 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP client with failover support
|
||||||
|
pub struct FailoverMcpClient {
|
||||||
|
servers: Arc<RwLock<Vec<ServerEntry>>>,
|
||||||
|
config: FailoverConfig,
|
||||||
|
consecutive_failures: Arc<RwLock<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FailoverMcpClient {
|
||||||
|
/// Create a new failover client with multiple servers
|
||||||
|
pub fn new(servers: Vec<ServerEntry>, config: FailoverConfig) -> Self {
|
||||||
|
// Sort servers by priority
|
||||||
|
let mut sorted_servers = servers;
|
||||||
|
sorted_servers.sort_by_key(|s| s.priority);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
servers: Arc::new(RwLock::new(sorted_servers)),
|
||||||
|
config,
|
||||||
|
consecutive_failures: Arc::new(RwLock::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with default configuration
|
||||||
|
pub fn with_servers(servers: Vec<ServerEntry>) -> Self {
|
||||||
|
Self::new(servers, FailoverConfig::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the first available server
|
||||||
|
async fn get_available_server(&self) -> Option<ServerEntry> {
|
||||||
|
let servers = self.servers.read().await;
|
||||||
|
for server in servers.iter() {
|
||||||
|
if server.is_available().await {
|
||||||
|
return Some(server.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute an operation with automatic failover
|
||||||
|
async fn with_failover<F, T>(&self, operation: F) -> Result<T>
|
||||||
|
where
|
||||||
|
F: Fn(Arc<dyn McpClient>) -> futures::future::BoxFuture<'static, Result<T>>,
|
||||||
|
T: Send + 'static,
|
||||||
|
{
|
||||||
|
let mut attempt = 0;
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
|
while attempt < self.config.max_retries {
|
||||||
|
// Get available server
|
||||||
|
let server = match self.get_available_server().await {
|
||||||
|
Some(s) => s,
|
||||||
|
None => {
|
||||||
|
// No healthy servers, try all servers anyway
|
||||||
|
let servers = self.servers.read().await;
|
||||||
|
if let Some(first) = servers.first() {
|
||||||
|
first.clone()
|
||||||
|
} else {
|
||||||
|
return Err(Error::Network("No servers configured".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute operation
|
||||||
|
match operation(server.client.clone()).await {
|
||||||
|
Ok(result) => {
|
||||||
|
server.mark_healthy().await;
|
||||||
|
let mut failures = self.consecutive_failures.write().await;
|
||||||
|
*failures = 0;
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Server '{}' failed: {}", server.name, e);
|
||||||
|
server.mark_degraded().await;
|
||||||
|
last_error = Some(e);
|
||||||
|
|
||||||
|
let mut failures = self.consecutive_failures.write().await;
|
||||||
|
*failures += 1;
|
||||||
|
|
||||||
|
if *failures >= self.config.circuit_breaker_threshold {
|
||||||
|
server.mark_down().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff
|
||||||
|
if attempt < self.config.max_retries - 1 {
|
||||||
|
let delay = self.config.base_retry_delay * 2_u32.pow(attempt as u32);
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_error.unwrap_or_else(|| Error::Network("All servers failed".to_string())))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform health check on all servers
|
||||||
|
pub async fn health_check_all(&self) {
|
||||||
|
let servers = self.servers.read().await;
|
||||||
|
for server in servers.iter() {
|
||||||
|
let client = server.client.clone();
|
||||||
|
let server_clone = server.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match tokio::time::timeout(
|
||||||
|
Duration::from_secs(5),
|
||||||
|
// Use a simple list_tools call as health check
|
||||||
|
async { client.list_tools().await },
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(_)) => server_clone.mark_healthy().await,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
log::warn!("Health check failed for '{}': {}", server_clone.name, e);
|
||||||
|
server_clone.mark_down().await;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
log::warn!("Health check timeout for '{}'", server_clone.name);
|
||||||
|
server_clone.mark_down().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start background health checking
|
||||||
|
pub fn start_health_checks(&self) -> tokio::task::JoinHandle<()> {
|
||||||
|
let client = self.clone_ref();
|
||||||
|
let interval = self.config.health_check_interval;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval_timer = tokio::time::interval(interval);
|
||||||
|
loop {
|
||||||
|
interval_timer.tick().await;
|
||||||
|
client.health_check_all().await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clone the client (returns new handle to same underlying data)
|
||||||
|
fn clone_ref(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
servers: self.servers.clone(),
|
||||||
|
config: self.config.clone(),
|
||||||
|
consecutive_failures: self.consecutive_failures.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get status of all servers
|
||||||
|
pub async fn get_server_status(&self) -> Vec<(String, ServerHealth)> {
|
||||||
|
let servers = self.servers.read().await;
|
||||||
|
let mut status = Vec::new();
|
||||||
|
for server in servers.iter() {
|
||||||
|
status.push((server.name.clone(), server.get_health().await));
|
||||||
|
}
|
||||||
|
status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl McpClient for FailoverMcpClient {
|
||||||
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
|
self.with_failover(|client| Box::pin(async move { client.list_tools().await }))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||||
|
self.with_failover(|client| {
|
||||||
|
let call_clone = call.clone();
|
||||||
|
Box::pin(async move { client.call_tool(call_clone).await })
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_server_entry_health() {
|
||||||
|
use crate::mcp::remote_client::RemoteMcpClient;
|
||||||
|
|
||||||
|
// This would need a mock client in practice
|
||||||
|
// Just demonstrating the API
|
||||||
|
let config = crate::config::McpServerConfig {
|
||||||
|
name: "test".to_string(),
|
||||||
|
command: "test".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
transport: "http".to_string(),
|
||||||
|
env: std::collections::HashMap::new(),
|
||||||
|
oauth: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(client) = RemoteMcpClient::new_with_config(&config) {
|
||||||
|
let entry = ServerEntry::new("test".to_string(), Arc::new(client), 1);
|
||||||
|
|
||||||
|
assert!(entry.is_available().await);
|
||||||
|
|
||||||
|
entry.mark_down().await;
|
||||||
|
assert!(!entry.is_available().await);
|
||||||
|
|
||||||
|
entry.mark_healthy().await;
|
||||||
|
assert!(entry.is_available().await);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,8 +4,8 @@
|
|||||||
/// It wraps MCP clients to filter/whitelist tool calls, log invocations, and prompt for consent.
|
/// It wraps MCP clients to filter/whitelist tool calls, log invocations, and prompt for consent.
|
||||||
use super::client::McpClient;
|
use super::client::McpClient;
|
||||||
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
use crate::config::Config;
|
|
||||||
use crate::{Error, Result};
|
use crate::{Error, Result};
|
||||||
|
use crate::{config::Config, mode::Mode};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -145,6 +145,10 @@ impl McpClient for PermissionLayer {
|
|||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
||||||
|
self.inner.set_mode(mode).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -152,13 +156,14 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::mcp::LocalMcpClient;
|
use crate::mcp::LocalMcpClient;
|
||||||
use crate::tools::registry::ToolRegistry;
|
use crate::tools::registry::ToolRegistry;
|
||||||
|
use crate::ui::NoOpUiController;
|
||||||
use crate::validation::SchemaValidator;
|
use crate::validation::SchemaValidator;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_permission_layer_filters_dangerous_tools() {
|
async fn test_permission_layer_filters_dangerous_tools() {
|
||||||
let config = Arc::new(Config::default());
|
let config = Arc::new(Config::default());
|
||||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
let ui = Arc::new(NoOpUiController);
|
||||||
let registry = Arc::new(ToolRegistry::new(
|
let registry = Arc::new(ToolRegistry::new(
|
||||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||||
ui,
|
ui,
|
||||||
@@ -182,7 +187,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_consent_callback_is_invoked() {
|
async fn test_consent_callback_is_invoked() {
|
||||||
let config = Arc::new(Config::default());
|
let config = Arc::new(Config::default());
|
||||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
let ui = Arc::new(NoOpUiController);
|
||||||
let registry = Arc::new(ToolRegistry::new(
|
let registry = Arc::new(ToolRegistry::new(
|
||||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||||
ui,
|
ui,
|
||||||
|
|||||||
@@ -1,128 +1,353 @@
|
|||||||
use super::protocol::methods;
|
use super::protocol::methods;
|
||||||
use super::protocol::{RequestId, RpcErrorResponse, RpcRequest, RpcResponse, PROTOCOL_VERSION};
|
use super::protocol::{
|
||||||
|
PROTOCOL_VERSION, RequestId, RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse,
|
||||||
|
};
|
||||||
use super::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
use super::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||||
|
use crate::consent::{ConsentManager, ConsentScope};
|
||||||
|
use crate::tools::{Tool, WebScrapeTool, WebSearchTool};
|
||||||
use crate::types::ModelInfo;
|
use crate::types::ModelInfo;
|
||||||
use crate::{Error, Provider, Result};
|
use crate::types::{ChatResponse, Message, Role};
|
||||||
use async_trait::async_trait;
|
use crate::{
|
||||||
|
ChatStream, Error, LlmProvider, Result, facade::llm_client::LlmClient, mode::Mode,
|
||||||
|
send_via_stream,
|
||||||
|
};
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use futures::{StreamExt, future::BoxFuture, stream};
|
||||||
|
use reqwest::Client as HttpClient;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
use tokio::process::{Child, Command};
|
use tokio::process::{Child, Command};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
// Provider trait is already imported via the earlier use statement.
|
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
|
||||||
use crate::types::{ChatResponse, Message, Role};
|
use tungstenite::protocol::Message as WsMessage;
|
||||||
use futures::stream;
|
|
||||||
use futures::StreamExt;
|
|
||||||
|
|
||||||
/// Client that talks to the external `owlen-mcp-server` over STDIO.
|
/// Client that talks to the external `owlen-mcp-server` over STDIO, HTTP, or WebSocket.
|
||||||
pub struct RemoteMcpClient {
|
pub struct RemoteMcpClient {
|
||||||
// Child process handling the server (kept alive for the duration of the client).
|
// Child process handling the server (kept alive for the duration of the client).
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
child: Arc<Mutex<Child>>, // guarded for mutable access across calls
|
// For stdio transport, we keep the child process handles.
|
||||||
// Writer to server stdin.
|
child: Option<Arc<Mutex<Child>>>,
|
||||||
stdin: Arc<Mutex<tokio::process::ChildStdin>>, // async write
|
stdin: Option<Arc<Mutex<tokio::process::ChildStdin>>>, // async write
|
||||||
// Reader for server stdout.
|
stdout: Option<Arc<Mutex<BufReader<tokio::process::ChildStdout>>>>,
|
||||||
stdout: Arc<Mutex<BufReader<tokio::process::ChildStdout>>>,
|
// For HTTP transport we keep a reusable client and base URL.
|
||||||
|
http_client: Option<HttpClient>,
|
||||||
|
http_endpoint: Option<String>,
|
||||||
|
// For WebSocket transport we keep a WebSocket stream.
|
||||||
|
ws_stream: Option<Arc<Mutex<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>>>,
|
||||||
|
#[allow(dead_code)] // Useful for debugging/logging
|
||||||
|
ws_endpoint: Option<String>,
|
||||||
// Incrementing request identifier.
|
// Incrementing request identifier.
|
||||||
next_id: AtomicU64,
|
next_id: AtomicU64,
|
||||||
|
// Optional HTTP header (name, value) injected into every request.
|
||||||
|
http_header: Option<(String, String)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runtime secrets provided when constructing an MCP client.
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct McpRuntimeSecrets {
|
||||||
|
pub env_overrides: HashMap<String, String>,
|
||||||
|
pub http_header: Option<(String, String)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RemoteMcpClient {
|
impl RemoteMcpClient {
|
||||||
/// Spawn the MCP server binary and prepare communication channels.
|
/// Spawn the MCP server binary and prepare communication channels.
|
||||||
|
/// Spawn an MCP server based on a configuration entry.
|
||||||
|
/// The `transport` field must be "stdio" (the only supported mode).
|
||||||
|
/// Spawn an external MCP server based on a configuration entry.
|
||||||
|
/// The server must communicate over STDIO (the only supported transport).
|
||||||
|
pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> {
|
||||||
|
Self::new_with_runtime(config, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_runtime(
|
||||||
|
config: &crate::config::McpServerConfig,
|
||||||
|
runtime: Option<McpRuntimeSecrets>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let mut runtime = runtime.unwrap_or_default();
|
||||||
|
let transport = config.transport.to_lowercase();
|
||||||
|
match transport.as_str() {
|
||||||
|
"stdio" => {
|
||||||
|
// Build the command using the provided binary and arguments.
|
||||||
|
let mut cmd = Command::new(config.command.clone());
|
||||||
|
if !config.args.is_empty() {
|
||||||
|
cmd.args(config.args.clone());
|
||||||
|
}
|
||||||
|
cmd.stdin(std::process::Stdio::piped())
|
||||||
|
.stdout(std::process::Stdio::piped())
|
||||||
|
.stderr(std::process::Stdio::inherit());
|
||||||
|
|
||||||
|
// Apply environment variables defined in the configuration.
|
||||||
|
for (k, v) in config.env.iter() {
|
||||||
|
cmd.env(k, v);
|
||||||
|
}
|
||||||
|
for (k, v) in runtime.env_overrides.drain() {
|
||||||
|
cmd.env(k, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut child = cmd.spawn().map_err(|e| {
|
||||||
|
Error::Io(std::io::Error::new(
|
||||||
|
e.kind(),
|
||||||
|
format!("Failed to spawn MCP server '{}': {}", config.name, e),
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let stdin = child.stdin.take().ok_or_else(|| {
|
||||||
|
Error::Io(std::io::Error::other(
|
||||||
|
"Failed to capture stdin of MCP server",
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let stdout = child.stdout.take().ok_or_else(|| {
|
||||||
|
Error::Io(std::io::Error::other(
|
||||||
|
"Failed to capture stdout of MCP server",
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
child: Some(Arc::new(Mutex::new(child))),
|
||||||
|
stdin: Some(Arc::new(Mutex::new(stdin))),
|
||||||
|
stdout: Some(Arc::new(Mutex::new(BufReader::new(stdout)))),
|
||||||
|
http_client: None,
|
||||||
|
http_endpoint: None,
|
||||||
|
ws_stream: None,
|
||||||
|
ws_endpoint: None,
|
||||||
|
next_id: AtomicU64::new(1),
|
||||||
|
http_header: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"http" => {
|
||||||
|
// For HTTP we treat `command` as the base URL.
|
||||||
|
let client = HttpClient::builder()
|
||||||
|
.timeout(Duration::from_secs(30))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| Error::Network(e.to_string()))?;
|
||||||
|
Ok(Self {
|
||||||
|
child: None,
|
||||||
|
stdin: None,
|
||||||
|
stdout: None,
|
||||||
|
http_client: Some(client),
|
||||||
|
http_endpoint: Some(config.command.clone()),
|
||||||
|
ws_stream: None,
|
||||||
|
ws_endpoint: None,
|
||||||
|
next_id: AtomicU64::new(1),
|
||||||
|
http_header: runtime.http_header.take(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"websocket" => {
|
||||||
|
// For WebSocket, the `command` field contains the WebSocket URL.
|
||||||
|
// We need to use a blocking task to establish the connection.
|
||||||
|
let ws_url = config.command.clone();
|
||||||
|
let (ws_stream, _response) = tokio::task::block_in_place(|| {
|
||||||
|
tokio::runtime::Handle::current().block_on(async {
|
||||||
|
connect_async(&ws_url).await.map_err(|e| {
|
||||||
|
Error::Network(format!("WebSocket connection failed: {}", e))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
child: None,
|
||||||
|
stdin: None,
|
||||||
|
stdout: None,
|
||||||
|
http_client: None,
|
||||||
|
http_endpoint: None,
|
||||||
|
ws_stream: Some(Arc::new(Mutex::new(ws_stream))),
|
||||||
|
ws_endpoint: Some(ws_url),
|
||||||
|
next_id: AtomicU64::new(1),
|
||||||
|
http_header: runtime.http_header.take(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
other => Err(Error::NotImplemented(format!(
|
||||||
|
"Transport '{}' not supported",
|
||||||
|
other
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Legacy constructor kept for compatibility; attempts to locate a binary.
|
||||||
pub fn new() -> Result<Self> {
|
pub fn new() -> Result<Self> {
|
||||||
// Locate the binary – it is built by Cargo into target/debug.
|
// Fall back to searching for a binary as before, then delegate to new_with_config.
|
||||||
// The test binary runs inside the crate directory, so we check a couple of relative locations.
|
|
||||||
// Attempt to locate the server binary; if unavailable we will fall back to launching via `cargo run`.
|
|
||||||
let _ = ();
|
|
||||||
// Resolve absolute path based on workspace root to avoid cwd dependence.
|
|
||||||
// The MCP server binary lives in the workspace's `target/debug` directory.
|
|
||||||
// Historically the binary was named `owlen-mcp-server`, but it has been
|
|
||||||
// renamed to `owlen-mcp-llm-server`. We attempt to locate the new name
|
|
||||||
// first and fall back to the legacy name for compatibility.
|
|
||||||
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||||
.join("../..")
|
.join("../..")
|
||||||
.canonicalize()
|
.canonicalize()
|
||||||
.map_err(Error::Io)?;
|
.map_err(Error::Io)?;
|
||||||
|
// Prefer the LLM server binary as it provides both LLM and resource tools.
|
||||||
|
// The generic file-server is kept as a fallback for testing.
|
||||||
let candidates = [
|
let candidates = [
|
||||||
"target/debug/owlen-mcp-llm-server",
|
"target/debug/owlen-mcp-llm-server",
|
||||||
|
"target/release/owlen-mcp-llm-server",
|
||||||
"target/debug/owlen-mcp-server",
|
"target/debug/owlen-mcp-server",
|
||||||
];
|
];
|
||||||
let mut binary_path = None;
|
let binary_path = candidates
|
||||||
for rel in &candidates {
|
.iter()
|
||||||
let p = workspace_root.join(rel);
|
.map(|rel| workspace_root.join(rel))
|
||||||
if p.exists() {
|
.find(|p| p.exists())
|
||||||
binary_path = Some(p);
|
.ok_or_else(|| {
|
||||||
break;
|
Error::NotImplemented(format!(
|
||||||
}
|
"owlen-mcp server binary not found; checked {}, {}, and {}",
|
||||||
}
|
candidates[0], candidates[1], candidates[2]
|
||||||
let binary_path = binary_path.ok_or_else(|| {
|
))
|
||||||
Error::NotImplemented(format!(
|
})?;
|
||||||
"owlen-mcp server binary not found; checked {} and {}",
|
let config = crate::config::McpServerConfig {
|
||||||
candidates[0], candidates[1]
|
name: "default".to_string(),
|
||||||
))
|
command: binary_path.to_string_lossy().into_owned(),
|
||||||
})?;
|
args: Vec::new(),
|
||||||
if !binary_path.exists() {
|
transport: "stdio".to_string(),
|
||||||
return Err(Error::NotImplemented(format!(
|
env: std::collections::HashMap::new(),
|
||||||
"owlen-mcp-server binary not found at {}",
|
oauth: None,
|
||||||
binary_path.display()
|
};
|
||||||
)));
|
Self::new_with_config(&config)
|
||||||
}
|
|
||||||
// Launch the already‑built server binary directly.
|
|
||||||
let mut child = Command::new(&binary_path)
|
|
||||||
.stdin(std::process::Stdio::piped())
|
|
||||||
.stdout(std::process::Stdio::piped())
|
|
||||||
.stderr(std::process::Stdio::inherit())
|
|
||||||
.spawn()
|
|
||||||
.map_err(Error::Io)?;
|
|
||||||
|
|
||||||
let stdin = child.stdin.take().ok_or_else(|| {
|
|
||||||
Error::Io(std::io::Error::other(
|
|
||||||
"Failed to capture stdin of MCP server",
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
let stdout = child.stdout.take().ok_or_else(|| {
|
|
||||||
Error::Io(std::io::Error::other(
|
|
||||||
"Failed to capture stdout of MCP server",
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
child: Arc::new(Mutex::new(child)),
|
|
||||||
stdin: Arc::new(Mutex::new(stdin)),
|
|
||||||
stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
|
|
||||||
next_id: AtomicU64::new(1),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_rpc(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
|
async fn send_rpc(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
|
||||||
let id = RequestId::Number(self.next_id.fetch_add(1, Ordering::Relaxed));
|
let id = RequestId::Number(self.next_id.fetch_add(1, Ordering::Relaxed));
|
||||||
let request = RpcRequest::new(id.clone(), method, Some(params));
|
let request = RpcRequest::new(id.clone(), method, Some(params));
|
||||||
let req_str = serde_json::to_string(&request)? + "\n";
|
let req_str = serde_json::to_string(&request)? + "\n";
|
||||||
{
|
// For stdio transport we forward the request to the child process.
|
||||||
let mut stdin = self.stdin.lock().await;
|
if let Some(stdin_arc) = &self.stdin {
|
||||||
|
let mut stdin = stdin_arc.lock().await;
|
||||||
stdin.write_all(req_str.as_bytes()).await?;
|
stdin.write_all(req_str.as_bytes()).await?;
|
||||||
stdin.flush().await?;
|
stdin.flush().await?;
|
||||||
}
|
}
|
||||||
// Read a single line response
|
// Read a single line response
|
||||||
let mut line = String::new();
|
// Handle based on selected transport.
|
||||||
{
|
if let Some(client) = &self.http_client {
|
||||||
let mut stdout = self.stdout.lock().await;
|
// HTTP: POST JSON body to endpoint.
|
||||||
stdout.read_line(&mut line).await?;
|
let endpoint = self
|
||||||
}
|
.http_endpoint
|
||||||
// Try to parse successful response first
|
.as_ref()
|
||||||
if let Ok(resp) = serde_json::from_str::<RpcResponse>(&line) {
|
.ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?;
|
||||||
if resp.id == id {
|
let mut builder = client.post(endpoint);
|
||||||
return Ok(resp.result);
|
if let Some((ref header_name, ref header_value)) = self.http_header {
|
||||||
|
builder = builder.header(header_name, header_value);
|
||||||
}
|
}
|
||||||
|
let resp = builder
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Network(e.to_string()))?;
|
||||||
|
let text = resp
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Network(e.to_string()))?;
|
||||||
|
// Try to parse as success then error.
|
||||||
|
if let Ok(r) = serde_json::from_str::<RpcResponse>(&text)
|
||||||
|
&& r.id == id
|
||||||
|
{
|
||||||
|
return Ok(r.result);
|
||||||
|
}
|
||||||
|
let err_resp: RpcErrorResponse =
|
||||||
|
serde_json::from_str(&text).map_err(Error::Serialization)?;
|
||||||
|
return Err(Error::Network(format!(
|
||||||
|
"MCP server error {}: {}",
|
||||||
|
err_resp.error.code, err_resp.error.message
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
// Fallback to error response
|
|
||||||
let err_resp: RpcErrorResponse =
|
// WebSocket path.
|
||||||
serde_json::from_str(&line).map_err(Error::Serialization)?;
|
if let Some(ws_arc) = &self.ws_stream {
|
||||||
Err(Error::Network(format!(
|
use futures::SinkExt;
|
||||||
"MCP server error {}: {}",
|
|
||||||
err_resp.error.code, err_resp.error.message
|
let mut ws = ws_arc.lock().await;
|
||||||
)))
|
|
||||||
|
// Send request as text message
|
||||||
|
let req_json = serde_json::to_string(&request)?;
|
||||||
|
ws.send(WsMessage::Text(req_json))
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Network(format!("WebSocket send failed: {}", e)))?;
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
let response_msg = ws
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| Error::Network("WebSocket stream closed".into()))?
|
||||||
|
.map_err(|e| Error::Network(format!("WebSocket receive failed: {}", e)))?;
|
||||||
|
|
||||||
|
let response_text = match response_msg {
|
||||||
|
WsMessage::Text(text) => text,
|
||||||
|
WsMessage::Binary(data) => String::from_utf8(data).map_err(|e| {
|
||||||
|
Error::Network(format!("Invalid UTF-8 in binary message: {}", e))
|
||||||
|
})?,
|
||||||
|
WsMessage::Close(_) => {
|
||||||
|
return Err(Error::Network(
|
||||||
|
"WebSocket connection closed by server".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
_ => return Err(Error::Network("Unexpected WebSocket message type".into())),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try to parse as success then error.
|
||||||
|
if let Ok(r) = serde_json::from_str::<RpcResponse>(&response_text)
|
||||||
|
&& r.id == id
|
||||||
|
{
|
||||||
|
return Ok(r.result);
|
||||||
|
}
|
||||||
|
let err_resp: RpcErrorResponse =
|
||||||
|
serde_json::from_str(&response_text).map_err(Error::Serialization)?;
|
||||||
|
return Err(Error::Network(format!(
|
||||||
|
"MCP server error {}: {}",
|
||||||
|
err_resp.error.code, err_resp.error.message
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// STDIO path (default).
|
||||||
|
// Loop to skip notifications and find the response with matching ID.
|
||||||
|
loop {
|
||||||
|
let mut line = String::new();
|
||||||
|
{
|
||||||
|
let mut stdout = self
|
||||||
|
.stdout
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| Error::Network("STDIO stdout not available".into()))?
|
||||||
|
.lock()
|
||||||
|
.await;
|
||||||
|
stdout.read_line(&mut line).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as notification first (has no id field)
|
||||||
|
if let Ok(_notif) = serde_json::from_str::<RpcNotification>(&line) {
|
||||||
|
// Skip notifications and continue reading
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse successful response
|
||||||
|
if let Ok(resp) = serde_json::from_str::<RpcResponse>(&line) {
|
||||||
|
if resp.id == id {
|
||||||
|
return Ok(resp.result);
|
||||||
|
}
|
||||||
|
// If ID doesn't match, continue (though this shouldn't happen)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to error response
|
||||||
|
if let Ok(err_resp) = serde_json::from_str::<RpcErrorResponse>(&line) {
|
||||||
|
return Err(Error::Network(format!(
|
||||||
|
"MCP server error {}: {}",
|
||||||
|
err_resp.error.code, err_resp.error.message
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can't parse as any known type, return error
|
||||||
|
return Err(Error::Network(format!(
|
||||||
|
"Unable to parse server response: {}",
|
||||||
|
line.trim()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteMcpClient {
|
||||||
|
/// Convenience wrapper delegating to the `McpClient` trait methods.
|
||||||
|
pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
|
<Self as McpClient>::list_tools(self).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||||
|
<Self as McpClient>::call_tool(self, call).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,18 +400,103 @@ impl McpClient for RemoteMcpClient {
|
|||||||
duration_ms: 0,
|
duration_ms: 0,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// Handle write and delete resources locally as well.
|
||||||
|
if call.name.starts_with("resources/write") {
|
||||||
|
let path = call
|
||||||
|
.arguments
|
||||||
|
.get("path")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| Error::InvalidInput("path missing".into()))?;
|
||||||
|
// Simple path‑traversal protection: reject any path containing ".." or absolute paths.
|
||||||
|
if path.contains("..") || Path::new(path).is_absolute() {
|
||||||
|
return Err(Error::InvalidInput("path traversal".into()));
|
||||||
|
}
|
||||||
|
let content = call
|
||||||
|
.arguments
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| Error::InvalidInput("content missing".into()))?;
|
||||||
|
std::fs::write(path, content).map_err(Error::Io)?;
|
||||||
|
return Ok(McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: true,
|
||||||
|
output: serde_json::json!(null),
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
duration_ms: 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if call.name.starts_with("resources/delete") {
|
||||||
|
let path = call
|
||||||
|
.arguments
|
||||||
|
.get("path")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| Error::InvalidInput("path missing".into()))?;
|
||||||
|
if path.contains("..") || Path::new(path).is_absolute() {
|
||||||
|
return Err(Error::InvalidInput("path traversal".into()));
|
||||||
|
}
|
||||||
|
std::fs::remove_file(path).map_err(Error::Io)?;
|
||||||
|
return Ok(McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: true,
|
||||||
|
output: serde_json::json!(null),
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
duration_ms: 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Local handling for web tools to avoid needing an external MCP server.
|
||||||
|
if call.name == "web_search" {
|
||||||
|
// Auto‑grant consent for the web_search tool (permanent for this process).
|
||||||
|
let consent_manager = std::sync::Arc::new(std::sync::Mutex::new(ConsentManager::new()));
|
||||||
|
{
|
||||||
|
let mut cm = consent_manager
|
||||||
|
.lock()
|
||||||
|
.map_err(|_| Error::Provider(anyhow!("Consent manager mutex poisoned")))?;
|
||||||
|
cm.grant_consent_with_scope(
|
||||||
|
"web_search",
|
||||||
|
Vec::new(),
|
||||||
|
Vec::new(),
|
||||||
|
ConsentScope::Permanent,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let tool = WebSearchTool::new(consent_manager.clone(), None, None);
|
||||||
|
let result = tool
|
||||||
|
.execute(call.arguments.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Provider(e.into()))?;
|
||||||
|
return Ok(McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: true,
|
||||||
|
output: result.output,
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
duration_ms: result.duration.as_millis() as u128,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if call.name == "web_scrape" {
|
||||||
|
let tool = WebScrapeTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(call.arguments.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Provider(e.into()))?;
|
||||||
|
return Ok(McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: true,
|
||||||
|
output: result.output,
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
duration_ms: result.duration.as_millis() as u128,
|
||||||
|
});
|
||||||
|
}
|
||||||
// MCP server expects a generic "tools/call" method with a payload containing the
|
// MCP server expects a generic "tools/call" method with a payload containing the
|
||||||
// specific tool name and its arguments. Wrap the incoming call accordingly.
|
// specific tool name and its arguments. Wrap the incoming call accordingly.
|
||||||
let payload = serde_json::to_value(&call)?;
|
let payload = serde_json::to_value(&call)?;
|
||||||
let result = self.send_rpc(methods::TOOLS_CALL, payload).await?;
|
let result = self.send_rpc(methods::TOOLS_CALL, payload).await?;
|
||||||
// The server returns the tool's output directly; construct a matching response.
|
// The server returns an McpToolResponse; deserialize it.
|
||||||
Ok(McpToolResponse {
|
let response: McpToolResponse = serde_json::from_value(result)?;
|
||||||
name: call.name,
|
Ok(response)
|
||||||
success: true,
|
}
|
||||||
output: result,
|
|
||||||
metadata: std::collections::HashMap::new(),
|
async fn set_mode(&self, _mode: Mode) -> Result<()> {
|
||||||
duration_ms: 0,
|
// Remote servers manage their own mode settings; treat as best-effort no-op.
|
||||||
})
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,61 +504,90 @@ impl McpClient for RemoteMcpClient {
|
|||||||
// Provider implementation – forwards chat requests to the generate_text tool.
|
// Provider implementation – forwards chat requests to the generate_text tool.
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
#[async_trait]
|
impl LlmProvider for RemoteMcpClient {
|
||||||
impl Provider for RemoteMcpClient {
|
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||||
|
type ListModelsFuture<'a> = BoxFuture<'a, Result<Vec<ModelInfo>>>;
|
||||||
|
type SendPromptFuture<'a> = BoxFuture<'a, Result<ChatResponse>>;
|
||||||
|
type StreamPromptFuture<'a> = BoxFuture<'a, Result<Self::Stream>>;
|
||||||
|
type HealthCheckFuture<'a> = BoxFuture<'a, Result<()>>;
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
"mcp-llm-server"
|
"mcp-llm-server"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let result = self.send_rpc(methods::MODELS_LIST, json!(null)).await?;
|
||||||
|
let models: Vec<ModelInfo> = serde_json::from_value(result)?;
|
||||||
|
Ok(models)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_prompt(&self, request: crate::types::ChatRequest) -> Self::SendPromptFuture<'_> {
|
||||||
|
Box::pin(send_via_stream(self, request))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_prompt(&self, request: crate::types::ChatRequest) -> Self::StreamPromptFuture<'_> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let args = serde_json::json!({
|
||||||
|
"messages": request.messages,
|
||||||
|
"temperature": request.parameters.temperature,
|
||||||
|
"max_tokens": request.parameters.max_tokens,
|
||||||
|
"model": request.model,
|
||||||
|
"stream": request.parameters.stream,
|
||||||
|
});
|
||||||
|
let call = McpToolCall {
|
||||||
|
name: "generate_text".to_string(),
|
||||||
|
arguments: args,
|
||||||
|
};
|
||||||
|
let resp = self.call_tool(call).await?;
|
||||||
|
let content = resp.output.as_str().unwrap_or("").to_string();
|
||||||
|
let message = Message::new(Role::Assistant, content);
|
||||||
|
let chat_resp = ChatResponse {
|
||||||
|
message,
|
||||||
|
usage: None,
|
||||||
|
is_streaming: false,
|
||||||
|
is_final: true,
|
||||||
|
};
|
||||||
|
Ok(stream::iter(vec![Ok(chat_resp)]))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let params = serde_json::json!({
|
||||||
|
"protocol_version": PROTOCOL_VERSION,
|
||||||
|
"client_info": {
|
||||||
|
"name": "owlen",
|
||||||
|
"version": env!("CARGO_PKG_VERSION"),
|
||||||
|
},
|
||||||
|
"capabilities": {}
|
||||||
|
});
|
||||||
|
self.send_rpc(methods::INITIALIZE, params).await.map(|_| ())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl LlmClient for RemoteMcpClient {
|
||||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
let result = self.send_rpc(methods::MODELS_LIST, json!(null)).await?;
|
<Self as LlmProvider>::list_models(self).await
|
||||||
let models: Vec<ModelInfo> = serde_json::from_value(result)?;
|
|
||||||
Ok(models)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat(&self, request: crate::types::ChatRequest) -> Result<ChatResponse> {
|
async fn send_chat(&self, request: crate::types::ChatRequest) -> Result<ChatResponse> {
|
||||||
// Use the streaming implementation and take the first response.
|
<Self as LlmProvider>::send_prompt(self, request).await
|
||||||
let mut stream = self.chat_stream(request).await?;
|
|
||||||
match stream.next().await {
|
|
||||||
Some(Ok(resp)) => Ok(resp),
|
|
||||||
Some(Err(e)) => Err(e),
|
|
||||||
None => Err(Error::Provider(anyhow::anyhow!("Empty chat stream"))),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_stream(
|
async fn stream_chat(&self, request: crate::types::ChatRequest) -> Result<ChatStream> {
|
||||||
&self,
|
let stream = <Self as LlmProvider>::stream_prompt(self, request).await?;
|
||||||
request: crate::types::ChatRequest,
|
|
||||||
) -> Result<crate::provider::ChatStream> {
|
|
||||||
// Build arguments matching the generate_text schema.
|
|
||||||
let args = serde_json::json!({
|
|
||||||
"messages": request.messages,
|
|
||||||
"temperature": request.parameters.temperature,
|
|
||||||
"max_tokens": request.parameters.max_tokens,
|
|
||||||
"model": request.model,
|
|
||||||
"stream": request.parameters.stream,
|
|
||||||
});
|
|
||||||
let call = McpToolCall {
|
|
||||||
name: "generate_text".to_string(),
|
|
||||||
arguments: args,
|
|
||||||
};
|
|
||||||
let resp = self.call_tool(call).await?;
|
|
||||||
// Build a ChatResponse from the tool output (assumed to be a string).
|
|
||||||
let content = resp.output.as_str().unwrap_or("").to_string();
|
|
||||||
let message = Message::new(Role::Assistant, content);
|
|
||||||
let chat_resp = ChatResponse {
|
|
||||||
message,
|
|
||||||
usage: None,
|
|
||||||
is_streaming: false,
|
|
||||||
is_final: true,
|
|
||||||
};
|
|
||||||
let stream = stream::once(async move { Ok(chat_resp) });
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_check(&self) -> Result<()> {
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
// Simple ping using initialize method.
|
<Self as McpClient>::list_tools(self).await
|
||||||
let params = serde_json::json!({"protocol_version": PROTOCOL_VERSION});
|
}
|
||||||
self.send_rpc("initialize", params).await.map(|_| ())
|
|
||||||
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||||
|
<Self as McpClient>::call_tool(self, call).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
182
crates/owlen-core/src/mode.rs
Normal file
182
crates/owlen-core/src/mode.rs
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
//! Operating modes for Owlen
|
||||||
|
//!
|
||||||
|
//! Defines the different modes in which Owlen can operate and their associated
|
||||||
|
//! tool availability policies.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
/// Operating mode for Owlen
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Mode {
|
||||||
|
/// Chat mode - limited tool access, safe for general conversation
|
||||||
|
#[default]
|
||||||
|
Chat,
|
||||||
|
/// Code mode - full tool access for development tasks
|
||||||
|
Code,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mode {
|
||||||
|
/// Get the display name for this mode
|
||||||
|
pub fn display_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Mode::Chat => "chat",
|
||||||
|
Mode::Code => "code",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Mode {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.display_name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for Mode {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"chat" => Ok(Mode::Chat),
|
||||||
|
"code" => Ok(Mode::Code),
|
||||||
|
_ => Err(format!(
|
||||||
|
"Invalid mode: '{}'. Valid modes are 'chat' or 'code'",
|
||||||
|
s
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for tool availability in different modes
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModeConfig {
|
||||||
|
/// Tools allowed in chat mode
|
||||||
|
#[serde(default = "ModeConfig::default_chat_tools")]
|
||||||
|
pub chat: ModeToolConfig,
|
||||||
|
/// Tools allowed in code mode
|
||||||
|
#[serde(default = "ModeConfig::default_code_tools")]
|
||||||
|
pub code: ModeToolConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ModeConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
chat: Self::default_chat_tools(),
|
||||||
|
code: Self::default_code_tools(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModeConfig {
|
||||||
|
fn default_chat_tools() -> ModeToolConfig {
|
||||||
|
ModeToolConfig {
|
||||||
|
allowed_tools: vec!["web_search".to_string()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_code_tools() -> ModeToolConfig {
|
||||||
|
ModeToolConfig {
|
||||||
|
allowed_tools: vec!["*".to_string()], // All tools allowed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a tool is allowed in the given mode
|
||||||
|
pub fn is_tool_allowed(&self, mode: Mode, tool_name: &str) -> bool {
|
||||||
|
let config = match mode {
|
||||||
|
Mode::Chat => &self.chat,
|
||||||
|
Mode::Code => &self.code,
|
||||||
|
};
|
||||||
|
|
||||||
|
config.is_tool_allowed(tool_name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool configuration for a specific mode
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModeToolConfig {
|
||||||
|
/// List of allowed tools. Use "*" to allow all tools.
|
||||||
|
pub allowed_tools: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModeToolConfig {
|
||||||
|
/// Check if a tool is allowed in this mode
|
||||||
|
pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
|
||||||
|
// Check for wildcard
|
||||||
|
if self.allowed_tools.iter().any(|t| t == "*") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if tool is explicitly listed
|
||||||
|
self.allowed_tools.iter().any(|t| t == tool_name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mode_display() {
|
||||||
|
assert_eq!(Mode::Chat.to_string(), "chat");
|
||||||
|
assert_eq!(Mode::Code.to_string(), "code");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mode_from_str() {
|
||||||
|
assert_eq!("chat".parse::<Mode>(), Ok(Mode::Chat));
|
||||||
|
assert_eq!("code".parse::<Mode>(), Ok(Mode::Code));
|
||||||
|
assert_eq!("CHAT".parse::<Mode>(), Ok(Mode::Chat));
|
||||||
|
assert_eq!("CODE".parse::<Mode>(), Ok(Mode::Code));
|
||||||
|
assert!("invalid".parse::<Mode>().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_mode() {
|
||||||
|
assert_eq!(Mode::default(), Mode::Chat);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_mode_restrictions() {
|
||||||
|
let config = ModeConfig::default();
|
||||||
|
|
||||||
|
// Web search should be allowed in chat mode
|
||||||
|
assert!(config.is_tool_allowed(Mode::Chat, "web_search"));
|
||||||
|
|
||||||
|
// Code exec should not be allowed in chat mode
|
||||||
|
assert!(!config.is_tool_allowed(Mode::Chat, "code_exec"));
|
||||||
|
assert!(!config.is_tool_allowed(Mode::Chat, "file_write"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_code_mode_allows_all() {
|
||||||
|
let config = ModeConfig::default();
|
||||||
|
|
||||||
|
// All tools should be allowed in code mode
|
||||||
|
assert!(config.is_tool_allowed(Mode::Code, "web_search"));
|
||||||
|
assert!(config.is_tool_allowed(Mode::Code, "code_exec"));
|
||||||
|
assert!(config.is_tool_allowed(Mode::Code, "file_write"));
|
||||||
|
assert!(config.is_tool_allowed(Mode::Code, "anything"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_tool_config() {
|
||||||
|
let config = ModeToolConfig {
|
||||||
|
allowed_tools: vec!["*".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(config.is_tool_allowed("any_tool"));
|
||||||
|
assert!(config.is_tool_allowed("another_tool"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_explicit_tool_list() {
|
||||||
|
let config = ModeToolConfig {
|
||||||
|
allowed_tools: vec!["tool1".to_string(), "tool2".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(config.is_tool_allowed("tool1"));
|
||||||
|
assert!(config.is_tool_allowed("tool2"));
|
||||||
|
assert!(!config.is_tool_allowed("tool3"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,10 @@
|
|||||||
use crate::types::ModelInfo;
|
pub mod details;
|
||||||
|
|
||||||
|
pub use details::{DetailedModelInfo, ModelInfoRetrievalError};
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
use crate::types::ModelInfo;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@@ -37,10 +42,8 @@ impl ModelManager {
|
|||||||
F: FnOnce() -> Fut,
|
F: FnOnce() -> Fut,
|
||||||
Fut: Future<Output = Result<Vec<ModelInfo>>>,
|
Fut: Future<Output = Result<Vec<ModelInfo>>>,
|
||||||
{
|
{
|
||||||
if !force_refresh {
|
if let (false, Some(models)) = (force_refresh, self.cached_if_fresh().await) {
|
||||||
if let Some(models) = self.cached_if_fresh().await {
|
return Ok(models);
|
||||||
return Ok(models);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let models = fetcher().await?;
|
let models = fetcher().await?;
|
||||||
@@ -82,3 +85,125 @@ impl ModelManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
struct ModelDetailsCacheInner {
|
||||||
|
by_key: HashMap<String, DetailedModelInfo>,
|
||||||
|
name_to_key: HashMap<String, String>,
|
||||||
|
fetched_at: HashMap<String, Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cache for rich model details, indexed by digest when available.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ModelDetailsCache {
|
||||||
|
inner: Arc<RwLock<ModelDetailsCacheInner>>,
|
||||||
|
ttl: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelDetailsCache {
|
||||||
|
/// Create a new details cache with the provided TTL.
|
||||||
|
pub fn new(ttl: Duration) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(RwLock::new(ModelDetailsCacheInner::default())),
|
||||||
|
ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to read cached details for the provided model name.
|
||||||
|
pub async fn get(&self, name: &str) -> Option<DetailedModelInfo> {
|
||||||
|
let mut inner = self.inner.write().await;
|
||||||
|
let key = inner.name_to_key.get(name).cloned()?;
|
||||||
|
let stale = inner
|
||||||
|
.fetched_at
|
||||||
|
.get(&key)
|
||||||
|
.is_some_and(|ts| ts.elapsed() >= self.ttl);
|
||||||
|
if stale {
|
||||||
|
inner.by_key.remove(&key);
|
||||||
|
inner.name_to_key.remove(name);
|
||||||
|
inner.fetched_at.remove(&key);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
inner.by_key.get(&key).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cache the provided details, overwriting existing entries.
|
||||||
|
pub async fn insert(&self, info: DetailedModelInfo) {
|
||||||
|
let key = info.digest.clone().unwrap_or_else(|| info.name.clone());
|
||||||
|
let mut inner = self.inner.write().await;
|
||||||
|
|
||||||
|
// Remove prior mappings for this model name (possibly different digest).
|
||||||
|
if let Some(previous_key) = inner.name_to_key.get(&info.name).cloned()
|
||||||
|
&& previous_key != key
|
||||||
|
{
|
||||||
|
inner.by_key.remove(&previous_key);
|
||||||
|
inner.fetched_at.remove(&previous_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
inner.fetched_at.insert(key.clone(), Instant::now());
|
||||||
|
inner.name_to_key.insert(info.name.clone(), key.clone());
|
||||||
|
inner.by_key.insert(key, info);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a specific model from the cache.
|
||||||
|
pub async fn invalidate(&self, name: &str) {
|
||||||
|
let mut inner = self.inner.write().await;
|
||||||
|
if let Some(key) = inner.name_to_key.remove(name) {
|
||||||
|
inner.by_key.remove(&key);
|
||||||
|
inner.fetched_at.remove(&key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the entire cache.
|
||||||
|
pub async fn invalidate_all(&self) {
|
||||||
|
let mut inner = self.inner.write().await;
|
||||||
|
inner.by_key.clear();
|
||||||
|
inner.name_to_key.clear();
|
||||||
|
inner.fetched_at.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return all cached values regardless of freshness.
|
||||||
|
pub async fn cached(&self) -> Vec<DetailedModelInfo> {
|
||||||
|
let inner = self.inner.read().await;
|
||||||
|
inner.by_key.values().cloned().collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
fn sample_details(name: &str) -> DetailedModelInfo {
|
||||||
|
DetailedModelInfo {
|
||||||
|
name: name.to_string(),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn model_details_cache_returns_cached_entry() {
|
||||||
|
let cache = ModelDetailsCache::new(Duration::from_millis(50));
|
||||||
|
let info = sample_details("llama");
|
||||||
|
cache.insert(info.clone()).await;
|
||||||
|
let cached = cache.get("llama").await;
|
||||||
|
assert!(cached.is_some());
|
||||||
|
assert_eq!(cached.unwrap().name, "llama");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn model_details_cache_expires_based_on_ttl() {
|
||||||
|
let cache = ModelDetailsCache::new(Duration::from_millis(10));
|
||||||
|
cache.insert(sample_details("phi")).await;
|
||||||
|
sleep(Duration::from_millis(30)).await;
|
||||||
|
assert!(cache.get("phi").await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn model_details_cache_invalidate_removes_entry() {
|
||||||
|
let cache = ModelDetailsCache::new(Duration::from_secs(1));
|
||||||
|
cache.insert(sample_details("mistral")).await;
|
||||||
|
cache.invalidate("mistral").await;
|
||||||
|
assert!(cache.get("mistral").await.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
105
crates/owlen-core/src/model/details.rs
Normal file
105
crates/owlen-core/src/model/details.rs
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
//! Detailed model metadata for provider inspection features.
|
||||||
|
//!
|
||||||
|
//! These types capture richer information about locally available models
|
||||||
|
//! than the lightweight [`crate::types::ModelInfo`] listing and back the
|
||||||
|
//! higher-level inspection UI exposed in the Owlen TUI.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Rich metadata about an Ollama model.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct DetailedModelInfo {
|
||||||
|
/// Canonical model name (including tag).
|
||||||
|
pub name: String,
|
||||||
|
/// Reported architecture or model format.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub architecture: Option<String>,
|
||||||
|
/// Human-readable parameter / quantisation summary.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parameters: Option<String>,
|
||||||
|
/// Context window length, if provided.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub context_length: Option<u64>,
|
||||||
|
/// Embedding vector length for embedding-capable models.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub embedding_length: Option<u64>,
|
||||||
|
/// Quantisation level (e.g., Q4_0, Q5_K_M).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub quantization: Option<String>,
|
||||||
|
/// Primary family identifier (e.g., llama3).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub family: Option<String>,
|
||||||
|
/// Additional family tags reported by Ollama.
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub families: Vec<String>,
|
||||||
|
/// Verbose parameter size description (e.g., 70B parameters).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parameter_size: Option<String>,
|
||||||
|
/// Default prompt template packaged with the model.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub template: Option<String>,
|
||||||
|
/// Default system prompt packaged with the model.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system: Option<String>,
|
||||||
|
/// License string provided by the model.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub license: Option<String>,
|
||||||
|
/// Raw modelfile contents (if available).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub modelfile: Option<String>,
|
||||||
|
/// Modification timestamp (ISO-8601) if reported.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub modified_at: Option<String>,
|
||||||
|
/// Approximate model size in bytes.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub size: Option<u64>,
|
||||||
|
/// Digest / checksum used by Ollama (sha256).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub digest: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DetailedModelInfo {
|
||||||
|
/// Convenience helper that normalises empty strings to `None`.
|
||||||
|
pub fn with_normalised_strings(mut self) -> Self {
|
||||||
|
if self.architecture.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.architecture = None;
|
||||||
|
}
|
||||||
|
if self.parameters.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.parameters = None;
|
||||||
|
}
|
||||||
|
if self.quantization.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.quantization = None;
|
||||||
|
}
|
||||||
|
if self.family.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.family = None;
|
||||||
|
}
|
||||||
|
if self.parameter_size.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.parameter_size = None;
|
||||||
|
}
|
||||||
|
if self.template.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.template = None;
|
||||||
|
}
|
||||||
|
if self.system.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.system = None;
|
||||||
|
}
|
||||||
|
if self.license.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.license = None;
|
||||||
|
}
|
||||||
|
if self.modelfile.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.modelfile = None;
|
||||||
|
}
|
||||||
|
if self.digest.as_ref().is_some_and(String::is_empty) {
|
||||||
|
self.digest = None;
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error payload returned when model inspection fails for a specific model.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelInfoRetrievalError {
|
||||||
|
/// Model that failed to resolve.
|
||||||
|
pub model_name: String,
|
||||||
|
/// Human-readable description of the failure.
|
||||||
|
pub error_message: String,
|
||||||
|
}
|
||||||
507
crates/owlen-core/src/oauth.rs
Normal file
507
crates/owlen-core/src/oauth.rs
Normal file
@@ -0,0 +1,507 @@
|
|||||||
|
use std::time::Duration as StdDuration;
|
||||||
|
|
||||||
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{Error, Result, config::McpOAuthConfig};
|
||||||
|
|
||||||
|
/// Persisted OAuth token set for MCP servers and providers.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
pub struct OAuthToken {
|
||||||
|
/// Bearer access token returned by the authorization server.
|
||||||
|
pub access_token: String,
|
||||||
|
/// Optional refresh token if the provider issues one.
|
||||||
|
#[serde(default)]
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
/// Absolute UTC expiration timestamp for the access token.
|
||||||
|
#[serde(default)]
|
||||||
|
pub expires_at: Option<DateTime<Utc>>,
|
||||||
|
/// Optional space-delimited scope string supplied by the provider.
|
||||||
|
#[serde(default)]
|
||||||
|
pub scope: Option<String>,
|
||||||
|
/// Token type reported by the provider (typically `Bearer`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub token_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthToken {
|
||||||
|
/// Returns `true` if the access token has expired at the provided instant.
|
||||||
|
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||||
|
matches!(self.expires_at, Some(expiry) if now >= expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the token will expire within the supplied duration window.
|
||||||
|
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
|
||||||
|
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Active device-authorization session details returned by the authorization server.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DeviceAuthorization {
|
||||||
|
pub device_code: String,
|
||||||
|
pub user_code: String,
|
||||||
|
pub verification_uri: String,
|
||||||
|
pub verification_uri_complete: Option<String>,
|
||||||
|
pub expires_at: DateTime<Utc>,
|
||||||
|
pub interval: StdDuration,
|
||||||
|
pub message: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeviceAuthorization {
|
||||||
|
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||||
|
now >= self.expires_at
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of polling the token endpoint during a device-authorization flow.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum DevicePollState {
|
||||||
|
Pending { retry_in: StdDuration },
|
||||||
|
Complete(OAuthToken),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OAuthClient {
|
||||||
|
http: Client,
|
||||||
|
config: McpOAuthConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthClient {
|
||||||
|
pub fn new(config: McpOAuthConfig) -> Result<Self> {
|
||||||
|
let http = Client::builder()
|
||||||
|
.user_agent("OwlenOAuth/1.0")
|
||||||
|
.build()
|
||||||
|
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
|
||||||
|
Ok(Self { http, config })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scope_value(&self) -> Option<String> {
|
||||||
|
if self.config.scopes.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.config.scopes.join(" "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_request_base(&self) -> Vec<(String, String)> {
|
||||||
|
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
|
||||||
|
if let Some(secret) = &self.config.client_secret {
|
||||||
|
params.push(("client_secret".to_string(), secret.clone()));
|
||||||
|
}
|
||||||
|
params
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
|
||||||
|
let device_url = self
|
||||||
|
.config
|
||||||
|
.device_authorization_url
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
Error::Config("Device authorization endpoint is not configured.".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut params = self.token_request_base();
|
||||||
|
if let Some(scope) = self.scope_value() {
|
||||||
|
params.push(("scope".to_string(), scope));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(device_url)
|
||||||
|
.form(¶ms)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_http_error("start device authorization", err))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let payload = response
|
||||||
|
.json::<DeviceAuthorizationResponse>()
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
Error::Auth(format!(
|
||||||
|
"Failed to parse device authorization response (status {status}): {err}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let expires_at =
|
||||||
|
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
|
||||||
|
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
|
||||||
|
|
||||||
|
Ok(DeviceAuthorization {
|
||||||
|
device_code: payload.device_code,
|
||||||
|
user_code: payload.user_code,
|
||||||
|
verification_uri: payload.verification_uri,
|
||||||
|
verification_uri_complete: payload.verification_uri_complete,
|
||||||
|
expires_at,
|
||||||
|
interval,
|
||||||
|
message: payload.message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
|
||||||
|
let mut params = self.token_request_base();
|
||||||
|
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
|
||||||
|
params.push(("device_code".to_string(), auth.device_code.clone()));
|
||||||
|
if let Some(scope) = self.scope_value() {
|
||||||
|
params.push(("scope".to_string(), scope));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(&self.config.token_url)
|
||||||
|
.form(¶ms)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_http_error("poll device token", err))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_http_error("read token response", err))?;
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||||
|
Error::Auth(format!(
|
||||||
|
"Failed to parse OAuth token response: {err}; body: {text}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
return Ok(DevicePollState::Complete(oauth_token_from_response(
|
||||||
|
payload,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||||
|
OAuthErrorResponse {
|
||||||
|
error: "unknown_error".to_string(),
|
||||||
|
error_description: Some(text.clone()),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
match error.error.as_str() {
|
||||||
|
"authorization_pending" => Ok(DevicePollState::Pending {
|
||||||
|
retry_in: auth.interval,
|
||||||
|
}),
|
||||||
|
"slow_down" => Ok(DevicePollState::Pending {
|
||||||
|
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
|
||||||
|
}),
|
||||||
|
"access_denied" => {
|
||||||
|
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||||
|
"User declined authorization".to_string()
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
"expired_token" | "expired_device_code" => {
|
||||||
|
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||||
|
"Device authorization expired".to_string()
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
other => Err(Error::Auth(
|
||||||
|
error
|
||||||
|
.error_description
|
||||||
|
.unwrap_or_else(|| format!("OAuth error: {other}")),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
|
||||||
|
let mut params = self.token_request_base();
|
||||||
|
params.push(("grant_type".to_string(), "refresh_token".to_string()));
|
||||||
|
params.push(("refresh_token".to_string(), refresh_token.to_string()));
|
||||||
|
if let Some(scope) = self.scope_value() {
|
||||||
|
params.push(("scope".to_string(), scope));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(&self.config.token_url)
|
||||||
|
.form(¶ms)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_http_error("refresh OAuth token", err))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_http_error("read refresh response", err))?;
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||||
|
Error::Auth(format!(
|
||||||
|
"Failed to parse OAuth refresh response: {err}; body: {text}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(oauth_token_from_response(payload))
|
||||||
|
} else {
|
||||||
|
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||||
|
OAuthErrorResponse {
|
||||||
|
error: "unknown_error".to_string(),
|
||||||
|
error_description: Some(text.clone()),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||||
|
format!("OAuth token refresh failed: {}", error.error)
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeviceAuthorizationResponse {
|
||||||
|
device_code: String,
|
||||||
|
user_code: String,
|
||||||
|
verification_uri: String,
|
||||||
|
#[serde(default)]
|
||||||
|
verification_uri_complete: Option<String>,
|
||||||
|
expires_in: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
interval: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
message: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct TokenResponse {
|
||||||
|
access_token: String,
|
||||||
|
#[serde(default)]
|
||||||
|
refresh_token: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
expires_in: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
scope: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
token_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OAuthErrorResponse {
|
||||||
|
error: String,
|
||||||
|
#[serde(default)]
|
||||||
|
error_description: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
|
||||||
|
let expires_at = payload
|
||||||
|
.expires_in
|
||||||
|
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
|
||||||
|
.map(|seconds| Utc::now() + Duration::seconds(seconds));
|
||||||
|
|
||||||
|
OAuthToken {
|
||||||
|
access_token: payload.access_token,
|
||||||
|
refresh_token: payload.refresh_token,
|
||||||
|
expires_at,
|
||||||
|
scope: payload.scope,
|
||||||
|
token_type: payload.token_type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
|
||||||
|
if err.is_timeout() {
|
||||||
|
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
|
||||||
|
} else if err.is_connect() {
|
||||||
|
Error::Network(format!("OAuth {action} connection error: {err}"))
|
||||||
|
} else {
|
||||||
|
Error::Network(format!("OAuth {action} request failed: {err}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use httpmock::prelude::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn config_for(server: &MockServer) -> McpOAuthConfig {
|
||||||
|
McpOAuthConfig {
|
||||||
|
client_id: "test-client".to_string(),
|
||||||
|
client_secret: None,
|
||||||
|
authorize_url: server.url("/authorize"),
|
||||||
|
token_url: server.url("/token"),
|
||||||
|
device_authorization_url: Some(server.url("/device")),
|
||||||
|
redirect_url: None,
|
||||||
|
scopes: vec!["repo".to_string(), "user".to_string()],
|
||||||
|
token_env: None,
|
||||||
|
header: None,
|
||||||
|
header_prefix: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_device_authorization() -> DeviceAuthorization {
|
||||||
|
DeviceAuthorization {
|
||||||
|
device_code: "device-123".to_string(),
|
||||||
|
user_code: "ABCD-EFGH".to_string(),
|
||||||
|
verification_uri: "https://example.test/activate".to_string(),
|
||||||
|
verification_uri_complete: Some(
|
||||||
|
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
|
||||||
|
),
|
||||||
|
expires_at: Utc::now() + Duration::minutes(10),
|
||||||
|
interval: StdDuration::from_secs(5),
|
||||||
|
message: Some("Open the verification URL and enter the code.".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn start_device_authorization_returns_payload() {
|
||||||
|
let server = MockServer::start_async().await;
|
||||||
|
let device_mock = server
|
||||||
|
.mock_async(|when, then| {
|
||||||
|
when.method(POST).path("/device");
|
||||||
|
then.status(200)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json_body(json!({
|
||||||
|
"device_code": "device-123",
|
||||||
|
"user_code": "ABCD-EFGH",
|
||||||
|
"verification_uri": "https://example.test/activate",
|
||||||
|
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
|
||||||
|
"expires_in": 600,
|
||||||
|
"interval": 7,
|
||||||
|
"message": "Open the verification URL and enter the code."
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let client = OAuthClient::new(config_for(&server)).expect("client");
|
||||||
|
let auth = client
|
||||||
|
.start_device_authorization()
|
||||||
|
.await
|
||||||
|
.expect("device authorization payload");
|
||||||
|
|
||||||
|
assert_eq!(auth.user_code, "ABCD-EFGH");
|
||||||
|
assert_eq!(auth.interval, StdDuration::from_secs(7));
|
||||||
|
assert!(auth.expires_at > Utc::now());
|
||||||
|
device_mock.assert_async().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn poll_device_token_reports_pending() {
|
||||||
|
let server = MockServer::start_async().await;
|
||||||
|
let pending = server
|
||||||
|
.mock_async(|when, then| {
|
||||||
|
when.method(POST)
|
||||||
|
.path("/token")
|
||||||
|
.body_contains(
|
||||||
|
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
|
||||||
|
)
|
||||||
|
.body_contains("device_code=device-123");
|
||||||
|
then.status(400)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json_body(json!({
|
||||||
|
"error": "authorization_pending"
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let config = config_for(&server);
|
||||||
|
let client = OAuthClient::new(config).expect("client");
|
||||||
|
let auth = sample_device_authorization();
|
||||||
|
|
||||||
|
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||||
|
match result {
|
||||||
|
DevicePollState::Pending { retry_in } => {
|
||||||
|
assert_eq!(retry_in, StdDuration::from_secs(5));
|
||||||
|
}
|
||||||
|
other => panic!("expected pending state, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
pending.assert_async().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn poll_device_token_applies_slow_down_backoff() {
|
||||||
|
let server = MockServer::start_async().await;
|
||||||
|
let slow = server
|
||||||
|
.mock_async(|when, then| {
|
||||||
|
when.method(POST).path("/token");
|
||||||
|
then.status(400)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json_body(json!({
|
||||||
|
"error": "slow_down"
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let config = config_for(&server);
|
||||||
|
let client = OAuthClient::new(config).expect("client");
|
||||||
|
let auth = sample_device_authorization();
|
||||||
|
|
||||||
|
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||||
|
match result {
|
||||||
|
DevicePollState::Pending { retry_in } => {
|
||||||
|
assert_eq!(retry_in, StdDuration::from_secs(10));
|
||||||
|
}
|
||||||
|
other => panic!("expected pending state, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
slow.assert_async().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn poll_device_token_returns_token_when_authorized() {
|
||||||
|
let server = MockServer::start_async().await;
|
||||||
|
let token = server
|
||||||
|
.mock_async(|when, then| {
|
||||||
|
when.method(POST).path("/token");
|
||||||
|
then.status(200)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json_body(json!({
|
||||||
|
"access_token": "token-abc",
|
||||||
|
"refresh_token": "refresh-xyz",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "repo user"
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let config = config_for(&server);
|
||||||
|
let client = OAuthClient::new(config).expect("client");
|
||||||
|
let auth = sample_device_authorization();
|
||||||
|
|
||||||
|
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||||
|
let token_info = match result {
|
||||||
|
DevicePollState::Complete(token) => token,
|
||||||
|
other => panic!("expected completion, got {other:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(token_info.access_token, "token-abc");
|
||||||
|
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
|
||||||
|
assert!(token_info.expires_at.is_some());
|
||||||
|
token.assert_async().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn refresh_token_roundtrip() {
|
||||||
|
let server = MockServer::start_async().await;
|
||||||
|
let refresh = server
|
||||||
|
.mock_async(|when, then| {
|
||||||
|
when.method(POST)
|
||||||
|
.path("/token")
|
||||||
|
.body_contains("grant_type=refresh_token")
|
||||||
|
.body_contains("refresh_token=old-refresh");
|
||||||
|
then.status(200)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json_body(json!({
|
||||||
|
"access_token": "token-new",
|
||||||
|
"refresh_token": "refresh-new",
|
||||||
|
"expires_in": 1200,
|
||||||
|
"token_type": "Bearer"
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let config = config_for(&server);
|
||||||
|
let client = OAuthClient::new(config).expect("client");
|
||||||
|
let token = client
|
||||||
|
.refresh_token("old-refresh")
|
||||||
|
.await
|
||||||
|
.expect("refresh response");
|
||||||
|
|
||||||
|
assert_eq!(token.access_token, "token-new");
|
||||||
|
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
|
||||||
|
assert!(token.expires_at.is_some());
|
||||||
|
refresh.assert_async().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
//! Provider trait and related types
|
|
||||||
|
|
||||||
use crate::{types::*, Result};
|
|
||||||
use futures::Stream;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// A stream of chat responses
|
|
||||||
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
|
||||||
|
|
||||||
/// Trait for LLM providers (Ollama, OpenAI, Anthropic, etc.)
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use std::pin::Pin;
|
|
||||||
/// use std::sync::Arc;
|
|
||||||
/// use futures::Stream;
|
|
||||||
/// use owlen_core::provider::{Provider, ProviderRegistry, ChatStream};
|
|
||||||
/// use owlen_core::types::{ChatRequest, ChatResponse, ModelInfo, Message, Role, ChatParameters};
|
|
||||||
/// use owlen_core::Result;
|
|
||||||
///
|
|
||||||
/// // 1. Create a mock provider
|
|
||||||
/// struct MockProvider;
|
|
||||||
///
|
|
||||||
/// #[async_trait::async_trait]
|
|
||||||
/// impl Provider for MockProvider {
|
|
||||||
/// fn name(&self) -> &str {
|
|
||||||
/// "mock"
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
||||||
/// Ok(vec![ModelInfo {
|
|
||||||
/// id: "mock-model".to_string(),
|
|
||||||
/// provider: "mock".to_string(),
|
|
||||||
/// name: "mock-model".to_string(),
|
|
||||||
/// description: None,
|
|
||||||
/// context_window: None,
|
|
||||||
/// capabilities: vec![],
|
|
||||||
/// supports_tools: false,
|
|
||||||
/// }])
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
|
||||||
/// let content = format!("Response to: {}", request.messages.last().unwrap().content);
|
|
||||||
/// Ok(ChatResponse {
|
|
||||||
/// message: Message::new(Role::Assistant, content),
|
|
||||||
/// usage: None,
|
|
||||||
/// is_streaming: false,
|
|
||||||
/// is_final: true,
|
|
||||||
/// })
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
|
||||||
/// unimplemented!();
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// async fn health_check(&self) -> Result<()> {
|
|
||||||
/// Ok(())
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// // 2. Use the provider with a registry
|
|
||||||
/// #[tokio::main]
|
|
||||||
/// async fn main() {
|
|
||||||
/// let mut registry = ProviderRegistry::new();
|
|
||||||
/// registry.register(MockProvider);
|
|
||||||
///
|
|
||||||
/// let provider = registry.get("mock").unwrap();
|
|
||||||
/// let models = provider.list_models().await.unwrap();
|
|
||||||
/// assert_eq!(models[0].name, "mock-model");
|
|
||||||
///
|
|
||||||
/// let request = ChatRequest {
|
|
||||||
/// model: "mock-model".to_string(),
|
|
||||||
/// messages: vec![Message::new(Role::User, "Hello".to_string())],
|
|
||||||
/// parameters: ChatParameters::default(),
|
|
||||||
/// tools: None,
|
|
||||||
/// };
|
|
||||||
///
|
|
||||||
/// let response = provider.chat(request).await.unwrap();
|
|
||||||
/// assert_eq!(response.message.content, "Response to: Hello");
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
#[async_trait::async_trait]
|
|
||||||
pub trait Provider: Send + Sync {
|
|
||||||
/// Get the name of this provider
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
/// List available models from this provider
|
|
||||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
|
||||||
|
|
||||||
/// Send a chat completion request
|
|
||||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
|
||||||
|
|
||||||
/// Send a streaming chat completion request
|
|
||||||
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream>;
|
|
||||||
|
|
||||||
/// Check if the provider is available/healthy
|
|
||||||
async fn health_check(&self) -> Result<()>;
|
|
||||||
|
|
||||||
/// Get provider-specific configuration schema
|
|
||||||
fn config_schema(&self) -> serde_json::Value {
|
|
||||||
serde_json::json!({})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Configuration for a provider
|
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
||||||
pub struct ProviderConfig {
|
|
||||||
/// Provider type identifier
|
|
||||||
pub provider_type: String,
|
|
||||||
/// Base URL for API calls
|
|
||||||
pub base_url: Option<String>,
|
|
||||||
/// API key or token
|
|
||||||
pub api_key: Option<String>,
|
|
||||||
/// Additional provider-specific configuration
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub extra: std::collections::HashMap<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A registry of providers
|
|
||||||
pub struct ProviderRegistry {
|
|
||||||
providers: std::collections::HashMap<String, Arc<dyn Provider>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderRegistry {
|
|
||||||
/// Create a new provider registry
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
providers: std::collections::HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register a provider
|
|
||||||
pub fn register<P: Provider + 'static>(&mut self, provider: P) {
|
|
||||||
self.register_arc(Arc::new(provider));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register an already wrapped provider
|
|
||||||
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
|
|
||||||
let name = provider.name().to_string();
|
|
||||||
self.providers.insert(name, provider);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a provider by name
|
|
||||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
|
||||||
self.providers.get(name).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List all registered provider names
|
|
||||||
pub fn list_providers(&self) -> Vec<String> {
|
|
||||||
self.providers.keys().cloned().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get all models from all providers
|
|
||||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
|
||||||
let mut all_models = Vec::new();
|
|
||||||
|
|
||||||
for provider in self.providers.values() {
|
|
||||||
match provider.list_models().await {
|
|
||||||
Ok(mut models) => all_models.append(&mut models),
|
|
||||||
Err(_) => {
|
|
||||||
// Continue with other providers
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(all_models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ProviderRegistry {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
227
crates/owlen-core/src/provider/manager.rs
Normal file
227
crates/owlen-core/src/provider/manager.rs
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::stream::{FuturesUnordered, StreamExt};
|
||||||
|
use log::{debug, warn};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::{Error, Result};
|
||||||
|
|
||||||
|
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
|
||||||
|
|
||||||
|
/// Model information annotated with the originating provider metadata.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AnnotatedModelInfo {
|
||||||
|
pub provider_id: String,
|
||||||
|
pub provider_status: ProviderStatus,
|
||||||
|
pub model: ModelInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
|
||||||
|
/// health state.
|
||||||
|
pub struct ProviderManager {
|
||||||
|
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
|
||||||
|
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderManager {
|
||||||
|
/// Construct a new manager using the supplied configuration. Providers
|
||||||
|
/// defined in the configuration start with a `RequiresSetup` status so
|
||||||
|
/// that frontends can surface incomplete configuration to users.
|
||||||
|
pub fn new(config: &Config) -> Self {
|
||||||
|
let mut status_cache = HashMap::new();
|
||||||
|
for provider_id in config.providers.keys() {
|
||||||
|
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
providers: RwLock::new(HashMap::new()),
|
||||||
|
status_cache: RwLock::new(status_cache),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a provider instance with the manager.
|
||||||
|
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
|
||||||
|
let provider_id = provider.metadata().id.clone();
|
||||||
|
debug!("registering provider {}", provider_id);
|
||||||
|
|
||||||
|
self.providers
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id.clone(), provider);
|
||||||
|
self.status_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id, ProviderStatus::Unavailable);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a stream by routing the request to the designated provider.
|
||||||
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
provider_id: &str,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<GenerateStream> {
|
||||||
|
let provider = {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard.get(provider_id).cloned()
|
||||||
|
}
|
||||||
|
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
|
||||||
|
|
||||||
|
match provider.generate_stream(request).await {
|
||||||
|
Ok(stream) => {
|
||||||
|
self.status_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.status_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List models across all providers, updating provider status along the way.
|
||||||
|
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
|
||||||
|
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard
|
||||||
|
.iter()
|
||||||
|
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tasks = FuturesUnordered::new();
|
||||||
|
|
||||||
|
for (provider_id, provider) in providers {
|
||||||
|
tasks.push(async move {
|
||||||
|
let log_id = provider_id.clone();
|
||||||
|
let mut status = ProviderStatus::Unavailable;
|
||||||
|
let mut models = Vec::new();
|
||||||
|
|
||||||
|
match provider.health_check().await {
|
||||||
|
Ok(health) => {
|
||||||
|
status = health;
|
||||||
|
if matches!(status, ProviderStatus::Available) {
|
||||||
|
match provider.list_models().await {
|
||||||
|
Ok(list) => {
|
||||||
|
models = list;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
status = ProviderStatus::Unavailable;
|
||||||
|
warn!("listing models failed for provider {}: {}", log_id, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("health check failed for provider {}: {}", log_id, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(provider_id, status, models)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut annotated = Vec::new();
|
||||||
|
let mut status_updates = HashMap::new();
|
||||||
|
|
||||||
|
while let Some((provider_id, status, models)) = tasks.next().await {
|
||||||
|
status_updates.insert(provider_id.clone(), status);
|
||||||
|
for model in models {
|
||||||
|
annotated.push(AnnotatedModelInfo {
|
||||||
|
provider_id: provider_id.clone(),
|
||||||
|
provider_status: status,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = self.status_cache.write().await;
|
||||||
|
for (provider_id, status) in status_updates {
|
||||||
|
guard.insert(provider_id, status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(annotated)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Refresh the health of all registered providers in parallel, returning
|
||||||
|
/// the latest status snapshot.
|
||||||
|
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||||
|
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard
|
||||||
|
.iter()
|
||||||
|
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tasks = FuturesUnordered::new();
|
||||||
|
for (provider_id, provider) in providers {
|
||||||
|
tasks.push(async move {
|
||||||
|
let status = match provider.health_check().await {
|
||||||
|
Ok(status) => status,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("health check failed for provider {}: {}", provider_id, err);
|
||||||
|
ProviderStatus::Unavailable
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(provider_id, status)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut updates = HashMap::new();
|
||||||
|
while let Some((provider_id, status)) = tasks.next().await {
|
||||||
|
updates.insert(provider_id, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = self.status_cache.write().await;
|
||||||
|
for (provider_id, status) in &updates {
|
||||||
|
guard.insert(provider_id.clone(), *status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updates
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the provider instance for an identifier.
|
||||||
|
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard.get(provider_id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List the registered provider identifiers.
|
||||||
|
pub async fn provider_ids(&self) -> Vec<String> {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve the last known status for a provider.
|
||||||
|
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
|
||||||
|
let guard = self.status_cache.read().await;
|
||||||
|
guard.get(provider_id).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Snapshot the currently cached statuses.
|
||||||
|
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||||
|
let guard = self.status_cache.read().await;
|
||||||
|
guard.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ProviderManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
providers: RwLock::new(HashMap::new()),
|
||||||
|
status_cache: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
36
crates/owlen-core/src/provider/mod.rs
Normal file
36
crates/owlen-core/src/provider/mod.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
//! Unified provider abstraction layer.
|
||||||
|
//!
|
||||||
|
//! This module defines the async [`ModelProvider`] trait that all model
|
||||||
|
//! backends implement, together with a small suite of shared data structures
|
||||||
|
//! used for model discovery and streaming generation. The [`ProviderManager`]
|
||||||
|
//! orchestrates multiple providers and coordinates their health state.
|
||||||
|
|
||||||
|
mod manager;
|
||||||
|
mod types;
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::Stream;
|
||||||
|
|
||||||
|
pub use self::{manager::*, types::*};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
/// Convenience alias for the stream type yielded by [`ModelProvider::generate_stream`].
|
||||||
|
pub type GenerateStream = Pin<Box<dyn Stream<Item = Result<GenerateChunk>> + Send + 'static>>;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ModelProvider: Send + Sync {
|
||||||
|
/// Returns descriptive metadata about the provider.
|
||||||
|
fn metadata(&self) -> &ProviderMetadata;
|
||||||
|
|
||||||
|
/// Check the current health state for the provider.
|
||||||
|
async fn health_check(&self) -> Result<ProviderStatus>;
|
||||||
|
|
||||||
|
/// List all models available through the provider.
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||||
|
|
||||||
|
/// Acquire a streaming response for a generation request.
|
||||||
|
async fn generate_stream(&self, request: GenerateRequest) -> Result<GenerateStream>;
|
||||||
|
}
|
||||||
124
crates/owlen-core/src/provider/types.rs
Normal file
124
crates/owlen-core/src/provider/types.rs
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
//! Shared types used by the unified provider abstraction layer.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
/// Categorises providers so the UI can distinguish between local and hosted
|
||||||
|
/// backends.
|
||||||
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub enum ProviderType {
|
||||||
|
Local,
|
||||||
|
Cloud,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents the current availability state for a provider.
|
||||||
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub enum ProviderStatus {
|
||||||
|
Available,
|
||||||
|
Unavailable,
|
||||||
|
RequiresSetup,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Describes core metadata for a provider implementation.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct ProviderMetadata {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub provider_type: ProviderType,
|
||||||
|
pub requires_auth: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub metadata: HashMap<String, Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderMetadata {
|
||||||
|
/// Construct a new metadata instance for a provider.
|
||||||
|
pub fn new(
|
||||||
|
id: impl Into<String>,
|
||||||
|
name: impl Into<String>,
|
||||||
|
provider_type: ProviderType,
|
||||||
|
requires_auth: bool,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
id: id.into(),
|
||||||
|
name: name.into(),
|
||||||
|
provider_type,
|
||||||
|
requires_auth,
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Information about a model that can be displayed to users.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub name: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub size_bytes: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub capabilities: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub provider: ProviderMetadata,
|
||||||
|
#[serde(default)]
|
||||||
|
pub metadata: HashMap<String, Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unified request for streaming text generation across providers.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct GenerateRequest {
|
||||||
|
pub model: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub prompt: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub context: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub parameters: HashMap<String, Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub metadata: HashMap<String, Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateRequest {
|
||||||
|
/// Helper for building a request from the minimum required fields.
|
||||||
|
pub fn new(model: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
model: model.into(),
|
||||||
|
prompt: None,
|
||||||
|
context: Vec::new(),
|
||||||
|
parameters: HashMap::new(),
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streamed chunk of generation output from a model.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct GenerateChunk {
|
||||||
|
#[serde(default)]
|
||||||
|
pub text: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub is_final: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub metadata: HashMap<String, Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateChunk {
|
||||||
|
/// Construct a new chunk with the provided text payload.
|
||||||
|
pub fn from_text(text: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
text: Some(text.into()),
|
||||||
|
is_final: false,
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark the chunk as the terminal item in a stream.
|
||||||
|
pub fn final_chunk() -> Self {
|
||||||
|
Self {
|
||||||
|
text: None,
|
||||||
|
is_final: true,
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
8
crates/owlen-core/src/providers/mod.rs
Normal file
8
crates/owlen-core/src/providers/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//! Built-in LLM provider implementations.
|
||||||
|
//!
|
||||||
|
//! Each provider integration lives in its own module so that maintenance
|
||||||
|
//! stays focused and configuration remains clear.
|
||||||
|
|
||||||
|
pub mod ollama;
|
||||||
|
|
||||||
|
pub use ollama::OllamaProvider;
|
||||||
1754
crates/owlen-core/src/providers/ollama.rs
Normal file
1754
crates/owlen-core/src/providers/ollama.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
|||||||
//! Router for managing multiple providers and routing requests
|
//! Router for managing multiple providers and routing requests
|
||||||
|
|
||||||
use crate::{provider::*, types::*, Result};
|
use crate::{Result, llm::*, types::*};
|
||||||
|
use anyhow::anyhow;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// A router that can distribute requests across multiple providers
|
/// A router that can distribute requests across multiple providers
|
||||||
@@ -32,7 +33,7 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Register a provider with the router
|
/// Register a provider with the router
|
||||||
pub fn register_provider<P: Provider + 'static>(&mut self, provider: P) {
|
pub fn register_provider<P: LlmProvider + 'static>(&mut self, provider: P) {
|
||||||
self.registry.register(provider);
|
self.registry.register(provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,13 +53,13 @@ impl Router {
|
|||||||
/// Route a request to the appropriate provider
|
/// Route a request to the appropriate provider
|
||||||
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||||
let provider = self.find_provider_for_model(&request.model)?;
|
let provider = self.find_provider_for_model(&request.model)?;
|
||||||
provider.chat(request).await
|
provider.send_prompt(request).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route a streaming request to the appropriate provider
|
/// Route a streaming request to the appropriate provider
|
||||||
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||||
let provider = self.find_provider_for_model(&request.model)?;
|
let provider = self.find_provider_for_model(&request.model)?;
|
||||||
provider.chat_stream(request).await
|
provider.stream_prompt(request).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List all available models from all providers
|
/// List all available models from all providers
|
||||||
@@ -70,18 +71,21 @@ impl Router {
|
|||||||
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
||||||
// Check routing rules first
|
// Check routing rules first
|
||||||
for rule in &self.routing_rules {
|
for rule in &self.routing_rules {
|
||||||
if self.matches_pattern(&rule.model_pattern, model) {
|
if !self.matches_pattern(&rule.model_pattern, model) {
|
||||||
if let Some(provider) = self.registry.get(&rule.provider) {
|
continue;
|
||||||
return Ok(provider);
|
}
|
||||||
}
|
if let Some(provider) = self.registry.get(&rule.provider) {
|
||||||
|
return Ok(provider);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to default provider
|
// Fall back to default provider
|
||||||
if let Some(default) = &self.default_provider {
|
if let Some(provider) = self
|
||||||
if let Some(provider) = self.registry.get(default) {
|
.default_provider
|
||||||
return Ok(provider);
|
.as_ref()
|
||||||
}
|
.and_then(|default| self.registry.get(default))
|
||||||
|
{
|
||||||
|
return Ok(provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no default, try to find any provider that has this model
|
// If no default, try to find any provider that has this model
|
||||||
@@ -92,7 +96,7 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(crate::Error::Provider(anyhow::anyhow!(
|
Err(crate::Error::Provider(anyhow!(
|
||||||
"No provider found for model: {}",
|
"No provider found for model: {}",
|
||||||
model
|
model
|
||||||
)))
|
)))
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::path::PathBuf;
|
|||||||
use std::process::{Command, Stdio};
|
use std::process::{Command, Stdio};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{Context, Result, bail};
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
/// Configuration options for sandboxed process execution.
|
/// Configuration options for sandboxed process execution.
|
||||||
@@ -185,16 +185,20 @@ impl SandboxedProcess {
|
|||||||
if let Ok(output) = output {
|
if let Ok(output) = output {
|
||||||
let version_str = String::from_utf8_lossy(&output.stdout);
|
let version_str = String::from_utf8_lossy(&output.stdout);
|
||||||
// Parse version like "bubblewrap 0.11.0" or "0.11.0"
|
// Parse version like "bubblewrap 0.11.0" or "0.11.0"
|
||||||
if let Some(version_part) = version_str.split_whitespace().last() {
|
return version_str
|
||||||
if let Some((major, rest)) = version_part.split_once('.') {
|
.split_whitespace()
|
||||||
if let Some((minor, _patch)) = rest.split_once('.') {
|
.last()
|
||||||
if let (Ok(maj), Ok(min)) = (major.parse::<u32>(), minor.parse::<u32>()) {
|
.and_then(|part| {
|
||||||
// --rlimit-as was added in 0.12.0
|
part.split_once('.').and_then(|(major, rest)| {
|
||||||
return maj > 0 || (maj == 0 && min >= 12);
|
rest.split_once('.').and_then(|(minor, _)| {
|
||||||
}
|
let maj = major.parse::<u32>().ok()?;
|
||||||
}
|
let min = minor.parse::<u32>().ok()?;
|
||||||
}
|
Some((maj, min))
|
||||||
}
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.map(|(maj, min)| maj > 0 || (maj == 0 && min >= 12))
|
||||||
|
.unwrap_or(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we can't determine the version, assume it doesn't support it (safer default)
|
// If we can't determine the version, assume it doesn't support it (safer default)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
199
crates/owlen-core/src/state/mod.rs
Normal file
199
crates/owlen-core/src/state/mod.rs
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
//! Shared application state types used across TUI frontends.
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// High-level application state reported by the UI loop.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum AppState {
|
||||||
|
Running,
|
||||||
|
Quit,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vim-style input modes supported by the TUI.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum InputMode {
|
||||||
|
Normal,
|
||||||
|
Editing,
|
||||||
|
ProviderSelection,
|
||||||
|
ModelSelection,
|
||||||
|
Help,
|
||||||
|
Visual,
|
||||||
|
Command,
|
||||||
|
SessionBrowser,
|
||||||
|
ThemeBrowser,
|
||||||
|
RepoSearch,
|
||||||
|
SymbolSearch,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for InputMode {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let label = match self {
|
||||||
|
InputMode::Normal => "Normal",
|
||||||
|
InputMode::Editing => "Editing",
|
||||||
|
InputMode::ModelSelection => "Model",
|
||||||
|
InputMode::ProviderSelection => "Provider",
|
||||||
|
InputMode::Help => "Help",
|
||||||
|
InputMode::Visual => "Visual",
|
||||||
|
InputMode::Command => "Command",
|
||||||
|
InputMode::SessionBrowser => "Sessions",
|
||||||
|
InputMode::ThemeBrowser => "Themes",
|
||||||
|
InputMode::RepoSearch => "Search",
|
||||||
|
InputMode::SymbolSearch => "Symbols",
|
||||||
|
};
|
||||||
|
f.write_str(label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents which panel is currently focused in the TUI layout.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum FocusedPanel {
|
||||||
|
Files,
|
||||||
|
Chat,
|
||||||
|
Thinking,
|
||||||
|
Input,
|
||||||
|
Code,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Auto-scroll state manager for scrollable panels.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AutoScroll {
|
||||||
|
pub scroll: usize,
|
||||||
|
pub content_len: usize,
|
||||||
|
pub stick_to_bottom: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AutoScroll {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
scroll: 0,
|
||||||
|
content_len: 0,
|
||||||
|
stick_to_bottom: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AutoScroll {
|
||||||
|
/// Update scroll position based on viewport height.
|
||||||
|
pub fn on_viewport(&mut self, viewport_h: usize) {
|
||||||
|
let max = self.content_len.saturating_sub(viewport_h);
|
||||||
|
if self.stick_to_bottom {
|
||||||
|
self.scroll = max;
|
||||||
|
} else {
|
||||||
|
self.scroll = self.scroll.min(max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle user scroll input.
|
||||||
|
pub fn on_user_scroll(&mut self, delta: isize, viewport_h: usize) {
|
||||||
|
let max = self.content_len.saturating_sub(viewport_h) as isize;
|
||||||
|
let s = (self.scroll as isize + delta).clamp(0, max) as usize;
|
||||||
|
self.scroll = s;
|
||||||
|
self.stick_to_bottom = s as isize == max;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scroll_half_page_down(&mut self, viewport_h: usize) {
|
||||||
|
let delta = (viewport_h / 2) as isize;
|
||||||
|
self.on_user_scroll(delta, viewport_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scroll_half_page_up(&mut self, viewport_h: usize) {
|
||||||
|
let delta = -((viewport_h / 2) as isize);
|
||||||
|
self.on_user_scroll(delta, viewport_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scroll_full_page_down(&mut self, viewport_h: usize) {
|
||||||
|
let delta = viewport_h as isize;
|
||||||
|
self.on_user_scroll(delta, viewport_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scroll_full_page_up(&mut self, viewport_h: usize) {
|
||||||
|
let delta = -(viewport_h as isize);
|
||||||
|
self.on_user_scroll(delta, viewport_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jump_to_top(&mut self) {
|
||||||
|
self.scroll = 0;
|
||||||
|
self.stick_to_bottom = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jump_to_bottom(&mut self, viewport_h: usize) {
|
||||||
|
self.stick_to_bottom = true;
|
||||||
|
self.on_viewport(viewport_h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Visual selection state for text selection.
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct VisualSelection {
|
||||||
|
pub start: Option<(usize, usize)>,
|
||||||
|
pub end: Option<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisualSelection {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_at(&mut self, pos: (usize, usize)) {
|
||||||
|
self.start = Some(pos);
|
||||||
|
self.end = Some(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extend_to(&mut self, pos: (usize, usize)) {
|
||||||
|
self.end = Some(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.start = None;
|
||||||
|
self.end = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_active(&self) -> bool {
|
||||||
|
self.start.is_some() && self.end.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_normalized(&self) -> Option<((usize, usize), (usize, usize))> {
|
||||||
|
if let (Some(s), Some(e)) = (self.start, self.end) {
|
||||||
|
if s.0 < e.0 || (s.0 == e.0 && s.1 <= e.1) {
|
||||||
|
Some((s, e))
|
||||||
|
} else {
|
||||||
|
Some((e, s))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cursor position helper for navigating scrollable content.
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct CursorPosition {
|
||||||
|
pub row: usize,
|
||||||
|
pub col: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CursorPosition {
|
||||||
|
pub fn new(row: usize, col: usize) -> Self {
|
||||||
|
Self { row, col }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn move_up(&mut self, amount: usize) {
|
||||||
|
self.row = self.row.saturating_sub(amount);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn move_down(&mut self, amount: usize, max: usize) {
|
||||||
|
self.row = (self.row + amount).min(max);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn move_left(&mut self, amount: usize) {
|
||||||
|
self.col = self.col.saturating_sub(amount);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn move_right(&mut self, amount: usize, max: usize) {
|
||||||
|
self.col = (self.col + amount).min(max);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_tuple(&self) -> (usize, usize) {
|
||||||
|
(self.row, self.col)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,8 @@ use std::collections::HashMap;
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
pub type ThemePalette = Theme;
|
||||||
|
|
||||||
/// A complete theme definition for OWLEN TUI
|
/// A complete theme definition for OWLEN TUI
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Theme {
|
pub struct Theme {
|
||||||
@@ -34,6 +36,42 @@ pub struct Theme {
|
|||||||
#[serde(serialize_with = "serialize_color")]
|
#[serde(serialize_with = "serialize_color")]
|
||||||
pub unfocused_panel_border: Color,
|
pub unfocused_panel_border: Color,
|
||||||
|
|
||||||
|
/// Foreground color for the active pane beacon (`▌`)
|
||||||
|
#[serde(default = "Theme::default_focus_beacon_fg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub focus_beacon_fg: Color,
|
||||||
|
|
||||||
|
/// Background color for the active pane beacon (`▌`)
|
||||||
|
#[serde(default = "Theme::default_focus_beacon_bg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub focus_beacon_bg: Color,
|
||||||
|
|
||||||
|
/// Foreground color for the inactive pane beacon (`▌`)
|
||||||
|
#[serde(default = "Theme::default_unfocused_beacon_fg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub unfocused_beacon_fg: Color,
|
||||||
|
|
||||||
|
/// Title color for active pane headers
|
||||||
|
#[serde(default = "Theme::default_pane_header_active")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub pane_header_active: Color,
|
||||||
|
|
||||||
|
/// Title color for inactive pane headers
|
||||||
|
#[serde(default = "Theme::default_pane_header_inactive")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub pane_header_inactive: Color,
|
||||||
|
|
||||||
|
/// Hint text color used within pane headers
|
||||||
|
#[serde(default = "Theme::default_pane_hint_text")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub pane_hint_text: Color,
|
||||||
|
|
||||||
/// Color for user message role indicator
|
/// Color for user message role indicator
|
||||||
#[serde(deserialize_with = "deserialize_color")]
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
#[serde(serialize_with = "serialize_color")]
|
#[serde(serialize_with = "serialize_color")]
|
||||||
@@ -114,6 +152,42 @@ pub struct Theme {
|
|||||||
#[serde(serialize_with = "serialize_color")]
|
#[serde(serialize_with = "serialize_color")]
|
||||||
pub cursor: Color,
|
pub cursor: Color,
|
||||||
|
|
||||||
|
/// Code block background color
|
||||||
|
#[serde(default = "Theme::default_code_block_background")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub code_block_background: Color,
|
||||||
|
|
||||||
|
/// Code block border color
|
||||||
|
#[serde(default = "Theme::default_code_block_border")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub code_block_border: Color,
|
||||||
|
|
||||||
|
/// Code block text color
|
||||||
|
#[serde(default = "Theme::default_code_block_text")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub code_block_text: Color,
|
||||||
|
|
||||||
|
/// Code block keyword color
|
||||||
|
#[serde(default = "Theme::default_code_block_keyword")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub code_block_keyword: Color,
|
||||||
|
|
||||||
|
/// Code block string literal color
|
||||||
|
#[serde(default = "Theme::default_code_block_string")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub code_block_string: Color,
|
||||||
|
|
||||||
|
/// Code block comment color
|
||||||
|
#[serde(default = "Theme::default_code_block_comment")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub code_block_comment: Color,
|
||||||
|
|
||||||
/// Placeholder text color
|
/// Placeholder text color
|
||||||
#[serde(deserialize_with = "deserialize_color")]
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
#[serde(serialize_with = "serialize_color")]
|
#[serde(serialize_with = "serialize_color")]
|
||||||
@@ -128,6 +202,84 @@ pub struct Theme {
|
|||||||
#[serde(deserialize_with = "deserialize_color")]
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
#[serde(serialize_with = "serialize_color")]
|
#[serde(serialize_with = "serialize_color")]
|
||||||
pub info: Color,
|
pub info: Color,
|
||||||
|
|
||||||
|
/// Agent action coloring (ReAct THOUGHT)
|
||||||
|
#[serde(default = "Theme::default_agent_thought")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_thought: Color,
|
||||||
|
|
||||||
|
/// Agent action coloring (ReAct ACTION)
|
||||||
|
#[serde(default = "Theme::default_agent_action")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_action: Color,
|
||||||
|
|
||||||
|
/// Agent action coloring (ReAct ACTION_INPUT)
|
||||||
|
#[serde(default = "Theme::default_agent_action_input")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_action_input: Color,
|
||||||
|
|
||||||
|
/// Agent action coloring (ReAct OBSERVATION)
|
||||||
|
#[serde(default = "Theme::default_agent_observation")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_observation: Color,
|
||||||
|
|
||||||
|
/// Agent action coloring (ReAct FINAL_ANSWER)
|
||||||
|
#[serde(default = "Theme::default_agent_final_answer")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_final_answer: Color,
|
||||||
|
|
||||||
|
/// Status badge foreground when agent is running
|
||||||
|
#[serde(default = "Theme::default_agent_badge_running_fg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_badge_running_fg: Color,
|
||||||
|
|
||||||
|
/// Status badge background when agent is running
|
||||||
|
#[serde(default = "Theme::default_agent_badge_running_bg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_badge_running_bg: Color,
|
||||||
|
|
||||||
|
/// Status badge foreground when agent mode is idle
|
||||||
|
#[serde(default = "Theme::default_agent_badge_idle_fg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_badge_idle_fg: Color,
|
||||||
|
|
||||||
|
/// Status badge background when agent mode is idle
|
||||||
|
#[serde(default = "Theme::default_agent_badge_idle_bg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub agent_badge_idle_bg: Color,
|
||||||
|
|
||||||
|
/// Operating mode badge foreground (Chat)
|
||||||
|
#[serde(default = "Theme::default_operating_chat_fg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub operating_chat_fg: Color,
|
||||||
|
|
||||||
|
/// Operating mode badge background (Chat)
|
||||||
|
#[serde(default = "Theme::default_operating_chat_bg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub operating_chat_bg: Color,
|
||||||
|
|
||||||
|
/// Operating mode badge foreground (Code)
|
||||||
|
#[serde(default = "Theme::default_operating_code_fg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub operating_code_fg: Color,
|
||||||
|
|
||||||
|
/// Operating mode badge background (Code)
|
||||||
|
#[serde(default = "Theme::default_operating_code_bg")]
|
||||||
|
#[serde(deserialize_with = "deserialize_color")]
|
||||||
|
#[serde(serialize_with = "serialize_color")]
|
||||||
|
pub operating_code_bg: Color,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Theme {
|
impl Default for Theme {
|
||||||
@@ -136,6 +288,108 @@ impl Default for Theme {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Theme {
|
||||||
|
const fn default_code_block_background() -> Color {
|
||||||
|
Color::Black
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_code_block_border() -> Color {
|
||||||
|
Color::Gray
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_code_block_text() -> Color {
|
||||||
|
Color::White
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_code_block_keyword() -> Color {
|
||||||
|
Color::Yellow
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_code_block_string() -> Color {
|
||||||
|
Color::LightGreen
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_code_block_comment() -> Color {
|
||||||
|
Color::DarkGray
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_thought() -> Color {
|
||||||
|
Color::LightBlue
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_action() -> Color {
|
||||||
|
Color::Yellow
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_action_input() -> Color {
|
||||||
|
Color::LightCyan
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_observation() -> Color {
|
||||||
|
Color::LightGreen
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_final_answer() -> Color {
|
||||||
|
Color::Magenta
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_badge_running_fg() -> Color {
|
||||||
|
Color::Black
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_badge_running_bg() -> Color {
|
||||||
|
Color::Yellow
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_badge_idle_fg() -> Color {
|
||||||
|
Color::Black
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_agent_badge_idle_bg() -> Color {
|
||||||
|
Color::Cyan
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_focus_beacon_fg() -> Color {
|
||||||
|
Color::LightMagenta
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_focus_beacon_bg() -> Color {
|
||||||
|
Color::Black
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_unfocused_beacon_fg() -> Color {
|
||||||
|
Color::DarkGray
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_pane_header_active() -> Color {
|
||||||
|
Color::White
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_pane_header_inactive() -> Color {
|
||||||
|
Color::Gray
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_pane_hint_text() -> Color {
|
||||||
|
Color::DarkGray
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_operating_chat_fg() -> Color {
|
||||||
|
Color::Black
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_operating_chat_bg() -> Color {
|
||||||
|
Color::Blue
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_operating_code_fg() -> Color {
|
||||||
|
Color::Black
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_operating_code_bg() -> Color {
|
||||||
|
Color::Magenta
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the default themes directory path
|
/// Get the default themes directory path
|
||||||
pub fn default_themes_dir() -> PathBuf {
|
pub fn default_themes_dir() -> PathBuf {
|
||||||
let config_dir = PathBuf::from(shellexpand::tilde(crate::config::DEFAULT_CONFIG_PATH).as_ref())
|
let config_dir = PathBuf::from(shellexpand::tilde(crate::config::DEFAULT_CONFIG_PATH).as_ref())
|
||||||
@@ -209,6 +463,14 @@ pub fn built_in_themes() -> HashMap<String, Theme> {
|
|||||||
"default_light",
|
"default_light",
|
||||||
include_str!("../../../themes/default_light.toml"),
|
include_str!("../../../themes/default_light.toml"),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"ansi_basic",
|
||||||
|
include_str!("../../../themes/ansi-basic.toml"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"grayscale-high-contrast",
|
||||||
|
include_str!("../../../themes/grayscale-high-contrast.toml"),
|
||||||
|
),
|
||||||
("gruvbox", include_str!("../../../themes/gruvbox.toml")),
|
("gruvbox", include_str!("../../../themes/gruvbox.toml")),
|
||||||
("dracula", include_str!("../../../themes/dracula.toml")),
|
("dracula", include_str!("../../../themes/dracula.toml")),
|
||||||
("solarized", include_str!("../../../themes/solarized.toml")),
|
("solarized", include_str!("../../../themes/solarized.toml")),
|
||||||
@@ -259,6 +521,7 @@ fn get_fallback_theme(name: &str) -> Option<Theme> {
|
|||||||
"monokai" => Some(monokai()),
|
"monokai" => Some(monokai()),
|
||||||
"material-dark" => Some(material_dark()),
|
"material-dark" => Some(material_dark()),
|
||||||
"material-light" => Some(material_light()),
|
"material-light" => Some(material_light()),
|
||||||
|
"grayscale-high-contrast" => Some(grayscale_high_contrast()),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -269,27 +532,52 @@ fn default_dark() -> Theme {
|
|||||||
name: "default_dark".to_string(),
|
name: "default_dark".to_string(),
|
||||||
text: Color::White,
|
text: Color::White,
|
||||||
background: Color::Black,
|
background: Color::Black,
|
||||||
focused_panel_border: Color::LightMagenta,
|
focused_panel_border: Color::Rgb(216, 160, 255),
|
||||||
unfocused_panel_border: Color::Rgb(95, 20, 135),
|
unfocused_panel_border: Color::Rgb(137, 82, 204),
|
||||||
|
focus_beacon_fg: Color::Rgb(248, 229, 255),
|
||||||
|
focus_beacon_bg: Color::Rgb(38, 10, 58),
|
||||||
|
unfocused_beacon_fg: Color::Rgb(130, 130, 130),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Color::Rgb(210, 210, 210),
|
||||||
|
pane_hint_text: Color::Rgb(210, 210, 210),
|
||||||
user_message_role: Color::LightBlue,
|
user_message_role: Color::LightBlue,
|
||||||
assistant_message_role: Color::Yellow,
|
assistant_message_role: Color::Yellow,
|
||||||
tool_output: Color::Gray,
|
tool_output: Color::Rgb(200, 200, 200),
|
||||||
thinking_panel_title: Color::LightMagenta,
|
thinking_panel_title: Color::Rgb(234, 182, 255),
|
||||||
command_bar_background: Color::Black,
|
command_bar_background: Color::Rgb(10, 10, 10),
|
||||||
status_background: Color::Black,
|
status_background: Color::Rgb(12, 12, 12),
|
||||||
mode_normal: Color::LightBlue,
|
mode_normal: Color::Rgb(117, 200, 255),
|
||||||
mode_editing: Color::LightGreen,
|
mode_editing: Color::Rgb(144, 242, 170),
|
||||||
mode_model_selection: Color::LightYellow,
|
mode_model_selection: Color::Rgb(255, 226, 140),
|
||||||
mode_provider_selection: Color::LightCyan,
|
mode_provider_selection: Color::Rgb(164, 235, 255),
|
||||||
mode_help: Color::LightMagenta,
|
mode_help: Color::Rgb(234, 182, 255),
|
||||||
mode_visual: Color::Magenta,
|
mode_visual: Color::Rgb(255, 170, 255),
|
||||||
mode_command: Color::Yellow,
|
mode_command: Color::Rgb(255, 220, 120),
|
||||||
selection_bg: Color::LightBlue,
|
selection_bg: Color::Rgb(56, 140, 240),
|
||||||
selection_fg: Color::Black,
|
selection_fg: Color::Black,
|
||||||
cursor: Color::Magenta,
|
cursor: Color::Rgb(255, 196, 255),
|
||||||
placeholder: Color::DarkGray,
|
code_block_background: Color::Rgb(25, 25, 25),
|
||||||
|
code_block_border: Color::Rgb(216, 160, 255),
|
||||||
|
code_block_text: Color::White,
|
||||||
|
code_block_keyword: Color::Rgb(255, 220, 120),
|
||||||
|
code_block_string: Color::Rgb(144, 242, 170),
|
||||||
|
code_block_comment: Color::Rgb(170, 170, 170),
|
||||||
|
placeholder: Color::Rgb(180, 180, 180),
|
||||||
error: Color::Red,
|
error: Color::Red,
|
||||||
info: Color::LightGreen,
|
info: Color::Rgb(144, 242, 170),
|
||||||
|
agent_thought: Color::Rgb(117, 200, 255),
|
||||||
|
agent_action: Color::Rgb(255, 220, 120),
|
||||||
|
agent_action_input: Color::Rgb(164, 235, 255),
|
||||||
|
agent_observation: Color::Rgb(144, 242, 170),
|
||||||
|
agent_final_answer: Color::Rgb(255, 170, 255),
|
||||||
|
agent_badge_running_fg: Color::Black,
|
||||||
|
agent_badge_running_bg: Color::Yellow,
|
||||||
|
agent_badge_idle_fg: Color::Black,
|
||||||
|
agent_badge_idle_bg: Color::Cyan,
|
||||||
|
operating_chat_fg: Color::Black,
|
||||||
|
operating_chat_bg: Color::Rgb(117, 200, 255),
|
||||||
|
operating_code_fg: Color::Black,
|
||||||
|
operating_code_bg: Color::Rgb(255, 170, 255),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,6 +589,12 @@ fn default_light() -> Theme {
|
|||||||
background: Color::White,
|
background: Color::White,
|
||||||
focused_panel_border: Color::Rgb(74, 144, 226),
|
focused_panel_border: Color::Rgb(74, 144, 226),
|
||||||
unfocused_panel_border: Color::Rgb(221, 221, 221),
|
unfocused_panel_border: Color::Rgb(221, 221, 221),
|
||||||
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
user_message_role: Color::Rgb(0, 85, 164),
|
user_message_role: Color::Rgb(0, 85, 164),
|
||||||
assistant_message_role: Color::Rgb(142, 68, 173),
|
assistant_message_role: Color::Rgb(142, 68, 173),
|
||||||
tool_output: Color::Gray,
|
tool_output: Color::Gray,
|
||||||
@@ -317,9 +611,28 @@ fn default_light() -> Theme {
|
|||||||
selection_bg: Color::Rgb(164, 200, 240),
|
selection_bg: Color::Rgb(164, 200, 240),
|
||||||
selection_fg: Color::Black,
|
selection_fg: Color::Black,
|
||||||
cursor: Color::Rgb(217, 95, 2),
|
cursor: Color::Rgb(217, 95, 2),
|
||||||
|
code_block_background: Color::Rgb(245, 245, 245),
|
||||||
|
code_block_border: Color::Rgb(142, 68, 173),
|
||||||
|
code_block_text: Color::Black,
|
||||||
|
code_block_keyword: Color::Rgb(181, 137, 0),
|
||||||
|
code_block_string: Color::Rgb(46, 139, 87),
|
||||||
|
code_block_comment: Color::Gray,
|
||||||
placeholder: Color::Gray,
|
placeholder: Color::Gray,
|
||||||
error: Color::Rgb(192, 57, 43),
|
error: Color::Rgb(192, 57, 43),
|
||||||
info: Color::Green,
|
info: Color::Green,
|
||||||
|
agent_thought: Color::Rgb(0, 85, 164),
|
||||||
|
agent_action: Color::Rgb(181, 137, 0),
|
||||||
|
agent_action_input: Color::Rgb(0, 139, 139),
|
||||||
|
agent_observation: Color::Rgb(46, 139, 87),
|
||||||
|
agent_final_answer: Color::Rgb(142, 68, 173),
|
||||||
|
agent_badge_running_fg: Color::White,
|
||||||
|
agent_badge_running_bg: Color::Rgb(241, 196, 15),
|
||||||
|
agent_badge_idle_fg: Color::White,
|
||||||
|
agent_badge_idle_bg: Color::Rgb(0, 150, 136),
|
||||||
|
operating_chat_fg: Color::White,
|
||||||
|
operating_chat_bg: Color::Rgb(0, 85, 164),
|
||||||
|
operating_code_fg: Color::White,
|
||||||
|
operating_code_bg: Color::Rgb(142, 68, 173),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,7 +644,13 @@ fn gruvbox() -> Theme {
|
|||||||
background: Color::Rgb(40, 40, 40), // #282828
|
background: Color::Rgb(40, 40, 40), // #282828
|
||||||
focused_panel_border: Color::Rgb(254, 128, 25), // #fe8019 (orange)
|
focused_panel_border: Color::Rgb(254, 128, 25), // #fe8019 (orange)
|
||||||
unfocused_panel_border: Color::Rgb(124, 111, 100), // #7c6f64
|
unfocused_panel_border: Color::Rgb(124, 111, 100), // #7c6f64
|
||||||
user_message_role: Color::Rgb(184, 187, 38), // #b8bb26 (green)
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
|
user_message_role: Color::Rgb(184, 187, 38), // #b8bb26 (green)
|
||||||
assistant_message_role: Color::Rgb(131, 165, 152), // #83a598 (blue)
|
assistant_message_role: Color::Rgb(131, 165, 152), // #83a598 (blue)
|
||||||
tool_output: Color::Rgb(146, 131, 116),
|
tool_output: Color::Rgb(146, 131, 116),
|
||||||
thinking_panel_title: Color::Rgb(211, 134, 155), // #d3869b (purple)
|
thinking_panel_title: Color::Rgb(211, 134, 155), // #d3869b (purple)
|
||||||
@@ -347,9 +666,28 @@ fn gruvbox() -> Theme {
|
|||||||
selection_bg: Color::Rgb(80, 73, 69),
|
selection_bg: Color::Rgb(80, 73, 69),
|
||||||
selection_fg: Color::Rgb(235, 219, 178),
|
selection_fg: Color::Rgb(235, 219, 178),
|
||||||
cursor: Color::Rgb(254, 128, 25),
|
cursor: Color::Rgb(254, 128, 25),
|
||||||
|
code_block_background: Color::Rgb(60, 56, 54),
|
||||||
|
code_block_border: Color::Rgb(124, 111, 100),
|
||||||
|
code_block_text: Color::Rgb(235, 219, 178),
|
||||||
|
code_block_keyword: Color::Rgb(250, 189, 47),
|
||||||
|
code_block_string: Color::Rgb(142, 192, 124),
|
||||||
|
code_block_comment: Color::Rgb(124, 111, 100),
|
||||||
placeholder: Color::Rgb(102, 92, 84),
|
placeholder: Color::Rgb(102, 92, 84),
|
||||||
error: Color::Rgb(251, 73, 52), // #fb4934
|
error: Color::Rgb(251, 73, 52), // #fb4934
|
||||||
info: Color::Rgb(184, 187, 38),
|
info: Color::Rgb(184, 187, 38),
|
||||||
|
agent_thought: Color::Rgb(131, 165, 152),
|
||||||
|
agent_action: Color::Rgb(250, 189, 47),
|
||||||
|
agent_action_input: Color::Rgb(142, 192, 124),
|
||||||
|
agent_observation: Color::Rgb(184, 187, 38),
|
||||||
|
agent_final_answer: Color::Rgb(211, 134, 155),
|
||||||
|
agent_badge_running_fg: Color::Rgb(40, 40, 40),
|
||||||
|
agent_badge_running_bg: Color::Rgb(250, 189, 47),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(40, 40, 40),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(131, 165, 152),
|
||||||
|
operating_chat_fg: Color::Rgb(40, 40, 40),
|
||||||
|
operating_chat_bg: Color::Rgb(131, 165, 152),
|
||||||
|
operating_code_fg: Color::Rgb(40, 40, 40),
|
||||||
|
operating_code_bg: Color::Rgb(211, 134, 155),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,11 +695,17 @@ fn gruvbox() -> Theme {
|
|||||||
fn dracula() -> Theme {
|
fn dracula() -> Theme {
|
||||||
Theme {
|
Theme {
|
||||||
name: "dracula".to_string(),
|
name: "dracula".to_string(),
|
||||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||||
background: Color::Rgb(40, 42, 54), // #282a36
|
background: Color::Rgb(40, 42, 54), // #282a36
|
||||||
focused_panel_border: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
focused_panel_border: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||||
unfocused_panel_border: Color::Rgb(68, 71, 90), // #44475a
|
unfocused_panel_border: Color::Rgb(68, 71, 90), // #44475a
|
||||||
user_message_role: Color::Rgb(139, 233, 253), // #8be9fd (cyan)
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
|
user_message_role: Color::Rgb(139, 233, 253), // #8be9fd (cyan)
|
||||||
assistant_message_role: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
assistant_message_role: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||||
tool_output: Color::Rgb(98, 114, 164),
|
tool_output: Color::Rgb(98, 114, 164),
|
||||||
thinking_panel_title: Color::Rgb(189, 147, 249), // #bd93f9 (purple)
|
thinking_panel_title: Color::Rgb(189, 147, 249), // #bd93f9 (purple)
|
||||||
@@ -377,9 +721,28 @@ fn dracula() -> Theme {
|
|||||||
selection_bg: Color::Rgb(68, 71, 90),
|
selection_bg: Color::Rgb(68, 71, 90),
|
||||||
selection_fg: Color::Rgb(248, 248, 242),
|
selection_fg: Color::Rgb(248, 248, 242),
|
||||||
cursor: Color::Rgb(255, 121, 198),
|
cursor: Color::Rgb(255, 121, 198),
|
||||||
|
code_block_background: Color::Rgb(68, 71, 90),
|
||||||
|
code_block_border: Color::Rgb(189, 147, 249),
|
||||||
|
code_block_text: Color::Rgb(248, 248, 242),
|
||||||
|
code_block_keyword: Color::Rgb(255, 121, 198),
|
||||||
|
code_block_string: Color::Rgb(80, 250, 123),
|
||||||
|
code_block_comment: Color::Rgb(98, 114, 164),
|
||||||
placeholder: Color::Rgb(98, 114, 164),
|
placeholder: Color::Rgb(98, 114, 164),
|
||||||
error: Color::Rgb(255, 85, 85), // #ff5555
|
error: Color::Rgb(255, 85, 85), // #ff5555
|
||||||
info: Color::Rgb(80, 250, 123),
|
info: Color::Rgb(80, 250, 123),
|
||||||
|
agent_thought: Color::Rgb(139, 233, 253),
|
||||||
|
agent_action: Color::Rgb(241, 250, 140),
|
||||||
|
agent_action_input: Color::Rgb(189, 147, 249),
|
||||||
|
agent_observation: Color::Rgb(80, 250, 123),
|
||||||
|
agent_final_answer: Color::Rgb(255, 121, 198),
|
||||||
|
agent_badge_running_fg: Color::Rgb(40, 42, 54),
|
||||||
|
agent_badge_running_bg: Color::Rgb(241, 250, 140),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(40, 42, 54),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(139, 233, 253),
|
||||||
|
operating_chat_fg: Color::Rgb(40, 42, 54),
|
||||||
|
operating_chat_bg: Color::Rgb(139, 233, 253),
|
||||||
|
operating_code_fg: Color::Rgb(40, 42, 54),
|
||||||
|
operating_code_bg: Color::Rgb(189, 147, 249),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,6 +754,12 @@ fn solarized() -> Theme {
|
|||||||
background: Color::Rgb(0, 43, 54), // #002b36 (base03)
|
background: Color::Rgb(0, 43, 54), // #002b36 (base03)
|
||||||
focused_panel_border: Color::Rgb(38, 139, 210), // #268bd2 (blue)
|
focused_panel_border: Color::Rgb(38, 139, 210), // #268bd2 (blue)
|
||||||
unfocused_panel_border: Color::Rgb(7, 54, 66), // #073642 (base02)
|
unfocused_panel_border: Color::Rgb(7, 54, 66), // #073642 (base02)
|
||||||
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
user_message_role: Color::Rgb(42, 161, 152), // #2aa198 (cyan)
|
user_message_role: Color::Rgb(42, 161, 152), // #2aa198 (cyan)
|
||||||
assistant_message_role: Color::Rgb(203, 75, 22), // #cb4b16 (orange)
|
assistant_message_role: Color::Rgb(203, 75, 22), // #cb4b16 (orange)
|
||||||
tool_output: Color::Rgb(101, 123, 131),
|
tool_output: Color::Rgb(101, 123, 131),
|
||||||
@@ -407,9 +776,28 @@ fn solarized() -> Theme {
|
|||||||
selection_bg: Color::Rgb(7, 54, 66),
|
selection_bg: Color::Rgb(7, 54, 66),
|
||||||
selection_fg: Color::Rgb(147, 161, 161),
|
selection_fg: Color::Rgb(147, 161, 161),
|
||||||
cursor: Color::Rgb(211, 54, 130),
|
cursor: Color::Rgb(211, 54, 130),
|
||||||
|
code_block_background: Color::Rgb(7, 54, 66),
|
||||||
|
code_block_border: Color::Rgb(38, 139, 210),
|
||||||
|
code_block_text: Color::Rgb(147, 161, 161),
|
||||||
|
code_block_keyword: Color::Rgb(181, 137, 0),
|
||||||
|
code_block_string: Color::Rgb(133, 153, 0),
|
||||||
|
code_block_comment: Color::Rgb(88, 110, 117),
|
||||||
placeholder: Color::Rgb(88, 110, 117),
|
placeholder: Color::Rgb(88, 110, 117),
|
||||||
error: Color::Rgb(220, 50, 47), // #dc322f (red)
|
error: Color::Rgb(220, 50, 47), // #dc322f (red)
|
||||||
info: Color::Rgb(133, 153, 0),
|
info: Color::Rgb(133, 153, 0),
|
||||||
|
agent_thought: Color::Rgb(42, 161, 152),
|
||||||
|
agent_action: Color::Rgb(181, 137, 0),
|
||||||
|
agent_action_input: Color::Rgb(38, 139, 210),
|
||||||
|
agent_observation: Color::Rgb(133, 153, 0),
|
||||||
|
agent_final_answer: Color::Rgb(108, 113, 196),
|
||||||
|
agent_badge_running_fg: Color::Rgb(0, 43, 54),
|
||||||
|
agent_badge_running_bg: Color::Rgb(181, 137, 0),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(0, 43, 54),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(42, 161, 152),
|
||||||
|
operating_chat_fg: Color::Rgb(0, 43, 54),
|
||||||
|
operating_chat_bg: Color::Rgb(42, 161, 152),
|
||||||
|
operating_code_fg: Color::Rgb(0, 43, 54),
|
||||||
|
operating_code_bg: Color::Rgb(108, 113, 196),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,6 +809,12 @@ fn midnight_ocean() -> Theme {
|
|||||||
background: Color::Rgb(13, 17, 23),
|
background: Color::Rgb(13, 17, 23),
|
||||||
focused_panel_border: Color::Rgb(88, 166, 255),
|
focused_panel_border: Color::Rgb(88, 166, 255),
|
||||||
unfocused_panel_border: Color::Rgb(48, 54, 61),
|
unfocused_panel_border: Color::Rgb(48, 54, 61),
|
||||||
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
user_message_role: Color::Rgb(121, 192, 255),
|
user_message_role: Color::Rgb(121, 192, 255),
|
||||||
assistant_message_role: Color::Rgb(137, 221, 255),
|
assistant_message_role: Color::Rgb(137, 221, 255),
|
||||||
tool_output: Color::Rgb(84, 110, 122),
|
tool_output: Color::Rgb(84, 110, 122),
|
||||||
@@ -437,9 +831,28 @@ fn midnight_ocean() -> Theme {
|
|||||||
selection_bg: Color::Rgb(56, 139, 253),
|
selection_bg: Color::Rgb(56, 139, 253),
|
||||||
selection_fg: Color::Rgb(13, 17, 23),
|
selection_fg: Color::Rgb(13, 17, 23),
|
||||||
cursor: Color::Rgb(246, 140, 245),
|
cursor: Color::Rgb(246, 140, 245),
|
||||||
|
code_block_background: Color::Rgb(22, 27, 34),
|
||||||
|
code_block_border: Color::Rgb(88, 166, 255),
|
||||||
|
code_block_text: Color::Rgb(192, 202, 245),
|
||||||
|
code_block_keyword: Color::Rgb(255, 212, 59),
|
||||||
|
code_block_string: Color::Rgb(158, 206, 106),
|
||||||
|
code_block_comment: Color::Rgb(110, 118, 129),
|
||||||
placeholder: Color::Rgb(110, 118, 129),
|
placeholder: Color::Rgb(110, 118, 129),
|
||||||
error: Color::Rgb(248, 81, 73),
|
error: Color::Rgb(248, 81, 73),
|
||||||
info: Color::Rgb(158, 206, 106),
|
info: Color::Rgb(158, 206, 106),
|
||||||
|
agent_thought: Color::Rgb(121, 192, 255),
|
||||||
|
agent_action: Color::Rgb(255, 212, 59),
|
||||||
|
agent_action_input: Color::Rgb(137, 221, 255),
|
||||||
|
agent_observation: Color::Rgb(158, 206, 106),
|
||||||
|
agent_final_answer: Color::Rgb(246, 140, 245),
|
||||||
|
agent_badge_running_fg: Color::Rgb(13, 17, 23),
|
||||||
|
agent_badge_running_bg: Color::Rgb(255, 212, 59),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(13, 17, 23),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(137, 221, 255),
|
||||||
|
operating_chat_fg: Color::Rgb(13, 17, 23),
|
||||||
|
operating_chat_bg: Color::Rgb(121, 192, 255),
|
||||||
|
operating_code_fg: Color::Rgb(13, 17, 23),
|
||||||
|
operating_code_bg: Color::Rgb(246, 140, 245),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,11 +860,17 @@ fn midnight_ocean() -> Theme {
|
|||||||
fn rose_pine() -> Theme {
|
fn rose_pine() -> Theme {
|
||||||
Theme {
|
Theme {
|
||||||
name: "rose-pine".to_string(),
|
name: "rose-pine".to_string(),
|
||||||
text: Color::Rgb(224, 222, 244), // #e0def4
|
text: Color::Rgb(224, 222, 244), // #e0def4
|
||||||
background: Color::Rgb(25, 23, 36), // #191724
|
background: Color::Rgb(25, 23, 36), // #191724
|
||||||
focused_panel_border: Color::Rgb(235, 111, 146), // #eb6f92 (love)
|
focused_panel_border: Color::Rgb(235, 111, 146), // #eb6f92 (love)
|
||||||
unfocused_panel_border: Color::Rgb(38, 35, 58), // #26233a
|
unfocused_panel_border: Color::Rgb(38, 35, 58), // #26233a
|
||||||
user_message_role: Color::Rgb(49, 116, 143), // #31748f (foam)
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
|
user_message_role: Color::Rgb(49, 116, 143), // #31748f (foam)
|
||||||
assistant_message_role: Color::Rgb(156, 207, 216), // #9ccfd8 (foam light)
|
assistant_message_role: Color::Rgb(156, 207, 216), // #9ccfd8 (foam light)
|
||||||
tool_output: Color::Rgb(110, 106, 134),
|
tool_output: Color::Rgb(110, 106, 134),
|
||||||
thinking_panel_title: Color::Rgb(196, 167, 231), // #c4a7e7 (iris)
|
thinking_panel_title: Color::Rgb(196, 167, 231), // #c4a7e7 (iris)
|
||||||
@@ -467,9 +886,28 @@ fn rose_pine() -> Theme {
|
|||||||
selection_bg: Color::Rgb(64, 61, 82),
|
selection_bg: Color::Rgb(64, 61, 82),
|
||||||
selection_fg: Color::Rgb(224, 222, 244),
|
selection_fg: Color::Rgb(224, 222, 244),
|
||||||
cursor: Color::Rgb(235, 111, 146),
|
cursor: Color::Rgb(235, 111, 146),
|
||||||
|
code_block_background: Color::Rgb(38, 35, 58),
|
||||||
|
code_block_border: Color::Rgb(235, 111, 146),
|
||||||
|
code_block_text: Color::Rgb(224, 222, 244),
|
||||||
|
code_block_keyword: Color::Rgb(246, 193, 119),
|
||||||
|
code_block_string: Color::Rgb(156, 207, 216),
|
||||||
|
code_block_comment: Color::Rgb(110, 106, 134),
|
||||||
placeholder: Color::Rgb(110, 106, 134),
|
placeholder: Color::Rgb(110, 106, 134),
|
||||||
error: Color::Rgb(235, 111, 146),
|
error: Color::Rgb(235, 111, 146),
|
||||||
info: Color::Rgb(156, 207, 216),
|
info: Color::Rgb(156, 207, 216),
|
||||||
|
agent_thought: Color::Rgb(156, 207, 216),
|
||||||
|
agent_action: Color::Rgb(246, 193, 119),
|
||||||
|
agent_action_input: Color::Rgb(196, 167, 231),
|
||||||
|
agent_observation: Color::Rgb(235, 188, 186),
|
||||||
|
agent_final_answer: Color::Rgb(235, 111, 146),
|
||||||
|
agent_badge_running_fg: Color::Rgb(25, 23, 36),
|
||||||
|
agent_badge_running_bg: Color::Rgb(246, 193, 119),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(25, 23, 36),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(156, 207, 216),
|
||||||
|
operating_chat_fg: Color::Rgb(25, 23, 36),
|
||||||
|
operating_chat_bg: Color::Rgb(156, 207, 216),
|
||||||
|
operating_code_fg: Color::Rgb(25, 23, 36),
|
||||||
|
operating_code_bg: Color::Rgb(196, 167, 231),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -477,11 +915,17 @@ fn rose_pine() -> Theme {
|
|||||||
fn monokai() -> Theme {
|
fn monokai() -> Theme {
|
||||||
Theme {
|
Theme {
|
||||||
name: "monokai".to_string(),
|
name: "monokai".to_string(),
|
||||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||||
background: Color::Rgb(39, 40, 34), // #272822
|
background: Color::Rgb(39, 40, 34), // #272822
|
||||||
focused_panel_border: Color::Rgb(249, 38, 114), // #f92672 (pink)
|
focused_panel_border: Color::Rgb(249, 38, 114), // #f92672 (pink)
|
||||||
unfocused_panel_border: Color::Rgb(117, 113, 94), // #75715e
|
unfocused_panel_border: Color::Rgb(117, 113, 94), // #75715e
|
||||||
user_message_role: Color::Rgb(102, 217, 239), // #66d9ef (cyan)
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
|
user_message_role: Color::Rgb(102, 217, 239), // #66d9ef (cyan)
|
||||||
assistant_message_role: Color::Rgb(174, 129, 255), // #ae81ff (purple)
|
assistant_message_role: Color::Rgb(174, 129, 255), // #ae81ff (purple)
|
||||||
tool_output: Color::Rgb(117, 113, 94),
|
tool_output: Color::Rgb(117, 113, 94),
|
||||||
thinking_panel_title: Color::Rgb(230, 219, 116), // #e6db74 (yellow)
|
thinking_panel_title: Color::Rgb(230, 219, 116), // #e6db74 (yellow)
|
||||||
@@ -497,9 +941,28 @@ fn monokai() -> Theme {
|
|||||||
selection_bg: Color::Rgb(117, 113, 94),
|
selection_bg: Color::Rgb(117, 113, 94),
|
||||||
selection_fg: Color::Rgb(248, 248, 242),
|
selection_fg: Color::Rgb(248, 248, 242),
|
||||||
cursor: Color::Rgb(249, 38, 114),
|
cursor: Color::Rgb(249, 38, 114),
|
||||||
|
code_block_background: Color::Rgb(50, 51, 46),
|
||||||
|
code_block_border: Color::Rgb(249, 38, 114),
|
||||||
|
code_block_text: Color::Rgb(248, 248, 242),
|
||||||
|
code_block_keyword: Color::Rgb(230, 219, 116),
|
||||||
|
code_block_string: Color::Rgb(166, 226, 46),
|
||||||
|
code_block_comment: Color::Rgb(117, 113, 94),
|
||||||
placeholder: Color::Rgb(117, 113, 94),
|
placeholder: Color::Rgb(117, 113, 94),
|
||||||
error: Color::Rgb(249, 38, 114),
|
error: Color::Rgb(249, 38, 114),
|
||||||
info: Color::Rgb(166, 226, 46),
|
info: Color::Rgb(166, 226, 46),
|
||||||
|
agent_thought: Color::Rgb(102, 217, 239),
|
||||||
|
agent_action: Color::Rgb(230, 219, 116),
|
||||||
|
agent_action_input: Color::Rgb(174, 129, 255),
|
||||||
|
agent_observation: Color::Rgb(166, 226, 46),
|
||||||
|
agent_final_answer: Color::Rgb(249, 38, 114),
|
||||||
|
agent_badge_running_fg: Color::Rgb(39, 40, 34),
|
||||||
|
agent_badge_running_bg: Color::Rgb(230, 219, 116),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(39, 40, 34),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(102, 217, 239),
|
||||||
|
operating_chat_fg: Color::Rgb(39, 40, 34),
|
||||||
|
operating_chat_bg: Color::Rgb(102, 217, 239),
|
||||||
|
operating_code_fg: Color::Rgb(39, 40, 34),
|
||||||
|
operating_code_bg: Color::Rgb(174, 129, 255),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -507,11 +970,17 @@ fn monokai() -> Theme {
|
|||||||
fn material_dark() -> Theme {
|
fn material_dark() -> Theme {
|
||||||
Theme {
|
Theme {
|
||||||
name: "material-dark".to_string(),
|
name: "material-dark".to_string(),
|
||||||
text: Color::Rgb(238, 255, 255), // #eeffff
|
text: Color::Rgb(238, 255, 255), // #eeffff
|
||||||
background: Color::Rgb(38, 50, 56), // #263238
|
background: Color::Rgb(38, 50, 56), // #263238
|
||||||
focused_panel_border: Color::Rgb(128, 203, 196), // #80cbc4 (cyan)
|
focused_panel_border: Color::Rgb(128, 203, 196), // #80cbc4 (cyan)
|
||||||
unfocused_panel_border: Color::Rgb(84, 110, 122), // #546e7a
|
unfocused_panel_border: Color::Rgb(84, 110, 122), // #546e7a
|
||||||
user_message_role: Color::Rgb(130, 170, 255), // #82aaff (blue)
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
|
user_message_role: Color::Rgb(130, 170, 255), // #82aaff (blue)
|
||||||
assistant_message_role: Color::Rgb(199, 146, 234), // #c792ea (purple)
|
assistant_message_role: Color::Rgb(199, 146, 234), // #c792ea (purple)
|
||||||
tool_output: Color::Rgb(84, 110, 122),
|
tool_output: Color::Rgb(84, 110, 122),
|
||||||
thinking_panel_title: Color::Rgb(255, 203, 107), // #ffcb6b (yellow)
|
thinking_panel_title: Color::Rgb(255, 203, 107), // #ffcb6b (yellow)
|
||||||
@@ -527,9 +996,28 @@ fn material_dark() -> Theme {
|
|||||||
selection_bg: Color::Rgb(84, 110, 122),
|
selection_bg: Color::Rgb(84, 110, 122),
|
||||||
selection_fg: Color::Rgb(238, 255, 255),
|
selection_fg: Color::Rgb(238, 255, 255),
|
||||||
cursor: Color::Rgb(255, 204, 0),
|
cursor: Color::Rgb(255, 204, 0),
|
||||||
|
code_block_background: Color::Rgb(33, 43, 48),
|
||||||
|
code_block_border: Color::Rgb(128, 203, 196),
|
||||||
|
code_block_text: Color::Rgb(238, 255, 255),
|
||||||
|
code_block_keyword: Color::Rgb(255, 203, 107),
|
||||||
|
code_block_string: Color::Rgb(195, 232, 141),
|
||||||
|
code_block_comment: Color::Rgb(84, 110, 122),
|
||||||
placeholder: Color::Rgb(84, 110, 122),
|
placeholder: Color::Rgb(84, 110, 122),
|
||||||
error: Color::Rgb(240, 113, 120),
|
error: Color::Rgb(240, 113, 120),
|
||||||
info: Color::Rgb(195, 232, 141),
|
info: Color::Rgb(195, 232, 141),
|
||||||
|
agent_thought: Color::Rgb(128, 203, 196),
|
||||||
|
agent_action: Color::Rgb(255, 203, 107),
|
||||||
|
agent_action_input: Color::Rgb(199, 146, 234),
|
||||||
|
agent_observation: Color::Rgb(195, 232, 141),
|
||||||
|
agent_final_answer: Color::Rgb(240, 113, 120),
|
||||||
|
agent_badge_running_fg: Color::Rgb(38, 50, 56),
|
||||||
|
agent_badge_running_bg: Color::Rgb(255, 203, 107),
|
||||||
|
agent_badge_idle_fg: Color::Rgb(38, 50, 56),
|
||||||
|
agent_badge_idle_bg: Color::Rgb(128, 203, 196),
|
||||||
|
operating_chat_fg: Color::Rgb(38, 50, 56),
|
||||||
|
operating_chat_bg: Color::Rgb(130, 170, 255),
|
||||||
|
operating_code_fg: Color::Rgb(38, 50, 56),
|
||||||
|
operating_code_bg: Color::Rgb(199, 146, 234),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -541,6 +1029,12 @@ fn material_light() -> Theme {
|
|||||||
background: Color::Rgb(236, 239, 241),
|
background: Color::Rgb(236, 239, 241),
|
||||||
focused_panel_border: Color::Rgb(0, 150, 136),
|
focused_panel_border: Color::Rgb(0, 150, 136),
|
||||||
unfocused_panel_border: Color::Rgb(176, 190, 197),
|
unfocused_panel_border: Color::Rgb(176, 190, 197),
|
||||||
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
user_message_role: Color::Rgb(68, 138, 255),
|
user_message_role: Color::Rgb(68, 138, 255),
|
||||||
assistant_message_role: Color::Rgb(124, 77, 255),
|
assistant_message_role: Color::Rgb(124, 77, 255),
|
||||||
tool_output: Color::Rgb(144, 164, 174),
|
tool_output: Color::Rgb(144, 164, 174),
|
||||||
@@ -557,9 +1051,83 @@ fn material_light() -> Theme {
|
|||||||
selection_bg: Color::Rgb(176, 190, 197),
|
selection_bg: Color::Rgb(176, 190, 197),
|
||||||
selection_fg: Color::Rgb(33, 33, 33),
|
selection_fg: Color::Rgb(33, 33, 33),
|
||||||
cursor: Color::Rgb(194, 24, 91),
|
cursor: Color::Rgb(194, 24, 91),
|
||||||
|
code_block_background: Color::Rgb(248, 249, 250),
|
||||||
|
code_block_border: Color::Rgb(0, 150, 136),
|
||||||
|
code_block_text: Color::Rgb(33, 33, 33),
|
||||||
|
code_block_keyword: Color::Rgb(245, 124, 0),
|
||||||
|
code_block_string: Color::Rgb(56, 142, 60),
|
||||||
|
code_block_comment: Color::Rgb(144, 164, 174),
|
||||||
placeholder: Color::Rgb(144, 164, 174),
|
placeholder: Color::Rgb(144, 164, 174),
|
||||||
error: Color::Rgb(211, 47, 47),
|
error: Color::Rgb(211, 47, 47),
|
||||||
info: Color::Rgb(56, 142, 60),
|
info: Color::Rgb(56, 142, 60),
|
||||||
|
agent_thought: Color::Rgb(68, 138, 255),
|
||||||
|
agent_action: Color::Rgb(245, 124, 0),
|
||||||
|
agent_action_input: Color::Rgb(124, 77, 255),
|
||||||
|
agent_observation: Color::Rgb(56, 142, 60),
|
||||||
|
agent_final_answer: Color::Rgb(211, 47, 47),
|
||||||
|
agent_badge_running_fg: Color::White,
|
||||||
|
agent_badge_running_bg: Color::Rgb(245, 124, 0),
|
||||||
|
agent_badge_idle_fg: Color::White,
|
||||||
|
agent_badge_idle_bg: Color::Rgb(0, 150, 136),
|
||||||
|
operating_chat_fg: Color::White,
|
||||||
|
operating_chat_bg: Color::Rgb(68, 138, 255),
|
||||||
|
operating_code_fg: Color::White,
|
||||||
|
operating_code_bg: Color::Rgb(124, 77, 255),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Grayscale high-contrast theme
|
||||||
|
fn grayscale_high_contrast() -> Theme {
|
||||||
|
Theme {
|
||||||
|
name: "grayscale_high_contrast".to_string(),
|
||||||
|
text: Color::Rgb(247, 247, 247),
|
||||||
|
background: Color::Black,
|
||||||
|
focused_panel_border: Color::White,
|
||||||
|
unfocused_panel_border: Color::Rgb(76, 76, 76),
|
||||||
|
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||||
|
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||||
|
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||||
|
pane_header_active: Theme::default_pane_header_active(),
|
||||||
|
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||||
|
pane_hint_text: Theme::default_pane_hint_text(),
|
||||||
|
user_message_role: Color::Rgb(240, 240, 240),
|
||||||
|
assistant_message_role: Color::Rgb(214, 214, 214),
|
||||||
|
tool_output: Color::Rgb(189, 189, 189),
|
||||||
|
thinking_panel_title: Color::Rgb(224, 224, 224),
|
||||||
|
command_bar_background: Color::Black,
|
||||||
|
status_background: Color::Rgb(15, 15, 15),
|
||||||
|
mode_normal: Color::White,
|
||||||
|
mode_editing: Color::Rgb(230, 230, 230),
|
||||||
|
mode_model_selection: Color::Rgb(204, 204, 204),
|
||||||
|
mode_provider_selection: Color::Rgb(179, 179, 179),
|
||||||
|
mode_help: Color::Rgb(153, 153, 153),
|
||||||
|
mode_visual: Color::Rgb(242, 242, 242),
|
||||||
|
mode_command: Color::Rgb(208, 208, 208),
|
||||||
|
selection_bg: Color::Rgb(240, 240, 240),
|
||||||
|
selection_fg: Color::Black,
|
||||||
|
cursor: Color::White,
|
||||||
|
code_block_background: Color::Rgb(15, 15, 15),
|
||||||
|
code_block_border: Color::White,
|
||||||
|
code_block_text: Color::Rgb(247, 247, 247),
|
||||||
|
code_block_keyword: Color::Rgb(204, 204, 204),
|
||||||
|
code_block_string: Color::Rgb(214, 214, 214),
|
||||||
|
code_block_comment: Color::Rgb(122, 122, 122),
|
||||||
|
placeholder: Color::Rgb(122, 122, 122),
|
||||||
|
error: Color::White,
|
||||||
|
info: Color::Rgb(200, 200, 200),
|
||||||
|
agent_thought: Color::Rgb(230, 230, 230),
|
||||||
|
agent_action: Color::Rgb(204, 204, 204),
|
||||||
|
agent_action_input: Color::Rgb(176, 176, 176),
|
||||||
|
agent_observation: Color::Rgb(153, 153, 153),
|
||||||
|
agent_final_answer: Color::White,
|
||||||
|
agent_badge_running_fg: Color::Black,
|
||||||
|
agent_badge_running_bg: Color::Rgb(247, 247, 247),
|
||||||
|
agent_badge_idle_fg: Color::Black,
|
||||||
|
agent_badge_idle_bg: Color::Rgb(189, 189, 189),
|
||||||
|
operating_chat_fg: Color::Black,
|
||||||
|
operating_chat_bg: Color::Rgb(242, 242, 242),
|
||||||
|
operating_code_fg: Color::Black,
|
||||||
|
operating_code_bg: Color::Rgb(191, 191, 191),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -582,16 +1150,16 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn parse_color(s: &str) -> Result<Color, String> {
|
fn parse_color(s: &str) -> Result<Color, String> {
|
||||||
if let Some(hex) = s.strip_prefix('#') {
|
if let Some(hex) = s.strip_prefix('#')
|
||||||
if hex.len() == 6 {
|
&& hex.len() == 6
|
||||||
let r = u8::from_str_radix(&hex[0..2], 16)
|
{
|
||||||
.map_err(|_| format!("Invalid hex color: {}", s))?;
|
let r =
|
||||||
let g = u8::from_str_radix(&hex[2..4], 16)
|
u8::from_str_radix(&hex[0..2], 16).map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||||
.map_err(|_| format!("Invalid hex color: {}", s))?;
|
let g =
|
||||||
let b = u8::from_str_radix(&hex[4..6], 16)
|
u8::from_str_radix(&hex[2..4], 16).map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||||
.map_err(|_| format!("Invalid hex color: {}", s))?;
|
let b =
|
||||||
return Ok(Color::Rgb(r, g, b));
|
u8::from_str_radix(&hex[4..6], 16).map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||||
}
|
return Ok(Color::Rgb(r, g, b));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try named colors
|
// Try named colors
|
||||||
@@ -656,5 +1224,6 @@ mod tests {
|
|||||||
assert!(themes.contains_key("default_dark"));
|
assert!(themes.contains_key("default_dark"));
|
||||||
assert!(themes.contains_key("gruvbox"));
|
assert!(themes.contains_key("gruvbox"));
|
||||||
assert!(themes.contains_key("dracula"));
|
assert!(themes.contains_key("dracula"));
|
||||||
|
assert!(themes.contains_key("grayscale-high-contrast"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,11 +8,12 @@
|
|||||||
pub mod code_exec;
|
pub mod code_exec;
|
||||||
pub mod fs_tools;
|
pub mod fs_tools;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
|
pub mod web_scrape;
|
||||||
pub mod web_search;
|
pub mod web_search;
|
||||||
pub mod web_search_detailed;
|
pub mod web_search_detailed;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
@@ -91,5 +92,6 @@ impl ToolResult {
|
|||||||
pub use code_exec::CodeExecTool;
|
pub use code_exec::CodeExecTool;
|
||||||
pub use fs_tools::{ResourcesDeleteTool, ResourcesGetTool, ResourcesListTool, ResourcesWriteTool};
|
pub use fs_tools::{ResourcesDeleteTool, ResourcesGetTool, ResourcesListTool, ResourcesWriteTool};
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
|
pub use web_scrape::WebScrapeTool;
|
||||||
pub use web_search::WebSearchTool;
|
pub use web_search::WebSearchTool;
|
||||||
pub use web_search_detailed::WebSearchDetailedTool;
|
pub use web_search_detailed::WebSearchDetailedTool;
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ use std::sync::Arc;
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{Context, anyhow};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
use super::{Tool, ToolResult};
|
use super::{Tool, ToolResult};
|
||||||
use crate::sandbox::{SandboxConfig, SandboxedProcess};
|
use crate::sandbox::{SandboxConfig, SandboxedProcess};
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use serde_json::Value;
|
|||||||
|
|
||||||
use super::{Tool, ToolResult};
|
use super::{Tool, ToolResult};
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
use crate::mode::Mode;
|
||||||
use crate::ui::UiController;
|
use crate::ui::UiController;
|
||||||
|
|
||||||
pub struct ToolRegistry {
|
pub struct ToolRegistry {
|
||||||
@@ -41,13 +42,33 @@ impl ToolRegistry {
|
|||||||
self.tools.values().cloned().collect()
|
self.tools.values().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
|
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
|
||||||
let tool = self
|
let tool = self
|
||||||
.get(name)
|
.get(name)
|
||||||
.with_context(|| format!("Tool not registered: {}", name))?;
|
.with_context(|| format!("Tool not registered: {}", name))?;
|
||||||
|
|
||||||
let mut config = self.config.lock().await;
|
let mut config = self.config.lock().await;
|
||||||
|
|
||||||
|
// Check mode-based tool availability first
|
||||||
|
if !config.modes.is_tool_allowed(mode, name) {
|
||||||
|
let alternate_mode = match mode {
|
||||||
|
Mode::Chat => Mode::Code,
|
||||||
|
Mode::Code => Mode::Chat,
|
||||||
|
};
|
||||||
|
|
||||||
|
if config.modes.is_tool_allowed(alternate_mode, name) {
|
||||||
|
return Ok(ToolResult::error(&format!(
|
||||||
|
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
|
||||||
|
name, mode, alternate_mode, alternate_mode
|
||||||
|
)));
|
||||||
|
} else {
|
||||||
|
return Ok(ToolResult::error(&format!(
|
||||||
|
"Tool '{}' is not available in any mode. Check your configuration.",
|
||||||
|
name
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let is_enabled = match name {
|
let is_enabled = match name {
|
||||||
"web_search" => config.tools.web_search.enabled,
|
"web_search" => config.tools.web_search.enabled,
|
||||||
"code_exec" => config.tools.code_exec.enabled,
|
"code_exec" => config.tools.code_exec.enabled,
|
||||||
@@ -77,6 +98,16 @@ impl ToolRegistry {
|
|||||||
tool.execute(args).await
|
tool.execute(args).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get all tools available in the given mode
|
||||||
|
pub async fn available_tools(&self, mode: Mode) -> Vec<String> {
|
||||||
|
let config = self.config.lock().await;
|
||||||
|
self.tools
|
||||||
|
.keys()
|
||||||
|
.filter(|name| config.modes.is_tool_allowed(mode, name))
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tools(&self) -> Vec<String> {
|
pub fn tools(&self) -> Vec<String> {
|
||||||
self.tools.keys().cloned().collect()
|
self.tools.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|||||||
102
crates/owlen-core/src/tools/web_scrape.rs
Normal file
102
crates/owlen-core/src/tools/web_scrape.rs
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
use super::{Tool, ToolResult};
|
||||||
|
use crate::Result;
|
||||||
|
use anyhow::Context;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
|
/// Tool that fetches the raw HTML content for a list of URLs.
|
||||||
|
///
|
||||||
|
/// Input schema expects:
|
||||||
|
/// urls: array of strings (max 5 URLs)
|
||||||
|
/// timeout_secs: optional integer per‑request timeout (default 10)
|
||||||
|
pub struct WebScrapeTool {
|
||||||
|
// No special dependencies; uses reqwest_011 for compatibility with existing web_search.
|
||||||
|
client: reqwest_011::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WebScrapeTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebScrapeTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let client = reqwest_011::Client::builder()
|
||||||
|
.user_agent("OwlenWebScrape/0.1")
|
||||||
|
.build()
|
||||||
|
.expect("Failed to build reqwest client");
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for WebScrapeTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"web_scrape"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"Fetch raw HTML content for a list of URLs"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"urls": {
|
||||||
|
"type": "array",
|
||||||
|
"items": { "type": "string", "format": "uri" },
|
||||||
|
"minItems": 1,
|
||||||
|
"maxItems": 5,
|
||||||
|
"description": "List of URLs to scrape"
|
||||||
|
},
|
||||||
|
"timeout_secs": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 30,
|
||||||
|
"default": 10,
|
||||||
|
"description": "Per‑request timeout in seconds"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["urls"],
|
||||||
|
"additionalProperties": false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn requires_network(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||||
|
let urls = args
|
||||||
|
.get("urls")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.context("Missing 'urls' array")?;
|
||||||
|
let timeout_secs = args
|
||||||
|
.get("timeout_secs")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(10);
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for url_val in urls {
|
||||||
|
let url = url_val.as_str().unwrap_or("");
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.get(url)
|
||||||
|
.timeout(std::time::Duration::from_secs(timeout_secs))
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
match resp {
|
||||||
|
Ok(r) => {
|
||||||
|
let text = r.text().await.unwrap_or_default();
|
||||||
|
results.push(json!({ "url": url, "content": text }));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
results.push(json!({ "url": url, "error": e.to_string() }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ToolResult::success(json!({ "pages": results })))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ use std::time::Instant;
|
|||||||
use crate::Result;
|
use crate::Result;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
use super::{Tool, ToolResult};
|
use super::{Tool, ToolResult};
|
||||||
use crate::consent::ConsentManager;
|
use crate::consent::ConsentManager;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use std::time::Instant;
|
|||||||
use crate::Result;
|
use crate::Result;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
use super::{Tool, ToolResult};
|
use super::{Tool, ToolResult};
|
||||||
use crate::consent::ConsentManager;
|
use crate::consent::ConsentManager;
|
||||||
@@ -86,7 +86,9 @@ impl Tool for WebSearchDetailedTool {
|
|||||||
.expect("Consent manager mutex poisoned");
|
.expect("Consent manager mutex poisoned");
|
||||||
|
|
||||||
if !consent.has_consent(self.name()) {
|
if !consent.has_consent(self.name()) {
|
||||||
return Ok(ToolResult::error("Consent not granted for detailed web search. This should have been handled by the TUI."));
|
return Ok(ToolResult::error(
|
||||||
|
"Consent not granted for detailed web search. This should have been handled by the TUI.",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,170 +3,30 @@
|
|||||||
//! This module contains reusable UI components that can be shared between
|
//! This module contains reusable UI components that can be shared between
|
||||||
//! different TUI applications (chat, code, etc.)
|
//! different TUI applications (chat, code, etc.)
|
||||||
|
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
/// Application state
|
/// Application state
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
pub use crate::state::AppState;
|
||||||
pub enum AppState {
|
|
||||||
Running,
|
|
||||||
Quit,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Input modes for TUI applications
|
/// Input modes for TUI applications
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
pub use crate::state::InputMode;
|
||||||
pub enum InputMode {
|
|
||||||
Normal,
|
|
||||||
Editing,
|
|
||||||
ProviderSelection,
|
|
||||||
ModelSelection,
|
|
||||||
Help,
|
|
||||||
Visual,
|
|
||||||
Command,
|
|
||||||
SessionBrowser,
|
|
||||||
ThemeBrowser,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for InputMode {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
let label = match self {
|
|
||||||
InputMode::Normal => "Normal",
|
|
||||||
InputMode::Editing => "Editing",
|
|
||||||
InputMode::ModelSelection => "Model",
|
|
||||||
InputMode::ProviderSelection => "Provider",
|
|
||||||
InputMode::Help => "Help",
|
|
||||||
InputMode::Visual => "Visual",
|
|
||||||
InputMode::Command => "Command",
|
|
||||||
InputMode::SessionBrowser => "Sessions",
|
|
||||||
InputMode::ThemeBrowser => "Themes",
|
|
||||||
};
|
|
||||||
f.write_str(label)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents which panel is currently focused
|
/// Represents which panel is currently focused
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
pub use crate::state::FocusedPanel;
|
||||||
pub enum FocusedPanel {
|
|
||||||
Chat,
|
|
||||||
Thinking,
|
|
||||||
Input,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Auto-scroll state manager for scrollable panels
|
/// Auto-scroll state manager for scrollable panels
|
||||||
#[derive(Debug, Clone)]
|
pub use crate::state::AutoScroll;
|
||||||
pub struct AutoScroll {
|
|
||||||
pub scroll: usize,
|
|
||||||
pub content_len: usize,
|
|
||||||
pub stick_to_bottom: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for AutoScroll {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
scroll: 0,
|
|
||||||
content_len: 0,
|
|
||||||
stick_to_bottom: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AutoScroll {
|
|
||||||
/// Update scroll position based on viewport height
|
|
||||||
pub fn on_viewport(&mut self, viewport_h: usize) {
|
|
||||||
let max = self.content_len.saturating_sub(viewport_h);
|
|
||||||
if self.stick_to_bottom {
|
|
||||||
self.scroll = max;
|
|
||||||
} else {
|
|
||||||
self.scroll = self.scroll.min(max);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle user scroll input
|
|
||||||
pub fn on_user_scroll(&mut self, delta: isize, viewport_h: usize) {
|
|
||||||
let max = self.content_len.saturating_sub(viewport_h) as isize;
|
|
||||||
let s = (self.scroll as isize + delta).clamp(0, max) as usize;
|
|
||||||
self.scroll = s;
|
|
||||||
self.stick_to_bottom = s as isize == max;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Scroll down half page
|
|
||||||
pub fn scroll_half_page_down(&mut self, viewport_h: usize) {
|
|
||||||
let delta = (viewport_h / 2) as isize;
|
|
||||||
self.on_user_scroll(delta, viewport_h);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Scroll up half page
|
|
||||||
pub fn scroll_half_page_up(&mut self, viewport_h: usize) {
|
|
||||||
let delta = -((viewport_h / 2) as isize);
|
|
||||||
self.on_user_scroll(delta, viewport_h);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Scroll down full page
|
|
||||||
pub fn scroll_full_page_down(&mut self, viewport_h: usize) {
|
|
||||||
let delta = viewport_h as isize;
|
|
||||||
self.on_user_scroll(delta, viewport_h);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Scroll up full page
|
|
||||||
pub fn scroll_full_page_up(&mut self, viewport_h: usize) {
|
|
||||||
let delta = -(viewport_h as isize);
|
|
||||||
self.on_user_scroll(delta, viewport_h);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Jump to top
|
|
||||||
pub fn jump_to_top(&mut self) {
|
|
||||||
self.scroll = 0;
|
|
||||||
self.stick_to_bottom = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Jump to bottom
|
|
||||||
pub fn jump_to_bottom(&mut self, viewport_h: usize) {
|
|
||||||
self.stick_to_bottom = true;
|
|
||||||
self.on_viewport(viewport_h);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Visual selection state for text selection
|
/// Visual selection state for text selection
|
||||||
#[derive(Debug, Clone, Default)]
|
pub use crate::state::VisualSelection;
|
||||||
pub struct VisualSelection {
|
|
||||||
pub start: Option<(usize, usize)>, // (row, col)
|
|
||||||
pub end: Option<(usize, usize)>, // (row, col)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VisualSelection {
|
use serde::{Deserialize, Serialize};
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn start_at(&mut self, pos: (usize, usize)) {
|
/// How role labels should be rendered alongside chat messages.
|
||||||
self.start = Some(pos);
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
self.end = Some(pos);
|
#[serde(rename_all = "lowercase")]
|
||||||
}
|
pub enum RoleLabelDisplay {
|
||||||
|
Inline,
|
||||||
pub fn extend_to(&mut self, pos: (usize, usize)) {
|
Above,
|
||||||
self.end = Some(pos);
|
None,
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clear(&mut self) {
|
|
||||||
self.start = None;
|
|
||||||
self.end = None;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_active(&self) -> bool {
|
|
||||||
self.start.is_some() && self.end.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_normalized(&self) -> Option<((usize, usize), (usize, usize))> {
|
|
||||||
if let (Some(s), Some(e)) = (self.start, self.end) {
|
|
||||||
// Normalize selection so start is always before end
|
|
||||||
if s.0 < e.0 || (s.0 == e.0 && s.1 <= e.1) {
|
|
||||||
Some((s, e))
|
|
||||||
} else {
|
|
||||||
Some((e, s))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract text from a selection range in a list of lines
|
/// Extract text from a selection range in a list of lines
|
||||||
@@ -235,37 +95,7 @@ pub fn extract_text_from_selection(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Cursor position for navigating scrollable content
|
/// Cursor position for navigating scrollable content
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
pub use crate::state::CursorPosition;
|
||||||
pub struct CursorPosition {
|
|
||||||
pub row: usize,
|
|
||||||
pub col: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CursorPosition {
|
|
||||||
pub fn new(row: usize, col: usize) -> Self {
|
|
||||||
Self { row, col }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_up(&mut self, amount: usize) {
|
|
||||||
self.row = self.row.saturating_sub(amount);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_down(&mut self, amount: usize, max: usize) {
|
|
||||||
self.row = (self.row + amount).min(max);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_left(&mut self, amount: usize) {
|
|
||||||
self.col = self.col.saturating_sub(amount);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_right(&mut self, amount: usize, max: usize) {
|
|
||||||
self.col = (self.col + amount).min(max);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_tuple(&self) -> (usize, usize) {
|
|
||||||
(self.row, self.col)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Word boundary detection for navigation
|
/// Word boundary detection for navigation
|
||||||
pub fn find_next_word_boundary(line: &str, col: usize) -> Option<usize> {
|
pub fn find_next_word_boundary(line: &str, col: usize) -> Option<usize> {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use jsonschema::{JSONSchema, ValidationError};
|
use jsonschema::{JSONSchema, ValidationError};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
pub struct SchemaValidator {
|
pub struct SchemaValidator {
|
||||||
schemas: HashMap<String, JSONSchema>,
|
schemas: HashMap<String, JSONSchema>,
|
||||||
|
|||||||
310
crates/owlen-core/tests/agent_tool_flow.rs
Normal file
310
crates/owlen-core/tests/agent_tool_flow.rs
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
use std::{any::Any, collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use owlen_core::{
|
||||||
|
Config, Error, Mode, Provider,
|
||||||
|
config::McpMode,
|
||||||
|
consent::ConsentScope,
|
||||||
|
mcp::{
|
||||||
|
McpClient, McpToolCall, McpToolDescriptor, McpToolResponse,
|
||||||
|
failover::{FailoverMcpClient, ServerEntry},
|
||||||
|
},
|
||||||
|
session::{ControllerEvent, SessionController, SessionOutcome},
|
||||||
|
storage::StorageManager,
|
||||||
|
types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall},
|
||||||
|
ui::NoOpUiController,
|
||||||
|
};
|
||||||
|
use tempfile::tempdir;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
struct StreamingToolProvider;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for StreamingToolProvider {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock-streaming-provider"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
|
||||||
|
Ok(vec![ModelInfo {
|
||||||
|
id: "mock-model".into(),
|
||||||
|
name: "Mock Model".into(),
|
||||||
|
description: Some("A mock model that emits tool calls".into()),
|
||||||
|
provider: self.name().into(),
|
||||||
|
context_window: Some(4096),
|
||||||
|
capabilities: vec!["chat".into(), "tools".into()],
|
||||||
|
supports_tools: true,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result<ChatResponse> {
|
||||||
|
let mut message = Message::assistant("tool-call".to_string());
|
||||||
|
message.tool_calls = Some(vec![ToolCall {
|
||||||
|
id: "call-1".to_string(),
|
||||||
|
name: "resources/write".to_string(),
|
||||||
|
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
||||||
|
}]);
|
||||||
|
|
||||||
|
Ok(ChatResponse {
|
||||||
|
message,
|
||||||
|
usage: None,
|
||||||
|
is_streaming: false,
|
||||||
|
is_final: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_prompt(
|
||||||
|
&self,
|
||||||
|
_request: ChatRequest,
|
||||||
|
) -> owlen_core::Result<owlen_core::ChatStream> {
|
||||||
|
let mut first_chunk = Message::assistant(
|
||||||
|
"Thought: need to update README.\nAction: resources/write".to_string(),
|
||||||
|
);
|
||||||
|
first_chunk.tool_calls = Some(vec![ToolCall {
|
||||||
|
id: "call-1".to_string(),
|
||||||
|
name: "resources/write".to_string(),
|
||||||
|
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
||||||
|
}]);
|
||||||
|
|
||||||
|
let chunk = ChatResponse {
|
||||||
|
message: first_chunk,
|
||||||
|
usage: None,
|
||||||
|
is_streaming: true,
|
||||||
|
is_final: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)])))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> owlen_core::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tool_descriptor() -> McpToolDescriptor {
|
||||||
|
McpToolDescriptor {
|
||||||
|
name: "web_search".to_string(),
|
||||||
|
description: "search".to_string(),
|
||||||
|
input_schema: serde_json::json!({"type": "object"}),
|
||||||
|
requires_network: true,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TimeoutClient;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl McpClient for TimeoutClient {
|
||||||
|
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
||||||
|
Ok(vec![tool_descriptor()])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
||||||
|
Err(Error::Network(
|
||||||
|
"timeout while contacting remote web search endpoint".into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CachedResponseClient {
|
||||||
|
response: Arc<McpToolResponse>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CachedResponseClient {
|
||||||
|
fn new() -> Self {
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
metadata.insert("source".to_string(), "cache".to_string());
|
||||||
|
metadata.insert("cached".to_string(), "true".to_string());
|
||||||
|
|
||||||
|
let response = McpToolResponse {
|
||||||
|
name: "web_search".to_string(),
|
||||||
|
success: true,
|
||||||
|
output: serde_json::json!({
|
||||||
|
"query": "rust",
|
||||||
|
"results": [
|
||||||
|
{"title": "Rust Programming Language", "url": "https://www.rust-lang.org"}
|
||||||
|
],
|
||||||
|
"note": "cached result"
|
||||||
|
}),
|
||||||
|
metadata,
|
||||||
|
duration_ms: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
response: Arc::new(response),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl McpClient for CachedResponseClient {
|
||||||
|
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
||||||
|
Ok(vec![tool_descriptor()])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
||||||
|
Ok((*self.response).clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
|
async fn streaming_file_write_consent_denied_returns_resolution() {
|
||||||
|
let temp_dir = tempdir().expect("temp dir");
|
||||||
|
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
|
||||||
|
.await
|
||||||
|
.expect("storage");
|
||||||
|
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.general.enable_streaming = true;
|
||||||
|
config.privacy.encrypt_local_data = false;
|
||||||
|
config.privacy.require_consent_per_session = true;
|
||||||
|
config.general.default_model = Some("mock-model".into());
|
||||||
|
config.mcp.mode = McpMode::LocalOnly;
|
||||||
|
config
|
||||||
|
.refresh_mcp_servers(None)
|
||||||
|
.expect("refresh MCP servers");
|
||||||
|
|
||||||
|
let provider: Arc<dyn Provider> = Arc::new(StreamingToolProvider);
|
||||||
|
let ui = Arc::new(NoOpUiController);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||||
|
|
||||||
|
let mut session = SessionController::new(
|
||||||
|
provider,
|
||||||
|
config,
|
||||||
|
Arc::new(storage),
|
||||||
|
ui,
|
||||||
|
true,
|
||||||
|
Some(event_tx),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("session controller");
|
||||||
|
|
||||||
|
session
|
||||||
|
.set_operating_mode(Mode::Code)
|
||||||
|
.await
|
||||||
|
.expect("code mode");
|
||||||
|
|
||||||
|
let outcome = session
|
||||||
|
.send_message(
|
||||||
|
"Please write to README".to_string(),
|
||||||
|
ChatParameters {
|
||||||
|
stream: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("send message");
|
||||||
|
|
||||||
|
let (response_id, mut stream) = if let SessionOutcome::Streaming {
|
||||||
|
response_id,
|
||||||
|
stream,
|
||||||
|
} = outcome
|
||||||
|
{
|
||||||
|
(response_id, stream)
|
||||||
|
} else {
|
||||||
|
panic!("expected streaming outcome");
|
||||||
|
};
|
||||||
|
|
||||||
|
session
|
||||||
|
.mark_stream_placeholder(response_id, "▌")
|
||||||
|
.expect("placeholder");
|
||||||
|
|
||||||
|
let chunk = stream
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.expect("stream chunk")
|
||||||
|
.expect("chunk result");
|
||||||
|
session
|
||||||
|
.apply_stream_chunk(response_id, &chunk)
|
||||||
|
.expect("apply chunk");
|
||||||
|
|
||||||
|
let tool_calls = session
|
||||||
|
.check_streaming_tool_calls(response_id)
|
||||||
|
.expect("tool calls");
|
||||||
|
assert_eq!(tool_calls.len(), 1);
|
||||||
|
assert_eq!(tool_calls[0].name, "resources/write");
|
||||||
|
|
||||||
|
let event = event_rx.recv().await.expect("controller event");
|
||||||
|
let request_id = match event {
|
||||||
|
ControllerEvent::ToolRequested {
|
||||||
|
request_id,
|
||||||
|
tool_name,
|
||||||
|
data_types,
|
||||||
|
endpoints,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tool_name, "resources/write");
|
||||||
|
assert!(data_types.iter().any(|t| t.contains("file")));
|
||||||
|
assert!(endpoints.iter().any(|e| e.contains("filesystem")));
|
||||||
|
request_id
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolution = session
|
||||||
|
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
||||||
|
.expect("resolution");
|
||||||
|
assert_eq!(resolution.scope, ConsentScope::Denied);
|
||||||
|
assert_eq!(resolution.tool_name, "resources/write");
|
||||||
|
assert_eq!(resolution.tool_calls.len(), tool_calls.len());
|
||||||
|
|
||||||
|
let err = session
|
||||||
|
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
||||||
|
.expect_err("second resolution should fail");
|
||||||
|
matches!(err, Error::InvalidInput(_));
|
||||||
|
|
||||||
|
let conversation = session.conversation().clone();
|
||||||
|
let assistant = conversation
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.find(|message| message.role == Role::Assistant)
|
||||||
|
.expect("assistant message present");
|
||||||
|
assert!(
|
||||||
|
assistant
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|calls| calls.first())
|
||||||
|
.is_some_and(|call| call.name == "resources/write"),
|
||||||
|
"stream chunk should capture the tool call on the assistant message"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn web_tool_timeout_fails_over_to_cached_result() {
|
||||||
|
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
||||||
|
let cached = CachedResponseClient::new();
|
||||||
|
let backup: Arc<dyn McpClient> = Arc::new(cached.clone());
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::with_servers(vec![
|
||||||
|
ServerEntry::new("primary".into(), primary, 1),
|
||||||
|
ServerEntry::new("cache".into(), backup, 2),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let call = McpToolCall {
|
||||||
|
name: "web_search".to_string(),
|
||||||
|
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.call_tool(call.clone()).await.expect("fallback");
|
||||||
|
|
||||||
|
assert_eq!(response.name, "web_search");
|
||||||
|
assert_eq!(
|
||||||
|
response.metadata.get("source").map(String::as_str),
|
||||||
|
Some("cache")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
response.output.get("note").and_then(|value| value.as_str()),
|
||||||
|
Some("cached result")
|
||||||
|
);
|
||||||
|
|
||||||
|
let statuses = client.get_server_status().await;
|
||||||
|
assert!(statuses.iter().any(|(name, health)| name == "primary"
|
||||||
|
&& !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
||||||
|
assert!(statuses.iter().any(|(name, health)| name == "cache"
|
||||||
|
&& matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
use owlen_core::mcp::client::McpClient;
|
use owlen_core::McpToolCall;
|
||||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||||
use owlen_core::mcp::McpToolCall;
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
@@ -22,7 +21,7 @@ async fn remote_file_server_read_and_list() {
|
|||||||
.join("../..")
|
.join("../..")
|
||||||
.join("Cargo.toml");
|
.join("Cargo.toml");
|
||||||
let build_status = std::process::Command::new("cargo")
|
let build_status = std::process::Command::new("cargo")
|
||||||
.args(&["build", "-p", "owlen-mcp-server", "--manifest-path"])
|
.args(["build", "-p", "owlen-mcp-server", "--manifest-path"])
|
||||||
.arg(manifest_path)
|
.arg(manifest_path)
|
||||||
.status()
|
.status()
|
||||||
.expect("failed to run cargo build for MCP server");
|
.expect("failed to run cargo build for MCP server");
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
use owlen_core::mcp::client::McpClient;
|
use owlen_core::McpToolCall;
|
||||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||||
use owlen_core::mcp::McpToolCall;
|
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn remote_write_and_delete() {
|
async fn remote_write_and_delete() {
|
||||||
// Build the server binary first
|
// Build the server binary first
|
||||||
let status = std::process::Command::new("cargo")
|
let status = std::process::Command::new("cargo")
|
||||||
.args(&["build", "-p", "owlen-mcp-server"])
|
.args(["build", "-p", "owlen-mcp-server"])
|
||||||
.status()
|
.status()
|
||||||
.expect("failed to build MCP server");
|
.expect("failed to build MCP server");
|
||||||
assert!(status.success());
|
assert!(status.success());
|
||||||
@@ -42,7 +41,7 @@ async fn remote_write_and_delete() {
|
|||||||
async fn write_outside_root_is_rejected() {
|
async fn write_outside_root_is_rejected() {
|
||||||
// Build server (already built in previous test, but ensure it exists)
|
// Build server (already built in previous test, but ensure it exists)
|
||||||
let status = std::process::Command::new("cargo")
|
let status = std::process::Command::new("cargo")
|
||||||
.args(&["build", "-p", "owlen-mcp-server"])
|
.args(["build", "-p", "owlen-mcp-server"])
|
||||||
.status()
|
.status()
|
||||||
.expect("failed to build MCP server");
|
.expect("failed to build MCP server");
|
||||||
assert!(status.success());
|
assert!(status.success());
|
||||||
|
|||||||
110
crates/owlen-core/tests/mode_tool_filter.rs
Normal file
110
crates/owlen-core/tests/mode_tool_filter.rs
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
//! Tests for mode‑based tool availability filtering.
|
||||||
|
//!
|
||||||
|
//! These tests verify that `ToolRegistry::execute` respects the
|
||||||
|
//! `ModeConfig` settings in `Config`. The default configuration only
|
||||||
|
//! allows `web_search` in chat mode and all tools in code mode.
|
||||||
|
//!
|
||||||
|
//! We create a simple mock tool (`EchoTool`) that just echoes the
|
||||||
|
//! provided arguments. By customizing the `Config` we can test both the
|
||||||
|
//! allowed‑in‑chat and disallowed‑in‑any‑mode paths.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use owlen_core::config::Config;
|
||||||
|
use owlen_core::mode::{Mode, ModeConfig, ModeToolConfig};
|
||||||
|
use owlen_core::tools::registry::ToolRegistry;
|
||||||
|
use owlen_core::tools::{Tool, ToolResult};
|
||||||
|
use owlen_core::ui::{NoOpUiController, UiController};
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
/// A trivial tool that returns the provided arguments as its output.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EchoTool;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Tool for EchoTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"echo"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"Echo the input arguments"
|
||||||
|
}
|
||||||
|
fn schema(&self) -> serde_json::Value {
|
||||||
|
// Accept any object.
|
||||||
|
json!({ "type": "object" })
|
||||||
|
}
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> owlen_core::Result<ToolResult> {
|
||||||
|
Ok(ToolResult::success(args))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_tool_allowed_in_chat_mode() {
|
||||||
|
// Build a config where the `echo` tool is explicitly allowed in chat.
|
||||||
|
let cfg = Config {
|
||||||
|
modes: ModeConfig {
|
||||||
|
chat: ModeToolConfig {
|
||||||
|
allowed_tools: vec!["echo".to_string()],
|
||||||
|
},
|
||||||
|
code: ModeToolConfig {
|
||||||
|
allowed_tools: vec!["*".to_string()],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cfg = Arc::new(Mutex::new(cfg));
|
||||||
|
|
||||||
|
let ui: Arc<dyn UiController> = Arc::new(NoOpUiController);
|
||||||
|
let mut reg = ToolRegistry::new(cfg.clone(), ui);
|
||||||
|
reg.register(EchoTool);
|
||||||
|
|
||||||
|
let args = json!({ "msg": "hello" });
|
||||||
|
let result = reg
|
||||||
|
.execute("echo", args.clone(), Mode::Chat)
|
||||||
|
.await
|
||||||
|
.expect("execution should succeed");
|
||||||
|
|
||||||
|
assert!(result.success, "Tool should succeed when allowed");
|
||||||
|
assert_eq!(result.output, args, "Output should echo the input");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_tool_not_allowed_in_any_mode() {
|
||||||
|
// Config that does NOT list `echo` in either mode.
|
||||||
|
let cfg = Config {
|
||||||
|
modes: ModeConfig {
|
||||||
|
chat: ModeToolConfig {
|
||||||
|
allowed_tools: vec!["web_search".to_string()],
|
||||||
|
},
|
||||||
|
code: ModeToolConfig {
|
||||||
|
// Strict denial - only web_search allowed
|
||||||
|
allowed_tools: vec!["web_search".to_string()],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cfg = Arc::new(Mutex::new(cfg));
|
||||||
|
|
||||||
|
let ui: Arc<dyn UiController> = Arc::new(NoOpUiController);
|
||||||
|
let mut reg = ToolRegistry::new(cfg.clone(), ui);
|
||||||
|
reg.register(EchoTool);
|
||||||
|
|
||||||
|
let args = json!({ "msg": "hello" });
|
||||||
|
let result = reg
|
||||||
|
.execute("echo", args, Mode::Chat)
|
||||||
|
.await
|
||||||
|
.expect("execution should return a ToolResult");
|
||||||
|
|
||||||
|
// Expect an error indicating the tool is unavailable in any mode.
|
||||||
|
assert!(!result.success, "Tool should be rejected when not allowed");
|
||||||
|
let err_msg = result
|
||||||
|
.output
|
||||||
|
.get("error")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
assert!(
|
||||||
|
err_msg.contains("not available in any mode"),
|
||||||
|
"Error message should explain unavailability"
|
||||||
|
);
|
||||||
|
}
|
||||||
311
crates/owlen-core/tests/phase9_remoting.rs
Normal file
311
crates/owlen-core/tests/phase9_remoting.rs
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
//! Integration tests for Phase 9: Remoting / Cloud Hybrid Deployment
|
||||||
|
//!
|
||||||
|
//! Tests WebSocket transport, failover mechanisms, and health checking.
|
||||||
|
|
||||||
|
use owlen_core::mcp::failover::{FailoverConfig, FailoverMcpClient, ServerEntry, ServerHealth};
|
||||||
|
use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor};
|
||||||
|
use owlen_core::{Error, Result};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Mock MCP client for testing failover behavior
|
||||||
|
struct MockMcpClient {
|
||||||
|
name: String,
|
||||||
|
fail_count: AtomicUsize,
|
||||||
|
max_failures: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockMcpClient {
|
||||||
|
fn new(name: &str, max_failures: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.to_string(),
|
||||||
|
fail_count: AtomicUsize::new(0),
|
||||||
|
max_failures,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn always_healthy(name: &str) -> Self {
|
||||||
|
Self::new(name, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fail_n_times(name: &str, n: usize) -> Self {
|
||||||
|
Self::new(name, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl McpClient for MockMcpClient {
|
||||||
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||||
|
let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
if current < self.max_failures {
|
||||||
|
Err(Error::Network(format!(
|
||||||
|
"Mock failure {} from '{}'",
|
||||||
|
current + 1,
|
||||||
|
self.name
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(vec![McpToolDescriptor {
|
||||||
|
name: format!("test_tool_{}", self.name),
|
||||||
|
description: format!("Tool from {}", self.name),
|
||||||
|
input_schema: serde_json::json!({}),
|
||||||
|
requires_network: false,
|
||||||
|
requires_filesystem: vec![],
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(&self, call: McpToolCall) -> Result<owlen_core::mcp::McpToolResponse> {
|
||||||
|
let current = self.fail_count.load(Ordering::SeqCst);
|
||||||
|
if current < self.max_failures {
|
||||||
|
Err(Error::Network(format!("Mock failure from '{}'", self.name)))
|
||||||
|
} else {
|
||||||
|
Ok(owlen_core::mcp::McpToolResponse {
|
||||||
|
name: call.name,
|
||||||
|
success: true,
|
||||||
|
output: serde_json::json!({ "server": self.name }),
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
duration_ms: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_failover_basic_priority() {
|
||||||
|
// Create two healthy servers with different priorities
|
||||||
|
let primary = Arc::new(MockMcpClient::always_healthy("primary"));
|
||||||
|
let backup = Arc::new(MockMcpClient::always_healthy("backup"));
|
||||||
|
|
||||||
|
let servers = vec![
|
||||||
|
ServerEntry::new("primary".to_string(), primary as Arc<dyn McpClient>, 1),
|
||||||
|
ServerEntry::new("backup".to_string(), backup as Arc<dyn McpClient>, 2),
|
||||||
|
];
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::with_servers(servers);
|
||||||
|
|
||||||
|
// Should use primary (lower priority number)
|
||||||
|
let tools = client.list_tools().await.unwrap();
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
assert_eq!(tools[0].name, "test_tool_primary");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_failover_with_retry() {
|
||||||
|
// Primary fails 2 times, then succeeds
|
||||||
|
let primary = Arc::new(MockMcpClient::fail_n_times("primary", 2));
|
||||||
|
let backup = Arc::new(MockMcpClient::always_healthy("backup"));
|
||||||
|
|
||||||
|
let servers = vec![
|
||||||
|
ServerEntry::new("primary".to_string(), primary as Arc<dyn McpClient>, 1),
|
||||||
|
ServerEntry::new("backup".to_string(), backup as Arc<dyn McpClient>, 2),
|
||||||
|
];
|
||||||
|
|
||||||
|
let config = FailoverConfig {
|
||||||
|
max_retries: 3,
|
||||||
|
base_retry_delay: Duration::from_millis(10),
|
||||||
|
health_check_interval: Duration::from_secs(30),
|
||||||
|
health_check_timeout: Duration::from_secs(5),
|
||||||
|
circuit_breaker_threshold: 5,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::new(servers, config);
|
||||||
|
|
||||||
|
// Should eventually succeed after retries
|
||||||
|
let tools = client.list_tools().await.unwrap();
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
// After 2 failures and 1 success, should get the tool
|
||||||
|
assert!(tools[0].name.contains("test_tool"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_failover_to_backup() {
|
||||||
|
// Primary always fails, backup always succeeds
|
||||||
|
let primary = Arc::new(MockMcpClient::fail_n_times("primary", 999));
|
||||||
|
let backup = Arc::new(MockMcpClient::always_healthy("backup"));
|
||||||
|
|
||||||
|
let servers = vec![
|
||||||
|
ServerEntry::new("primary".to_string(), primary as Arc<dyn McpClient>, 1),
|
||||||
|
ServerEntry::new("backup".to_string(), backup as Arc<dyn McpClient>, 2),
|
||||||
|
];
|
||||||
|
|
||||||
|
let config = FailoverConfig {
|
||||||
|
max_retries: 5,
|
||||||
|
base_retry_delay: Duration::from_millis(5),
|
||||||
|
health_check_interval: Duration::from_secs(30),
|
||||||
|
health_check_timeout: Duration::from_secs(5),
|
||||||
|
circuit_breaker_threshold: 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::new(servers, config);
|
||||||
|
|
||||||
|
// Should failover to backup after exhausting retries on primary
|
||||||
|
let tools = client.list_tools().await.unwrap();
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
assert_eq!(tools[0].name, "test_tool_backup");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_server_health_tracking() {
|
||||||
|
let client = Arc::new(MockMcpClient::always_healthy("test"));
|
||||||
|
let entry = ServerEntry::new("test".to_string(), client, 1);
|
||||||
|
|
||||||
|
// Initial state should be healthy
|
||||||
|
assert!(entry.is_available().await);
|
||||||
|
assert_eq!(entry.get_health().await, ServerHealth::Healthy);
|
||||||
|
|
||||||
|
// Mark as degraded
|
||||||
|
entry.mark_degraded().await;
|
||||||
|
assert!(!entry.is_available().await);
|
||||||
|
match entry.get_health().await {
|
||||||
|
ServerHealth::Degraded { .. } => {}
|
||||||
|
_ => panic!("Expected Degraded state"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as down
|
||||||
|
entry.mark_down().await;
|
||||||
|
assert!(!entry.is_available().await);
|
||||||
|
match entry.get_health().await {
|
||||||
|
ServerHealth::Down { .. } => {}
|
||||||
|
_ => panic!("Expected Down state"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recover to healthy
|
||||||
|
entry.mark_healthy().await;
|
||||||
|
assert!(entry.is_available().await);
|
||||||
|
assert_eq!(entry.get_health().await, ServerHealth::Healthy);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_health_check_all() {
|
||||||
|
let healthy = Arc::new(MockMcpClient::always_healthy("healthy"));
|
||||||
|
let unhealthy = Arc::new(MockMcpClient::fail_n_times("unhealthy", 999));
|
||||||
|
|
||||||
|
let servers = vec![
|
||||||
|
ServerEntry::new("healthy".to_string(), healthy as Arc<dyn McpClient>, 1),
|
||||||
|
ServerEntry::new("unhealthy".to_string(), unhealthy as Arc<dyn McpClient>, 2),
|
||||||
|
];
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::with_servers(servers);
|
||||||
|
|
||||||
|
// Run health check
|
||||||
|
client.health_check_all().await;
|
||||||
|
|
||||||
|
// Give spawned tasks time to complete
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
|
||||||
|
// Check server status
|
||||||
|
let status = client.get_server_status().await;
|
||||||
|
assert_eq!(status.len(), 2);
|
||||||
|
|
||||||
|
// Healthy server should be healthy
|
||||||
|
let healthy_status = status.iter().find(|(name, _)| name == "healthy").unwrap();
|
||||||
|
assert_eq!(healthy_status.1, ServerHealth::Healthy);
|
||||||
|
|
||||||
|
// Unhealthy server should be down
|
||||||
|
let unhealthy_status = status.iter().find(|(name, _)| name == "unhealthy").unwrap();
|
||||||
|
match unhealthy_status.1 {
|
||||||
|
ServerHealth::Down { .. } => {}
|
||||||
|
_ => panic!("Expected unhealthy server to be Down"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_call_tool_failover() {
|
||||||
|
// Primary fails, backup succeeds
|
||||||
|
let primary = Arc::new(MockMcpClient::fail_n_times("primary", 999));
|
||||||
|
let backup = Arc::new(MockMcpClient::always_healthy("backup"));
|
||||||
|
|
||||||
|
let servers = vec![
|
||||||
|
ServerEntry::new("primary".to_string(), primary as Arc<dyn McpClient>, 1),
|
||||||
|
ServerEntry::new("backup".to_string(), backup as Arc<dyn McpClient>, 2),
|
||||||
|
];
|
||||||
|
|
||||||
|
let config = FailoverConfig {
|
||||||
|
max_retries: 5,
|
||||||
|
base_retry_delay: Duration::from_millis(5),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::new(servers, config);
|
||||||
|
|
||||||
|
// Call a tool - should failover to backup
|
||||||
|
let call = McpToolCall {
|
||||||
|
name: "test_tool".to_string(),
|
||||||
|
arguments: serde_json::json!({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.call_tool(call).await.unwrap();
|
||||||
|
assert!(response.success);
|
||||||
|
assert_eq!(response.output["server"], "backup");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_exponential_backoff() {
|
||||||
|
// Test that retry delays increase exponentially
|
||||||
|
let client = Arc::new(MockMcpClient::fail_n_times("test", 2));
|
||||||
|
let entry = ServerEntry::new("test".to_string(), client, 1);
|
||||||
|
|
||||||
|
let config = FailoverConfig {
|
||||||
|
max_retries: 3,
|
||||||
|
base_retry_delay: Duration::from_millis(10),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let failover = FailoverMcpClient::new(vec![entry], config);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let _ = failover.list_tools().await;
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
|
// With base delay of 10ms and 2 retries:
|
||||||
|
// Attempt 1: immediate
|
||||||
|
// Attempt 2: 10ms delay (2^0 * 10)
|
||||||
|
// Attempt 3: 20ms delay (2^1 * 10)
|
||||||
|
// Total should be at least 30ms
|
||||||
|
assert!(
|
||||||
|
elapsed >= Duration::from_millis(30),
|
||||||
|
"Expected at least 30ms, got {:?}",
|
||||||
|
elapsed
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_no_servers_configured() {
|
||||||
|
let config = FailoverConfig::default();
|
||||||
|
let client = FailoverMcpClient::new(vec![], config);
|
||||||
|
|
||||||
|
let result = client.list_tools().await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(Error::Network(msg)) => assert!(msg.contains("No servers configured")),
|
||||||
|
_ => panic!("Expected Network error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_all_servers_fail() {
|
||||||
|
// Both servers always fail
|
||||||
|
let primary = Arc::new(MockMcpClient::fail_n_times("primary", 999));
|
||||||
|
let backup = Arc::new(MockMcpClient::fail_n_times("backup", 999));
|
||||||
|
|
||||||
|
let servers = vec![
|
||||||
|
ServerEntry::new("primary".to_string(), primary as Arc<dyn McpClient>, 1),
|
||||||
|
ServerEntry::new("backup".to_string(), backup as Arc<dyn McpClient>, 2),
|
||||||
|
];
|
||||||
|
|
||||||
|
let config = FailoverConfig {
|
||||||
|
max_retries: 2,
|
||||||
|
base_retry_delay: Duration::from_millis(5),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = FailoverMcpClient::new(servers, config);
|
||||||
|
|
||||||
|
let result = client.list_tools().await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(Error::Network(_)) => {} // Expected
|
||||||
|
_ => panic!("Expected Network error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
75
crates/owlen-core/tests/prompt_server.rs
Normal file
75
crates/owlen-core/tests/prompt_server.rs
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
//! Integration test for the MCP prompt rendering server.
|
||||||
|
|
||||||
|
use owlen_core::Result;
|
||||||
|
use owlen_core::config::McpServerConfig;
|
||||||
|
use owlen_core::mcp::client::RemoteMcpClient;
|
||||||
|
use owlen_core::mcp::{McpToolCall, McpToolResponse};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_render_prompt_via_external_server() -> Result<()> {
|
||||||
|
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
|
let workspace_root = manifest_dir
|
||||||
|
.parent()
|
||||||
|
.and_then(|p| p.parent())
|
||||||
|
.expect("workspace root");
|
||||||
|
|
||||||
|
let candidates = [
|
||||||
|
workspace_root
|
||||||
|
.join("target")
|
||||||
|
.join("debug")
|
||||||
|
.join("owlen-mcp-prompt-server"),
|
||||||
|
workspace_root
|
||||||
|
.join("owlen-mcp-prompt-server")
|
||||||
|
.join("target")
|
||||||
|
.join("debug")
|
||||||
|
.join("owlen-mcp-prompt-server"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let binary = if let Some(path) = candidates.iter().find(|path| path.exists()) {
|
||||||
|
path.clone()
|
||||||
|
} else {
|
||||||
|
eprintln!(
|
||||||
|
"Skipping prompt server integration test: binary not found. \
|
||||||
|
Build it with `cargo build -p owlen-mcp-prompt-server`. Tried {:?}",
|
||||||
|
candidates
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = McpServerConfig {
|
||||||
|
name: "prompt_server".into(),
|
||||||
|
command: binary.to_string_lossy().into_owned(),
|
||||||
|
args: Vec::new(),
|
||||||
|
transport: "stdio".into(),
|
||||||
|
env: std::collections::HashMap::new(),
|
||||||
|
oauth: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = match RemoteMcpClient::new_with_config(&config) {
|
||||||
|
Ok(client) => client,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!(
|
||||||
|
"Skipping prompt server integration test: failed to launch {} ({err})",
|
||||||
|
config.command
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let call = McpToolCall {
|
||||||
|
name: "render_prompt".into(),
|
||||||
|
arguments: json!({
|
||||||
|
"template_name": "example",
|
||||||
|
"variables": {"name": "Alice", "role": "Tester"}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp: McpToolResponse = client.call_tool(call).await?;
|
||||||
|
assert!(resp.success, "Tool reported failure: {:?}", resp);
|
||||||
|
let output = resp.output.as_str().unwrap_or("");
|
||||||
|
assert!(output.contains("Alice"), "Output missing name: {}", output);
|
||||||
|
assert!(output.contains("Tester"), "Output missing role: {}", output);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
use owlen_core::wrap_cursor::{build_cursor_map, ScreenPos};
|
use owlen_core::wrap_cursor::{ScreenPos, build_cursor_map};
|
||||||
|
|
||||||
fn assert_cursor_pos(map: &[ScreenPos], byte_idx: usize, expected: ScreenPos) {
|
fn assert_cursor_pos(map: &[ScreenPos], byte_idx: usize, expected: ScreenPos) {
|
||||||
assert_eq!(map[byte_idx], expected, "Mismatch at byte {}", byte_idx);
|
assert_eq!(map[byte_idx], expected, "Mismatch at byte {}", byte_idx);
|
||||||
|
|||||||
10
crates/owlen-markdown/Cargo.toml
Normal file
10
crates/owlen-markdown/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[package]
|
||||||
|
name = "owlen-markdown"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
description = "Lightweight markdown to ratatui::Text renderer for OWLEN"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ratatui = { workspace = true }
|
||||||
|
unicode-width = "0.1"
|
||||||
270
crates/owlen-markdown/src/lib.rs
Normal file
270
crates/owlen-markdown/src/lib.rs
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
use ratatui::prelude::*;
|
||||||
|
use ratatui::text::{Line, Span, Text};
|
||||||
|
use unicode_width::UnicodeWidthStr;
|
||||||
|
|
||||||
|
/// Convert a markdown string into a `ratatui::Text`.
|
||||||
|
///
|
||||||
|
/// This lightweight renderer supports common constructs (headings, lists, bold,
|
||||||
|
/// italics, and inline code) and is designed to keep dependencies minimal for
|
||||||
|
/// the OWLEN project.
|
||||||
|
pub fn from_str(input: &str) -> Text<'static> {
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
let mut in_code_block = false;
|
||||||
|
|
||||||
|
for raw_line in input.lines() {
|
||||||
|
let line = raw_line.trim_end_matches('\r');
|
||||||
|
let trimmed = line.trim_start();
|
||||||
|
let indent = &line[..line.len() - trimmed.len()];
|
||||||
|
|
||||||
|
if trimmed.starts_with("```") {
|
||||||
|
in_code_block = !in_code_block;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if in_code_block {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
if !indent.is_empty() {
|
||||||
|
spans.push(Span::raw(indent.to_string()));
|
||||||
|
}
|
||||||
|
spans.push(Span::styled(
|
||||||
|
trimmed.to_string(),
|
||||||
|
Style::default()
|
||||||
|
.fg(Color::LightYellow)
|
||||||
|
.add_modifier(Modifier::DIM),
|
||||||
|
));
|
||||||
|
lines.push(Line::from(spans));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
lines.push(Line::from(Vec::<Span<'static>>::new()));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if trimmed.starts_with('#') {
|
||||||
|
let level = trimmed.chars().take_while(|c| *c == '#').count().min(6);
|
||||||
|
let content = trimmed[level..].trim_start();
|
||||||
|
let mut style = Style::default().add_modifier(Modifier::BOLD);
|
||||||
|
style = match level {
|
||||||
|
1 => style.fg(Color::LightCyan),
|
||||||
|
2 => style.fg(Color::Cyan),
|
||||||
|
_ => style.fg(Color::LightBlue),
|
||||||
|
};
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
if !indent.is_empty() {
|
||||||
|
spans.push(Span::raw(indent.to_string()));
|
||||||
|
}
|
||||||
|
spans.push(Span::styled(content.to_string(), style));
|
||||||
|
lines.push(Line::from(spans));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rest) = trimmed.strip_prefix("- ") {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
if !indent.is_empty() {
|
||||||
|
spans.push(Span::raw(indent.to_string()));
|
||||||
|
}
|
||||||
|
spans.push(Span::styled(
|
||||||
|
"• ".to_string(),
|
||||||
|
Style::default().fg(Color::LightGreen),
|
||||||
|
));
|
||||||
|
spans.extend(parse_inline(rest));
|
||||||
|
lines.push(Line::from(spans));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rest) = trimmed.strip_prefix("* ") {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
if !indent.is_empty() {
|
||||||
|
spans.push(Span::raw(indent.to_string()));
|
||||||
|
}
|
||||||
|
spans.push(Span::styled(
|
||||||
|
"• ".to_string(),
|
||||||
|
Style::default().fg(Color::LightGreen),
|
||||||
|
));
|
||||||
|
spans.extend(parse_inline(rest));
|
||||||
|
lines.push(Line::from(spans));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((number, rest)) = parse_ordered_item(trimmed) {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
if !indent.is_empty() {
|
||||||
|
spans.push(Span::raw(indent.to_string()));
|
||||||
|
}
|
||||||
|
spans.push(Span::styled(
|
||||||
|
format!("{number}. "),
|
||||||
|
Style::default().fg(Color::LightGreen),
|
||||||
|
));
|
||||||
|
spans.extend(parse_inline(rest));
|
||||||
|
lines.push(Line::from(spans));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
if !indent.is_empty() {
|
||||||
|
spans.push(Span::raw(indent.to_string()));
|
||||||
|
}
|
||||||
|
spans.extend(parse_inline(trimmed));
|
||||||
|
lines.push(Line::from(spans));
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.is_empty() {
|
||||||
|
lines.push(Line::from(Vec::<Span<'static>>::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Text::from(lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_ordered_item(line: &str) -> Option<(u32, &str)> {
|
||||||
|
let mut parts = line.splitn(2, '.');
|
||||||
|
let number = parts.next()?.trim();
|
||||||
|
let rest = parts.next()?;
|
||||||
|
if number.chars().all(|c| c.is_ascii_digit()) {
|
||||||
|
let value = number.parse().ok()?;
|
||||||
|
let rest = rest.trim_start();
|
||||||
|
Some((value, rest))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_inline(text: &str) -> Vec<Span<'static>> {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
let bytes = text.as_bytes();
|
||||||
|
let mut i = 0;
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut plain_start = 0;
|
||||||
|
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'`' {
|
||||||
|
if let Some(offset) = text[i + 1..].find('`') {
|
||||||
|
if i > plain_start {
|
||||||
|
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||||
|
}
|
||||||
|
let content = &text[i + 1..i + 1 + offset];
|
||||||
|
spans.push(Span::styled(
|
||||||
|
content.to_string(),
|
||||||
|
Style::default()
|
||||||
|
.fg(Color::LightYellow)
|
||||||
|
.add_modifier(Modifier::BOLD),
|
||||||
|
));
|
||||||
|
i += offset + 2;
|
||||||
|
plain_start = i;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes[i] == b'*' {
|
||||||
|
if i + 1 < len && bytes[i + 1] == b'*' {
|
||||||
|
if let Some(offset) = text[i + 2..].find("**") {
|
||||||
|
if i > plain_start {
|
||||||
|
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||||
|
}
|
||||||
|
let content = &text[i + 2..i + 2 + offset];
|
||||||
|
spans.push(Span::styled(
|
||||||
|
content.to_string(),
|
||||||
|
Style::default().add_modifier(Modifier::BOLD),
|
||||||
|
));
|
||||||
|
i += offset + 4;
|
||||||
|
plain_start = i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else if let Some(offset) = text[i + 1..].find('*') {
|
||||||
|
if i > plain_start {
|
||||||
|
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||||
|
}
|
||||||
|
let content = &text[i + 1..i + 1 + offset];
|
||||||
|
spans.push(Span::styled(
|
||||||
|
content.to_string(),
|
||||||
|
Style::default().add_modifier(Modifier::ITALIC),
|
||||||
|
));
|
||||||
|
i += offset + 2;
|
||||||
|
plain_start = i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes[i] == b'_' {
|
||||||
|
if i + 1 < len && bytes[i + 1] == b'_' {
|
||||||
|
if let Some(offset) = text[i + 2..].find("__") {
|
||||||
|
if i > plain_start {
|
||||||
|
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||||
|
}
|
||||||
|
let content = &text[i + 2..i + 2 + offset];
|
||||||
|
spans.push(Span::styled(
|
||||||
|
content.to_string(),
|
||||||
|
Style::default().add_modifier(Modifier::BOLD),
|
||||||
|
));
|
||||||
|
i += offset + 4;
|
||||||
|
plain_start = i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else if let Some(offset) = text[i + 1..].find('_') {
|
||||||
|
if i > plain_start {
|
||||||
|
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||||
|
}
|
||||||
|
let content = &text[i + 1..i + 1 + offset];
|
||||||
|
spans.push(Span::styled(
|
||||||
|
content.to_string(),
|
||||||
|
Style::default().add_modifier(Modifier::ITALIC),
|
||||||
|
));
|
||||||
|
i += offset + 2;
|
||||||
|
plain_start = i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if plain_start < len {
|
||||||
|
spans.push(Span::raw(text[plain_start..].to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if spans.is_empty() {
|
||||||
|
spans.push(Span::raw(String::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
spans
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn visual_length(spans: &[Span<'_>]) -> usize {
|
||||||
|
spans
|
||||||
|
.iter()
|
||||||
|
.map(|span| UnicodeWidthStr::width(span.content.as_ref()))
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn headings_are_bold() {
|
||||||
|
let text = from_str("# Heading");
|
||||||
|
assert_eq!(text.lines.len(), 1);
|
||||||
|
let line = &text.lines[0];
|
||||||
|
assert!(
|
||||||
|
line.spans
|
||||||
|
.iter()
|
||||||
|
.any(|span| span.style.contains(Modifier::BOLD))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inline_code_styled() {
|
||||||
|
let text = from_str("Use `code` inline.");
|
||||||
|
let styled = text
|
||||||
|
.lines
|
||||||
|
.iter()
|
||||||
|
.flat_map(|line| &line.spans)
|
||||||
|
.find(|span| span.content.as_ref() == "code")
|
||||||
|
.cloned()
|
||||||
|
.unwrap();
|
||||||
|
assert!(styled.style.contains(Modifier::BOLD));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "owlen-mcp-llm-server"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
owlen-core = { path = "../owlen-core" }
|
|
||||||
owlen-ollama = { path = "../owlen-ollama" }
|
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
serde_json = "1.0"
|
|
||||||
anyhow = "1.0"
|
|
||||||
tokio-stream = "0.1"
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
path = "src/lib.rs"
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name = "owlen-mcp-llm-server"
|
|
||||||
path = "src/lib.rs"
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "owlen-mcp-server"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
serde_json = "1.0"
|
|
||||||
anyhow = "1.0"
|
|
||||||
path-clean = "1.0"
|
|
||||||
owlen-core = { path = "../owlen-core" }
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# Owlen Ollama
|
|
||||||
|
|
||||||
This crate provides an implementation of the `owlen-core::Provider` trait for the [Ollama](https://ollama.ai) backend.
|
|
||||||
|
|
||||||
It allows Owlen to communicate with a local Ollama instance, sending requests and receiving responses from locally-run large language models. You can also target [Ollama Cloud](https://docs.ollama.com/cloud) by pointing the provider at `https://ollama.com` (or `https://api.ollama.com`) and providing an API key through your Owlen configuration (or the `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables). The client automatically adds the required Bearer authorization header when a key is supplied, accepts either host without rewriting, and expands inline environment references like `$OLLAMA_API_KEY` if you prefer not to check the secret into your config file. The generated configuration now includes both `providers.ollama` and `providers.ollama-cloud` entries—switch between them by updating `general.default_provider`.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
To use this provider, you need to have Ollama installed and running. The default address is `http://localhost:11434`. You can configure this in your `config.toml` if your Ollama instance is running elsewhere.
|
|
||||||
@@ -1,994 +0,0 @@
|
|||||||
//! Ollama provider for OWLEN LLM client
|
|
||||||
|
|
||||||
use futures_util::StreamExt;
|
|
||||||
use owlen_core::{
|
|
||||||
config::GeneralSettings,
|
|
||||||
model::ModelManager,
|
|
||||||
provider::{ChatStream, Provider, ProviderConfig},
|
|
||||||
types::{
|
|
||||||
ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage, ToolCall,
|
|
||||||
},
|
|
||||||
Result,
|
|
||||||
};
|
|
||||||
use reqwest::{header, Client, Url};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::{json, Value};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::env;
|
|
||||||
use std::io;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
|
|
||||||
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
|
||||||
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
enum OllamaMode {
|
|
||||||
Local,
|
|
||||||
Cloud,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OllamaMode {
|
|
||||||
fn from_provider_type(provider_type: &str) -> Self {
|
|
||||||
if provider_type.eq_ignore_ascii_case("ollama-cloud") {
|
|
||||||
Self::Cloud
|
|
||||||
} else {
|
|
||||||
Self::Local
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_base_url(self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Local => "http://localhost:11434",
|
|
||||||
Self::Cloud => "https://ollama.com",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_scheme(self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Local => "http",
|
|
||||||
Self::Cloud => "https",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_ollama_host(host: &str) -> bool {
|
|
||||||
host.eq_ignore_ascii_case("ollama.com")
|
|
||||||
|| host.eq_ignore_ascii_case("www.ollama.com")
|
|
||||||
|| host.eq_ignore_ascii_case("api.ollama.com")
|
|
||||||
|| host.ends_with(".ollama.com")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_base_url(
|
|
||||||
input: Option<&str>,
|
|
||||||
mode_hint: OllamaMode,
|
|
||||||
) -> std::result::Result<String, String> {
|
|
||||||
let mut candidate = input
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|value| !value.is_empty())
|
|
||||||
.map(|value| value.to_string())
|
|
||||||
.unwrap_or_else(|| mode_hint.default_base_url().to_string());
|
|
||||||
|
|
||||||
if !candidate.contains("://") {
|
|
||||||
candidate = format!("{}://{}", mode_hint.default_scheme(), candidate);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut url =
|
|
||||||
Url::parse(&candidate).map_err(|err| format!("Invalid base_url '{candidate}': {err}"))?;
|
|
||||||
|
|
||||||
let mut is_cloud = matches!(mode_hint, OllamaMode::Cloud);
|
|
||||||
|
|
||||||
if let Some(host) = url.host_str() {
|
|
||||||
if is_ollama_host(host) {
|
|
||||||
is_cloud = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_cloud {
|
|
||||||
if url.scheme() != "https" {
|
|
||||||
url.set_scheme("https")
|
|
||||||
.map_err(|_| "Ollama Cloud requires an https URL".to_string())?;
|
|
||||||
}
|
|
||||||
|
|
||||||
match url.host_str() {
|
|
||||||
Some(host) => {
|
|
||||||
if host.eq_ignore_ascii_case("www.ollama.com") {
|
|
||||||
url.set_host(Some("ollama.com"))
|
|
||||||
.map_err(|_| "Failed to normalize Ollama Cloud host".to_string())?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
return Err("Ollama Cloud base_url must include a hostname".to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove trailing slash and discard query/fragment segments
|
|
||||||
let current_path = url.path().to_string();
|
|
||||||
let trimmed_path = current_path.trim_end_matches('/');
|
|
||||||
if trimmed_path.is_empty() {
|
|
||||||
url.set_path("");
|
|
||||||
} else {
|
|
||||||
url.set_path(trimmed_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
url.set_query(None);
|
|
||||||
url.set_fragment(None);
|
|
||||||
|
|
||||||
Ok(url.to_string().trim_end_matches('/').to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
|
||||||
let trimmed_base = base_url.trim_end_matches('/');
|
|
||||||
let trimmed_endpoint = endpoint.trim_start_matches('/');
|
|
||||||
|
|
||||||
if trimmed_base.ends_with("/api") {
|
|
||||||
format!("{trimmed_base}/{trimmed_endpoint}")
|
|
||||||
} else {
|
|
||||||
format!("{trimmed_base}/api/{trimmed_endpoint}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn env_var_non_empty(name: &str) -> Option<String> {
|
|
||||||
env::var(name)
|
|
||||||
.ok()
|
|
||||||
.map(|value| value.trim().to_string())
|
|
||||||
.filter(|value| !value.is_empty())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_api_key(configured: Option<String>) -> Option<String> {
|
|
||||||
let raw = configured?.trim().to_string();
|
|
||||||
if raw.is_empty() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(variable) = raw
|
|
||||||
.strip_prefix("${")
|
|
||||||
.and_then(|value| value.strip_suffix('}'))
|
|
||||||
.or_else(|| raw.strip_prefix('$'))
|
|
||||||
{
|
|
||||||
let var_name = variable.trim();
|
|
||||||
if var_name.is_empty() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
return env_var_non_empty(var_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(raw)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn debug_requests_enabled() -> bool {
|
|
||||||
std::env::var("OWLEN_DEBUG_OLLAMA")
|
|
||||||
.ok()
|
|
||||||
.map(|value| {
|
|
||||||
matches!(
|
|
||||||
value.trim(),
|
|
||||||
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes"
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mask_token(token: &str) -> String {
|
|
||||||
if token.len() <= 8 {
|
|
||||||
return "***".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
let head = &token[..4];
|
|
||||||
let tail = &token[token.len() - 4..];
|
|
||||||
format!("{head}***{tail}")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mask_authorization(value: &str) -> String {
|
|
||||||
if let Some(token) = value.strip_prefix("Bearer ") {
|
|
||||||
format!("Bearer {}", mask_token(token))
|
|
||||||
} else {
|
|
||||||
"***".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama provider implementation with enhanced configuration and caching
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct OllamaProvider {
|
|
||||||
client: Client,
|
|
||||||
base_url: String,
|
|
||||||
api_key: Option<String>,
|
|
||||||
model_manager: ModelManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Options for configuring the Ollama provider
|
|
||||||
pub(crate) struct OllamaOptions {
|
|
||||||
base_url: String,
|
|
||||||
request_timeout: Duration,
|
|
||||||
model_cache_ttl: Duration,
|
|
||||||
api_key: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OllamaOptions {
|
|
||||||
pub(crate) fn new(base_url: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
base_url: base_url.into(),
|
|
||||||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
|
||||||
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
|
||||||
api_key: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_general(mut self, general: &GeneralSettings) -> Self {
|
|
||||||
self.model_cache_ttl = general.model_cache_ttl();
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama-specific message format
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct OllamaMessage {
|
|
||||||
role: String,
|
|
||||||
content: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tool_calls: Option<Vec<OllamaToolCall>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama tool call format
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct OllamaToolCall {
|
|
||||||
function: OllamaToolCallFunction,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct OllamaToolCallFunction {
|
|
||||||
name: String,
|
|
||||||
arguments: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama chat request format
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
struct OllamaChatRequest {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<OllamaMessage>,
|
|
||||||
stream: bool,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tools: Option<Vec<OllamaTool>>,
|
|
||||||
#[serde(flatten)]
|
|
||||||
options: HashMap<String, Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama tool definition
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct OllamaTool {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
tool_type: String,
|
|
||||||
function: OllamaToolFunction,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct OllamaToolFunction {
|
|
||||||
name: String,
|
|
||||||
description: String,
|
|
||||||
parameters: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama chat response format
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OllamaChatResponse {
|
|
||||||
message: Option<OllamaMessage>,
|
|
||||||
done: bool,
|
|
||||||
#[serde(default)]
|
|
||||||
prompt_eval_count: Option<u32>,
|
|
||||||
#[serde(default)]
|
|
||||||
eval_count: Option<u32>,
|
|
||||||
#[serde(default)]
|
|
||||||
error: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OllamaErrorResponse {
|
|
||||||
error: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama models list response
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OllamaModelsResponse {
|
|
||||||
models: Vec<OllamaModelInfo>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ollama model information
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OllamaModelInfo {
|
|
||||||
name: String,
|
|
||||||
#[serde(default)]
|
|
||||||
details: Option<OllamaModelDetails>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OllamaModelDetails {
|
|
||||||
#[serde(default)]
|
|
||||||
family: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OllamaProvider {
|
|
||||||
/// Create a new Ollama provider with sensible defaults
|
|
||||||
pub fn new(base_url: impl Into<String>) -> Result<Self> {
|
|
||||||
let mode = OllamaMode::Local;
|
|
||||||
let supplied = base_url.into();
|
|
||||||
let normalized =
|
|
||||||
normalize_base_url(Some(&supplied), mode).map_err(owlen_core::Error::Config)?;
|
|
||||||
|
|
||||||
Self::with_options(OllamaOptions::new(normalized))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn debug_log_request(&self, label: &str, request: &reqwest::Request, body_json: Option<&str>) {
|
|
||||||
if !debug_requests_enabled() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!("--- OWLEN Ollama request ({label}) ---");
|
|
||||||
eprintln!("{} {}", request.method(), request.url());
|
|
||||||
|
|
||||||
match request
|
|
||||||
.headers()
|
|
||||||
.get(header::AUTHORIZATION)
|
|
||||||
.and_then(|value| value.to_str().ok())
|
|
||||||
{
|
|
||||||
Some(value) => eprintln!("Authorization: {}", mask_authorization(value)),
|
|
||||||
None => eprintln!("Authorization: <none>"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(body) = body_json {
|
|
||||||
eprintln!("Body:\n{body}");
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!("---------------------------------------");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert MCP tool descriptors to Ollama tool format
|
|
||||||
fn convert_tools_to_ollama(tools: &[owlen_core::mcp::McpToolDescriptor]) -> Vec<OllamaTool> {
|
|
||||||
tools
|
|
||||||
.iter()
|
|
||||||
.map(|tool| OllamaTool {
|
|
||||||
tool_type: "function".to_string(),
|
|
||||||
function: OllamaToolFunction {
|
|
||||||
name: tool.name.clone(),
|
|
||||||
description: tool.description.clone(),
|
|
||||||
parameters: tool.input_schema.clone(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a provider from configuration settings
|
|
||||||
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
|
|
||||||
let mode = OllamaMode::from_provider_type(&config.provider_type);
|
|
||||||
let normalized_base_url = normalize_base_url(config.base_url.as_deref(), mode)
|
|
||||||
.map_err(owlen_core::Error::Config)?;
|
|
||||||
|
|
||||||
let mut options = OllamaOptions::new(normalized_base_url);
|
|
||||||
|
|
||||||
if let Some(timeout) = config
|
|
||||||
.extra
|
|
||||||
.get("timeout_secs")
|
|
||||||
.and_then(|value| value.as_u64())
|
|
||||||
{
|
|
||||||
options.request_timeout = Duration::from_secs(timeout.max(5));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(cache_ttl) = config
|
|
||||||
.extra
|
|
||||||
.get("model_cache_ttl_secs")
|
|
||||||
.and_then(|value| value.as_u64())
|
|
||||||
{
|
|
||||||
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
|
|
||||||
}
|
|
||||||
|
|
||||||
options.api_key = resolve_api_key(config.api_key.clone())
|
|
||||||
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
|
||||||
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
|
||||||
|
|
||||||
if let Some(general) = general {
|
|
||||||
options = options.with_general(general);
|
|
||||||
}
|
|
||||||
|
|
||||||
Self::with_options(options)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a provider from explicit options
|
|
||||||
pub(crate) fn with_options(options: OllamaOptions) -> Result<Self> {
|
|
||||||
let OllamaOptions {
|
|
||||||
base_url,
|
|
||||||
request_timeout,
|
|
||||||
model_cache_ttl,
|
|
||||||
api_key,
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
let client = Client::builder()
|
|
||||||
.timeout(request_timeout)
|
|
||||||
.build()
|
|
||||||
.map_err(|e| owlen_core::Error::Config(format!("Failed to build HTTP client: {e}")))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
client,
|
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
|
||||||
api_key,
|
|
||||||
model_manager: ModelManager::new(model_cache_ttl),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Accessor for the underlying model manager
|
|
||||||
pub fn model_manager(&self) -> &ModelManager {
|
|
||||||
&self.model_manager
|
|
||||||
}
|
|
||||||
|
|
||||||
fn api_url(&self, endpoint: &str) -> String {
|
|
||||||
build_api_endpoint(&self.base_url, endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
|
||||||
if let Some(api_key) = &self.api_key {
|
|
||||||
request.bearer_auth(api_key)
|
|
||||||
} else {
|
|
||||||
request
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convert_message(message: &Message) -> OllamaMessage {
|
|
||||||
let role = match message.role {
|
|
||||||
Role::User => "user".to_string(),
|
|
||||||
Role::Assistant => "assistant".to_string(),
|
|
||||||
Role::System => "system".to_string(),
|
|
||||||
Role::Tool => "tool".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let tool_calls = message.tool_calls.as_ref().map(|calls| {
|
|
||||||
calls
|
|
||||||
.iter()
|
|
||||||
.map(|tc| OllamaToolCall {
|
|
||||||
function: OllamaToolCallFunction {
|
|
||||||
name: tc.name.clone(),
|
|
||||||
arguments: tc.arguments.clone(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
});
|
|
||||||
|
|
||||||
OllamaMessage {
|
|
||||||
role,
|
|
||||||
content: message.content.clone(),
|
|
||||||
tool_calls,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convert_ollama_message(message: &OllamaMessage) -> Message {
|
|
||||||
let role = match message.role.as_str() {
|
|
||||||
"user" => Role::User,
|
|
||||||
"assistant" => Role::Assistant,
|
|
||||||
"system" => Role::System,
|
|
||||||
"tool" => Role::Tool,
|
|
||||||
_ => Role::Assistant,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut msg = Message::new(role, message.content.clone());
|
|
||||||
|
|
||||||
// Convert tool calls if present
|
|
||||||
if let Some(ollama_tool_calls) = &message.tool_calls {
|
|
||||||
let tool_calls: Vec<ToolCall> = ollama_tool_calls
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(idx, tc)| ToolCall {
|
|
||||||
id: format!("call_{}", idx),
|
|
||||||
name: tc.function.name.clone(),
|
|
||||||
arguments: tc.function.arguments.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
msg.tool_calls = Some(tool_calls);
|
|
||||||
}
|
|
||||||
|
|
||||||
msg
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_options(parameters: ChatParameters) -> HashMap<String, Value> {
|
|
||||||
let mut options = parameters.extra;
|
|
||||||
|
|
||||||
if let Some(temperature) = parameters.temperature {
|
|
||||||
options
|
|
||||||
.entry("temperature".to_string())
|
|
||||||
.or_insert(json!(temperature as f64));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_tokens) = parameters.max_tokens {
|
|
||||||
options
|
|
||||||
.entry("num_predict".to_string())
|
|
||||||
.or_insert(json!(max_tokens));
|
|
||||||
}
|
|
||||||
|
|
||||||
options
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
|
||||||
let url = self.api_url("tags");
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.apply_auth(self.client.get(&url))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| owlen_core::Error::Network(format!("Failed to fetch models: {e}")))?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let code = response.status();
|
|
||||||
let error = parse_error_body(response).await;
|
|
||||||
return Err(owlen_core::Error::Network(format!(
|
|
||||||
"Ollama model listing failed ({code}): {error}"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = response.text().await.map_err(|e| {
|
|
||||||
owlen_core::Error::Network(format!("Failed to read models response: {e}"))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let ollama_response: OllamaModelsResponse =
|
|
||||||
serde_json::from_str(&body).map_err(owlen_core::Error::Serialization)?;
|
|
||||||
|
|
||||||
let models = ollama_response
|
|
||||||
.models
|
|
||||||
.into_iter()
|
|
||||||
.map(|model| {
|
|
||||||
// Check if model supports tool calling based on known models
|
|
||||||
let supports_tools = Self::check_tool_support(&model.name);
|
|
||||||
|
|
||||||
ModelInfo {
|
|
||||||
id: model.name.clone(),
|
|
||||||
name: model.name.clone(),
|
|
||||||
description: model
|
|
||||||
.details
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|d| d.family.as_ref().map(|f| format!("Ollama {f} model"))),
|
|
||||||
provider: "ollama".to_string(),
|
|
||||||
context_window: None,
|
|
||||||
capabilities: vec!["chat".to_string()],
|
|
||||||
supports_tools,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(models)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a model supports tool calling based on its name
|
|
||||||
fn check_tool_support(model_name: &str) -> bool {
|
|
||||||
let name_lower = model_name.to_lowercase();
|
|
||||||
|
|
||||||
// Known models with tool calling support
|
|
||||||
let tool_supporting_models = [
|
|
||||||
"qwen",
|
|
||||||
"llama3.1",
|
|
||||||
"llama3.2",
|
|
||||||
"llama3.3",
|
|
||||||
"mistral-nemo",
|
|
||||||
"mistral:7b-instruct",
|
|
||||||
"command-r",
|
|
||||||
"firefunction",
|
|
||||||
"hermes",
|
|
||||||
"nexusraven",
|
|
||||||
"granite-code",
|
|
||||||
];
|
|
||||||
|
|
||||||
tool_supporting_models
|
|
||||||
.iter()
|
|
||||||
.any(|&supported| name_lower.contains(supported))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
|
||||||
impl Provider for OllamaProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"ollama"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
||||||
self.model_manager
|
|
||||||
.get_or_refresh(false, || async { self.fetch_models().await })
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
|
||||||
let ChatRequest {
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
parameters,
|
|
||||||
tools,
|
|
||||||
} = request;
|
|
||||||
|
|
||||||
let messages: Vec<OllamaMessage> = messages.iter().map(Self::convert_message).collect();
|
|
||||||
|
|
||||||
let options = Self::build_options(parameters);
|
|
||||||
|
|
||||||
// Only send the `tools` field if there is at least one tool.
|
|
||||||
// An empty array makes Ollama validate tool support and can cause a
|
|
||||||
// 400 Bad Request for models that do not support tools.
|
|
||||||
// Currently the `tools` field is omitted for compatibility; the variable is retained
|
|
||||||
// for potential future use.
|
|
||||||
let _ollama_tools = tools
|
|
||||||
.as_ref()
|
|
||||||
.filter(|t| !t.is_empty())
|
|
||||||
.map(|t| Self::convert_tools_to_ollama(t));
|
|
||||||
|
|
||||||
// Ollama currently rejects any presence of the `tools` field for models that
|
|
||||||
// do not support function calling. To be safe, we omit the field entirely.
|
|
||||||
let ollama_request = OllamaChatRequest {
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
stream: false,
|
|
||||||
tools: None,
|
|
||||||
options,
|
|
||||||
};
|
|
||||||
|
|
||||||
let url = self.api_url("chat");
|
|
||||||
let debug_body = if debug_requests_enabled() {
|
|
||||||
serde_json::to_string_pretty(&ollama_request).ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut request_builder = self.client.post(&url).json(&ollama_request);
|
|
||||||
request_builder = self.apply_auth(request_builder);
|
|
||||||
|
|
||||||
let request = request_builder.build().map_err(|e| {
|
|
||||||
owlen_core::Error::Network(format!("Failed to build chat request: {e}"))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
self.debug_log_request("chat", &request, debug_body.as_deref());
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.execute(request)
|
|
||||||
.await
|
|
||||||
.map_err(|e| owlen_core::Error::Network(format!("Chat request failed: {e}")))?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let code = response.status();
|
|
||||||
let error = parse_error_body(response).await;
|
|
||||||
return Err(owlen_core::Error::Network(format!(
|
|
||||||
"Ollama chat failed ({code}): {error}"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = response.text().await.map_err(|e| {
|
|
||||||
owlen_core::Error::Network(format!("Failed to read chat response: {e}"))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut ollama_response: OllamaChatResponse =
|
|
||||||
serde_json::from_str(&body).map_err(owlen_core::Error::Serialization)?;
|
|
||||||
|
|
||||||
if let Some(error) = ollama_response.error.take() {
|
|
||||||
return Err(owlen_core::Error::Provider(anyhow::anyhow!(error)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let message = match ollama_response.message {
|
|
||||||
Some(ref msg) => Self::convert_ollama_message(msg),
|
|
||||||
None => {
|
|
||||||
return Err(owlen_core::Error::Provider(anyhow::anyhow!(
|
|
||||||
"Ollama response missing message"
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let usage = if let (Some(prompt_tokens), Some(completion_tokens)) = (
|
|
||||||
ollama_response.prompt_eval_count,
|
|
||||||
ollama_response.eval_count,
|
|
||||||
) {
|
|
||||||
Some(TokenUsage {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens: prompt_tokens + completion_tokens,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ChatResponse {
|
|
||||||
message,
|
|
||||||
usage,
|
|
||||||
is_streaming: false,
|
|
||||||
is_final: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
|
||||||
let ChatRequest {
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
parameters,
|
|
||||||
tools,
|
|
||||||
} = request;
|
|
||||||
|
|
||||||
let messages: Vec<OllamaMessage> = messages.iter().map(Self::convert_message).collect();
|
|
||||||
|
|
||||||
let options = Self::build_options(parameters);
|
|
||||||
|
|
||||||
// Only include the `tools` field if there is at least one tool.
|
|
||||||
// Sending an empty tools array causes Ollama to reject the request for
|
|
||||||
// models without tool support (400 Bad Request).
|
|
||||||
// Retain tools conversion for possible future extensions, but silence unused warnings.
|
|
||||||
let _ollama_tools = tools
|
|
||||||
.as_ref()
|
|
||||||
.filter(|t| !t.is_empty())
|
|
||||||
.map(|t| Self::convert_tools_to_ollama(t));
|
|
||||||
|
|
||||||
// Omit the `tools` field for compatibility with models lacking tool support.
|
|
||||||
let ollama_request = OllamaChatRequest {
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
stream: true,
|
|
||||||
tools: None,
|
|
||||||
options,
|
|
||||||
};
|
|
||||||
|
|
||||||
let url = self.api_url("chat");
|
|
||||||
let debug_body = if debug_requests_enabled() {
|
|
||||||
serde_json::to_string_pretty(&ollama_request).ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut request_builder = self.client.post(&url).json(&ollama_request);
|
|
||||||
request_builder = self.apply_auth(request_builder);
|
|
||||||
|
|
||||||
let request = request_builder.build().map_err(|e| {
|
|
||||||
owlen_core::Error::Network(format!("Failed to build streaming request: {e}"))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
self.debug_log_request("chat_stream", &request, debug_body.as_deref());
|
|
||||||
|
|
||||||
let response =
|
|
||||||
self.client.execute(request).await.map_err(|e| {
|
|
||||||
owlen_core::Error::Network(format!("Streaming request failed: {e}"))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let code = response.status();
|
|
||||||
let error = parse_error_body(response).await;
|
|
||||||
return Err(owlen_core::Error::Network(format!(
|
|
||||||
"Ollama streaming chat failed ({code}): {error}"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
|
||||||
let mut stream = response.bytes_stream();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut buffer = String::new();
|
|
||||||
|
|
||||||
while let Some(chunk) = stream.next().await {
|
|
||||||
match chunk {
|
|
||||||
Ok(bytes) => {
|
|
||||||
if let Ok(text) = String::from_utf8(bytes.to_vec()) {
|
|
||||||
buffer.push_str(&text);
|
|
||||||
|
|
||||||
while let Some(pos) = buffer.find('\n') {
|
|
||||||
let mut line = buffer[..pos].trim().to_string();
|
|
||||||
buffer.drain(..=pos);
|
|
||||||
|
|
||||||
if line.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if line.ends_with('\r') {
|
|
||||||
line.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<OllamaChatResponse>(&line) {
|
|
||||||
Ok(mut ollama_response) => {
|
|
||||||
if let Some(error) = ollama_response.error.take() {
|
|
||||||
let _ = tx.send(Err(owlen_core::Error::Provider(
|
|
||||||
anyhow::anyhow!(error),
|
|
||||||
)));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(message) = ollama_response.message {
|
|
||||||
let mut chat_response = ChatResponse {
|
|
||||||
message: Self::convert_ollama_message(&message),
|
|
||||||
usage: None,
|
|
||||||
is_streaming: true,
|
|
||||||
is_final: ollama_response.done,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let (Some(prompt_tokens), Some(completion_tokens)) = (
|
|
||||||
ollama_response.prompt_eval_count,
|
|
||||||
ollama_response.eval_count,
|
|
||||||
) {
|
|
||||||
chat_response.usage = Some(TokenUsage {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens: prompt_tokens + completion_tokens,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if tx.send(Ok(chat_response)).is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ollama_response.done {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let _ = tx.send(Err(owlen_core::Error::Serialization(e)));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let _ = tx.send(Err(owlen_core::Error::Serialization(
|
|
||||||
serde_json::Error::io(io::Error::new(
|
|
||||||
io::ErrorKind::InvalidData,
|
|
||||||
"Non UTF-8 chunk from Ollama",
|
|
||||||
)),
|
|
||||||
)));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let _ = tx.send(Err(owlen_core::Error::Network(format!(
|
|
||||||
"Stream error: {e}"
|
|
||||||
))));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let stream = UnboundedReceiverStream::new(rx);
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_check(&self) -> Result<()> {
|
|
||||||
let url = self.api_url("version");
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.apply_auth(self.client.get(&url))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| owlen_core::Error::Network(format!("Health check failed: {e}")))?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(owlen_core::Error::Network(format!(
|
|
||||||
"Ollama health check failed: HTTP {}",
|
|
||||||
response.status()
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_schema(&self) -> serde_json::Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"base_url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Base URL for Ollama API",
|
|
||||||
"default": "http://localhost:11434"
|
|
||||||
},
|
|
||||||
"timeout_secs": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "HTTP request timeout in seconds",
|
|
||||||
"minimum": 5,
|
|
||||||
"default": DEFAULT_TIMEOUT_SECS
|
|
||||||
},
|
|
||||||
"model_cache_ttl_secs": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Seconds to cache model listings",
|
|
||||||
"minimum": 5,
|
|
||||||
"default": DEFAULT_MODEL_CACHE_TTL_SECS
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn parse_error_body(response: reqwest::Response) -> String {
|
|
||||||
match response.bytes().await {
|
|
||||||
Ok(bytes) => {
|
|
||||||
if bytes.is_empty() {
|
|
||||||
return "unknown error".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(err) = serde_json::from_slice::<OllamaErrorResponse>(&bytes) {
|
|
||||||
if let Some(error) = err.error {
|
|
||||||
return error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match String::from_utf8(bytes.to_vec()) {
|
|
||||||
Ok(text) if !text.trim().is_empty() => text,
|
|
||||||
_ => "unknown error".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => "unknown error".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn normalizes_local_base_url_and_infers_scheme() {
|
|
||||||
let normalized =
|
|
||||||
normalize_base_url(Some("localhost:11434"), OllamaMode::Local).expect("valid URL");
|
|
||||||
assert_eq!(normalized, "http://localhost:11434");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn normalizes_cloud_base_url_and_host() {
|
|
||||||
let normalized =
|
|
||||||
normalize_base_url(Some("https://ollama.com"), OllamaMode::Cloud).expect("valid URL");
|
|
||||||
assert_eq!(normalized, "https://ollama.com");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn infers_scheme_for_cloud_hosts() {
|
|
||||||
let normalized =
|
|
||||||
normalize_base_url(Some("ollama.com"), OllamaMode::Cloud).expect("valid URL");
|
|
||||||
assert_eq!(normalized, "https://ollama.com");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn rewrites_www_cloud_host() {
|
|
||||||
let normalized = normalize_base_url(Some("https://www.ollama.com"), OllamaMode::Cloud)
|
|
||||||
.expect("valid URL");
|
|
||||||
assert_eq!(normalized, "https://ollama.com");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retains_explicit_api_suffix() {
|
|
||||||
let normalized = normalize_base_url(Some("https://api.ollama.com/api"), OllamaMode::Cloud)
|
|
||||||
.expect("valid URL");
|
|
||||||
assert_eq!(normalized, "https://api.ollama.com/api");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn builds_api_endpoint_without_duplicate_segments() {
|
|
||||||
let base = "http://localhost:11434";
|
|
||||||
assert_eq!(
|
|
||||||
build_api_endpoint(base, "chat"),
|
|
||||||
"http://localhost:11434/api/chat"
|
|
||||||
);
|
|
||||||
|
|
||||||
let base_with_api = "http://localhost:11434/api";
|
|
||||||
assert_eq!(
|
|
||||||
build_api_endpoint(base_with_api, "chat"),
|
|
||||||
"http://localhost:11434/api/chat"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolve_api_key_prefers_literal_value() {
|
|
||||||
assert_eq!(
|
|
||||||
resolve_api_key(Some("direct-key".into())),
|
|
||||||
Some("direct-key".into())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolve_api_key_expands_braced_env_reference() {
|
|
||||||
std::env::set_var("OWLEN_TEST_KEY_BRACED", "super-secret");
|
|
||||||
assert_eq!(
|
|
||||||
resolve_api_key(Some("${OWLEN_TEST_KEY_BRACED}".into())),
|
|
||||||
Some("super-secret".into())
|
|
||||||
);
|
|
||||||
std::env::remove_var("OWLEN_TEST_KEY_BRACED");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolve_api_key_expands_unbraced_env_reference() {
|
|
||||||
std::env::set_var("OWLEN_TEST_KEY_UNBRACED", "another-secret");
|
|
||||||
assert_eq!(
|
|
||||||
resolve_api_key(Some("$OWLEN_TEST_KEY_UNBRACED".into())),
|
|
||||||
Some("another-secret".into())
|
|
||||||
);
|
|
||||||
std::env::remove_var("OWLEN_TEST_KEY_UNBRACED");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,34 +1,20 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "owlen-ollama"
|
name = "owlen-providers"
|
||||||
version.workspace = true
|
version.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
authors.workspace = true
|
authors.workspace = true
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
description = "Ollama provider for OWLEN LLM client"
|
description = "Provider implementations for OWLEN"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
owlen-core = { path = "../owlen-core" }
|
owlen-core = { path = "../owlen-core" }
|
||||||
|
anyhow = { workspace = true }
|
||||||
# HTTP client
|
async-trait = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
|
||||||
|
|
||||||
# Async runtime
|
|
||||||
tokio = { workspace = true }
|
|
||||||
tokio-stream = { workspace = true }
|
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
futures-util = { workspace = true }
|
|
||||||
|
|
||||||
# Serialization
|
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
tokio = { workspace = true }
|
||||||
# Utilities
|
tokio-stream = { workspace = true }
|
||||||
anyhow = { workspace = true }
|
reqwest = { package = "reqwest", version = "0.11", features = ["json", "stream"] }
|
||||||
thiserror = { workspace = true }
|
|
||||||
uuid = { workspace = true }
|
|
||||||
async-trait = { workspace = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
tokio-test = { workspace = true }
|
|
||||||
3
crates/owlen-providers/src/lib.rs
Normal file
3
crates/owlen-providers/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
//! Provider implementations for OWLEN.
|
||||||
|
|
||||||
|
pub mod ollama;
|
||||||
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
use std::{env, time::Duration};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use owlen_core::{
|
||||||
|
Error as CoreError, Result as CoreResult,
|
||||||
|
config::OLLAMA_CLOUD_BASE_URL,
|
||||||
|
provider::{
|
||||||
|
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||||
|
ProviderStatus, ProviderType,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use serde_json::{Number, Value};
|
||||||
|
|
||||||
|
use super::OllamaClient;
|
||||||
|
|
||||||
|
const API_KEY_ENV: &str = "OLLAMA_CLOUD_API_KEY";
|
||||||
|
|
||||||
|
/// ModelProvider implementation for the hosted Ollama Cloud service.
|
||||||
|
pub struct OllamaCloudProvider {
|
||||||
|
client: OllamaClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaCloudProvider {
|
||||||
|
/// Construct a new cloud provider. An API key must be supplied either
|
||||||
|
/// directly or via the `OLLAMA_CLOUD_API_KEY` environment variable.
|
||||||
|
pub fn new(
|
||||||
|
base_url: Option<String>,
|
||||||
|
api_key: Option<String>,
|
||||||
|
request_timeout: Option<Duration>,
|
||||||
|
) -> CoreResult<Self> {
|
||||||
|
let (api_key, key_source) = resolve_api_key(api_key)?;
|
||||||
|
let base_url = base_url.unwrap_or_else(|| OLLAMA_CLOUD_BASE_URL.to_string());
|
||||||
|
|
||||||
|
let mut metadata =
|
||||||
|
ProviderMetadata::new("ollama_cloud", "Ollama (Cloud)", ProviderType::Cloud, true);
|
||||||
|
metadata
|
||||||
|
.metadata
|
||||||
|
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||||
|
metadata.metadata.insert(
|
||||||
|
"api_key_source".into(),
|
||||||
|
Value::String(key_source.to_string()),
|
||||||
|
);
|
||||||
|
metadata
|
||||||
|
.metadata
|
||||||
|
.insert("api_key_env".into(), Value::String(API_KEY_ENV.to_string()));
|
||||||
|
|
||||||
|
if let Some(timeout) = request_timeout {
|
||||||
|
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||||
|
metadata.metadata.insert(
|
||||||
|
"request_timeout_ms".into(),
|
||||||
|
Value::Number(Number::from(timeout_ms)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let client = OllamaClient::new(&base_url, Some(api_key), metadata, request_timeout)?;
|
||||||
|
|
||||||
|
Ok(Self { client })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ModelProvider for OllamaCloudProvider {
|
||||||
|
fn metadata(&self) -> &ProviderMetadata {
|
||||||
|
self.client.metadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||||
|
match self.client.health_check().await {
|
||||||
|
Ok(status) => Ok(status),
|
||||||
|
Err(CoreError::Auth(_)) => Ok(ProviderStatus::RequiresSetup),
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||||
|
self.client.list_models().await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||||
|
self.client.generate_stream(request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_api_key(api_key: Option<String>) -> CoreResult<(String, &'static str)> {
|
||||||
|
let key_from_config = api_key
|
||||||
|
.as_ref()
|
||||||
|
.map(|value| value.trim())
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(str::to_string);
|
||||||
|
|
||||||
|
if let Some(key) = key_from_config {
|
||||||
|
return Ok((key, "config"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_from_env = env::var(API_KEY_ENV)
|
||||||
|
.ok()
|
||||||
|
.map(|value| value.trim().to_string())
|
||||||
|
.filter(|value| !value.is_empty());
|
||||||
|
|
||||||
|
if let Some(key) = key_from_env {
|
||||||
|
return Ok((key, "env"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(CoreError::Config(
|
||||||
|
"Ollama Cloud API key not configured. Set OLLAMA_CLOUD_API_KEY or configure an API key."
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
80
crates/owlen-providers/src/ollama/local.rs
Normal file
80
crates/owlen-providers/src/ollama/local.rs
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use owlen_core::provider::{
|
||||||
|
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata, ProviderStatus,
|
||||||
|
ProviderType,
|
||||||
|
};
|
||||||
|
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||||
|
use serde_json::{Number, Value};
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use super::OllamaClient;
|
||||||
|
|
||||||
|
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
|
||||||
|
const DEFAULT_HEALTH_TIMEOUT_SECS: u64 = 5;
|
||||||
|
|
||||||
|
/// ModelProvider implementation for a local Ollama daemon.
|
||||||
|
pub struct OllamaLocalProvider {
|
||||||
|
client: OllamaClient,
|
||||||
|
health_timeout: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaLocalProvider {
|
||||||
|
/// Construct a new local provider using the shared [`OllamaClient`].
|
||||||
|
pub fn new(
|
||||||
|
base_url: Option<String>,
|
||||||
|
request_timeout: Option<Duration>,
|
||||||
|
health_timeout: Option<Duration>,
|
||||||
|
) -> CoreResult<Self> {
|
||||||
|
let base_url = base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
|
||||||
|
let health_timeout =
|
||||||
|
health_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_HEALTH_TIMEOUT_SECS));
|
||||||
|
|
||||||
|
let mut metadata =
|
||||||
|
ProviderMetadata::new("ollama_local", "Ollama (Local)", ProviderType::Local, false);
|
||||||
|
metadata
|
||||||
|
.metadata
|
||||||
|
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||||
|
if let Some(timeout) = request_timeout {
|
||||||
|
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||||
|
metadata.metadata.insert(
|
||||||
|
"request_timeout_ms".into(),
|
||||||
|
Value::Number(Number::from(timeout_ms)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let client = OllamaClient::new(&base_url, None, metadata, request_timeout)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
health_timeout,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ModelProvider for OllamaLocalProvider {
|
||||||
|
fn metadata(&self) -> &ProviderMetadata {
|
||||||
|
self.client.metadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||||
|
match timeout(self.health_timeout, self.client.health_check()).await {
|
||||||
|
Ok(Ok(status)) => Ok(status),
|
||||||
|
Ok(Err(CoreError::Network(_))) | Ok(Err(CoreError::Timeout(_))) => {
|
||||||
|
Ok(ProviderStatus::Unavailable)
|
||||||
|
}
|
||||||
|
Ok(Err(err)) => Err(err),
|
||||||
|
Err(_) => Ok(ProviderStatus::Unavailable),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||||
|
self.client.list_models().await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||||
|
self.client.generate_stream(request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
7
crates/owlen-providers/src/ollama/mod.rs
Normal file
7
crates/owlen-providers/src/ollama/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
pub mod cloud;
|
||||||
|
pub mod local;
|
||||||
|
pub mod shared;
|
||||||
|
|
||||||
|
pub use cloud::OllamaCloudProvider;
|
||||||
|
pub use local::OllamaLocalProvider;
|
||||||
|
pub use shared::OllamaClient;
|
||||||
389
crates/owlen-providers/src/ollama/shared.rs
Normal file
389
crates/owlen-providers/src/ollama/shared.rs
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
use owlen_core::provider::{
|
||||||
|
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ProviderMetadata, ProviderStatus,
|
||||||
|
};
|
||||||
|
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||||
|
use reqwest::{Client, Method, StatusCode, Url};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::{Map as JsonMap, Value};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 60;
|
||||||
|
|
||||||
|
/// Shared Ollama HTTP client used by both local and cloud providers.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OllamaClient {
|
||||||
|
http: Client,
|
||||||
|
base_url: Url,
|
||||||
|
api_key: Option<String>,
|
||||||
|
provider_metadata: ProviderMetadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaClient {
|
||||||
|
/// Create a new client with the given base URL and optional API key.
|
||||||
|
pub fn new(
|
||||||
|
base_url: impl AsRef<str>,
|
||||||
|
api_key: Option<String>,
|
||||||
|
provider_metadata: ProviderMetadata,
|
||||||
|
request_timeout: Option<Duration>,
|
||||||
|
) -> CoreResult<Self> {
|
||||||
|
let base_url = Url::parse(base_url.as_ref())
|
||||||
|
.map_err(|err| CoreError::Config(format!("invalid base url: {}", err)))?;
|
||||||
|
|
||||||
|
let timeout = request_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_TIMEOUT_SECS));
|
||||||
|
let http = Client::builder()
|
||||||
|
.timeout(timeout)
|
||||||
|
.build()
|
||||||
|
.map_err(map_reqwest_error)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
http,
|
||||||
|
base_url,
|
||||||
|
api_key,
|
||||||
|
provider_metadata,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provider metadata associated with this client.
|
||||||
|
pub fn metadata(&self) -> &ProviderMetadata {
|
||||||
|
&self.provider_metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a basic health check to determine provider availability.
|
||||||
|
pub async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||||
|
let url = self.endpoint("api/tags")?;
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.request(Method::GET, url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(map_reqwest_error)?;
|
||||||
|
|
||||||
|
match response.status() {
|
||||||
|
status if status.is_success() => Ok(ProviderStatus::Available),
|
||||||
|
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Ok(ProviderStatus::RequiresSetup),
|
||||||
|
_ => Ok(ProviderStatus::Unavailable),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch the available models from the Ollama API.
|
||||||
|
pub async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||||
|
let url = self.endpoint("api/tags")?;
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.request(Method::GET, url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(map_reqwest_error)?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(map_http_error("tags", status, &bytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload: TagsResponse =
|
||||||
|
serde_json::from_slice(&bytes).map_err(CoreError::Serialization)?;
|
||||||
|
|
||||||
|
let models = payload
|
||||||
|
.models
|
||||||
|
.into_iter()
|
||||||
|
.map(|model| self.parse_model_info(model))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(models)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request a streaming generation session from Ollama.
|
||||||
|
pub async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||||
|
let url = self.endpoint("api/generate")?;
|
||||||
|
|
||||||
|
let body = self.build_generate_body(request);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.request(Method::POST, url)
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(map_reqwest_error)?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||||
|
return Err(map_http_error("generate", status, &bytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = response.bytes_stream();
|
||||||
|
let (tx, rx) = mpsc::channel::<CoreResult<GenerateChunk>>(32);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut stream = stream;
|
||||||
|
let mut buffer: Vec<u8> = Vec::new();
|
||||||
|
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
match chunk {
|
||||||
|
Ok(bytes) => {
|
||||||
|
buffer.extend_from_slice(&bytes);
|
||||||
|
while let Some(pos) = buffer.iter().position(|byte| *byte == b'\n') {
|
||||||
|
let line_bytes: Vec<u8> = buffer.drain(..=pos).collect();
|
||||||
|
let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match parse_stream_line(&line) {
|
||||||
|
Ok(item) => {
|
||||||
|
if tx.send(Ok(item)).await.is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err)).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(map_reqwest_error(err))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !buffer.is_empty() {
|
||||||
|
let line = String::from_utf8_lossy(&buffer).trim().to_string();
|
||||||
|
if !line.is_empty() {
|
||||||
|
match parse_stream_line(&line) {
|
||||||
|
Ok(item) => {
|
||||||
|
let _ = tx.send(Ok(item)).await;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = ReceiverStream::new(rx);
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn request(&self, method: Method, url: Url) -> reqwest::RequestBuilder {
|
||||||
|
let mut builder = self.http.request(method, url);
|
||||||
|
if let Some(api_key) = &self.api_key {
|
||||||
|
builder = builder.bearer_auth(api_key);
|
||||||
|
}
|
||||||
|
builder
|
||||||
|
}
|
||||||
|
|
||||||
|
fn endpoint(&self, path: &str) -> CoreResult<Url> {
|
||||||
|
self.base_url
|
||||||
|
.join(path)
|
||||||
|
.map_err(|err| CoreError::Config(format!("invalid endpoint '{}': {}", path, err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_generate_body(&self, request: GenerateRequest) -> Value {
|
||||||
|
let GenerateRequest {
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
context,
|
||||||
|
parameters,
|
||||||
|
metadata,
|
||||||
|
} = request;
|
||||||
|
|
||||||
|
let mut body = JsonMap::new();
|
||||||
|
body.insert("model".into(), Value::String(model));
|
||||||
|
body.insert("stream".into(), Value::Bool(true));
|
||||||
|
|
||||||
|
if let Some(prompt) = prompt {
|
||||||
|
body.insert("prompt".into(), Value::String(prompt));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !context.is_empty() {
|
||||||
|
let items = context.into_iter().map(Value::String).collect();
|
||||||
|
body.insert("context".into(), Value::Array(items));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !parameters.is_empty() {
|
||||||
|
body.insert("options".into(), Value::Object(to_json_map(parameters)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !metadata.is_empty() {
|
||||||
|
body.insert("metadata".into(), Value::Object(to_json_map(metadata)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value::Object(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_model_info(&self, model: OllamaModel) -> ModelInfo {
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
|
||||||
|
if let Some(digest) = model.digest {
|
||||||
|
metadata.insert("digest".to_string(), Value::String(digest));
|
||||||
|
}
|
||||||
|
if let Some(modified) = model.modified_at {
|
||||||
|
metadata.insert("modified_at".to_string(), Value::String(modified));
|
||||||
|
}
|
||||||
|
if let Some(details) = model.details {
|
||||||
|
let mut details_map = JsonMap::new();
|
||||||
|
if let Some(format) = details.format {
|
||||||
|
details_map.insert("format".into(), Value::String(format));
|
||||||
|
}
|
||||||
|
if let Some(family) = details.family {
|
||||||
|
details_map.insert("family".into(), Value::String(family));
|
||||||
|
}
|
||||||
|
if let Some(parameter_size) = details.parameter_size {
|
||||||
|
details_map.insert("parameter_size".into(), Value::String(parameter_size));
|
||||||
|
}
|
||||||
|
if let Some(quantisation) = details.quantization_level {
|
||||||
|
details_map.insert("quantization_level".into(), Value::String(quantisation));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !details_map.is_empty() {
|
||||||
|
metadata.insert("details".to_string(), Value::Object(details_map));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelInfo {
|
||||||
|
name: model.name,
|
||||||
|
size_bytes: model.size,
|
||||||
|
capabilities: Vec::new(),
|
||||||
|
description: None,
|
||||||
|
provider: self.provider_metadata.clone(),
|
||||||
|
metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct TagsResponse {
|
||||||
|
#[serde(default)]
|
||||||
|
models: Vec<OllamaModel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaModel {
|
||||||
|
name: String,
|
||||||
|
#[serde(default)]
|
||||||
|
size: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
digest: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
modified_at: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
details: Option<OllamaModelDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaModelDetails {
|
||||||
|
#[serde(default)]
|
||||||
|
format: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
family: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
parameter_size: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
quantization_level: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_json_map(source: HashMap<String, Value>) -> JsonMap<String, Value> {
|
||||||
|
source.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_metadata_map(value: &Value) -> HashMap<String, Value> {
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
|
||||||
|
if let Value::Object(obj) = value {
|
||||||
|
for (key, item) in obj {
|
||||||
|
if key == "response" || key == "done" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
metadata.insert(key.clone(), item.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_stream_line(line: &str) -> CoreResult<GenerateChunk> {
|
||||||
|
let value: Value = serde_json::from_str(line).map_err(CoreError::Serialization)?;
|
||||||
|
|
||||||
|
if let Some(error) = value.get("error").and_then(Value::as_str) {
|
||||||
|
return Err(CoreError::Provider(anyhow::anyhow!(
|
||||||
|
"ollama generation error: {}",
|
||||||
|
error
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut chunk = GenerateChunk {
|
||||||
|
text: value
|
||||||
|
.get("response")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.map(str::to_string),
|
||||||
|
is_final: value.get("done").and_then(Value::as_bool).unwrap_or(false),
|
||||||
|
metadata: to_metadata_map(&value),
|
||||||
|
};
|
||||||
|
|
||||||
|
if chunk.is_final && chunk.text.is_none() && chunk.metadata.is_empty() {
|
||||||
|
chunk
|
||||||
|
.metadata
|
||||||
|
.insert("status".into(), Value::String("done".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_http_error(endpoint: &str, status: StatusCode, body: &[u8]) -> CoreError {
|
||||||
|
match status {
|
||||||
|
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => CoreError::Auth(format!(
|
||||||
|
"Ollama {} request unauthorized (status {})",
|
||||||
|
endpoint, status
|
||||||
|
)),
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => CoreError::Provider(anyhow::anyhow!(
|
||||||
|
"Ollama {} request rate limited (status {})",
|
||||||
|
endpoint,
|
||||||
|
status
|
||||||
|
)),
|
||||||
|
_ => {
|
||||||
|
let snippet = truncated_body(body);
|
||||||
|
CoreError::Provider(anyhow::anyhow!(
|
||||||
|
"Ollama {} request failed: HTTP {} - {}",
|
||||||
|
endpoint,
|
||||||
|
status,
|
||||||
|
snippet
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncated_body(body: &[u8]) -> String {
|
||||||
|
const MAX_CHARS: usize = 512;
|
||||||
|
let text = String::from_utf8_lossy(body);
|
||||||
|
let mut value = String::new();
|
||||||
|
for (idx, ch) in text.chars().enumerate() {
|
||||||
|
if idx >= MAX_CHARS {
|
||||||
|
value.push('…');
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
value.push(ch);
|
||||||
|
}
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_reqwest_error(err: reqwest::Error) -> CoreError {
|
||||||
|
if err.is_timeout() {
|
||||||
|
CoreError::Timeout(err.to_string())
|
||||||
|
} else if err.is_connect() || err.is_request() {
|
||||||
|
CoreError::Network(err.to_string())
|
||||||
|
} else {
|
||||||
|
CoreError::Provider(err.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
106
crates/owlen-providers/tests/common/mock_provider.rs
Normal file
106
crates/owlen-providers/tests/common/mock_provider.rs
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::{self, StreamExt};
|
||||||
|
use owlen_core::Result as CoreResult;
|
||||||
|
use owlen_core::provider::{
|
||||||
|
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||||
|
ProviderStatus, ProviderType,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct MockProvider {
|
||||||
|
metadata: ProviderMetadata,
|
||||||
|
models: Vec<ModelInfo>,
|
||||||
|
status: ProviderStatus,
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
generate_handler: Option<Arc<dyn Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync>>,
|
||||||
|
generate_error: Option<Arc<dyn Fn() -> owlen_core::Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
pub fn new(id: &str) -> Self {
|
||||||
|
let metadata = ProviderMetadata::new(
|
||||||
|
id,
|
||||||
|
format!("Mock Provider ({})", id),
|
||||||
|
ProviderType::Local,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
metadata,
|
||||||
|
models: vec![ModelInfo {
|
||||||
|
name: format!("{}-primary", id),
|
||||||
|
size_bytes: None,
|
||||||
|
capabilities: vec!["chat".into()],
|
||||||
|
description: Some("Mock model".into()),
|
||||||
|
provider: ProviderMetadata::new(id, "Mock", ProviderType::Local, false),
|
||||||
|
metadata: Default::default(),
|
||||||
|
}],
|
||||||
|
status: ProviderStatus::Available,
|
||||||
|
generate_handler: None,
|
||||||
|
generate_error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_models(mut self, models: Vec<ModelInfo>) -> Self {
|
||||||
|
self.models = models;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_status(mut self, status: ProviderStatus) -> Self {
|
||||||
|
self.status = status;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_generate_handler<F>(mut self, handler: F) -> Self
|
||||||
|
where
|
||||||
|
F: Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
self.generate_handler = Some(Arc::new(handler));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_generate_error<F>(mut self, factory: F) -> Self
|
||||||
|
where
|
||||||
|
F: Fn() -> owlen_core::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
self.generate_error = Some(Arc::new(factory));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ModelProvider for MockProvider {
|
||||||
|
fn metadata(&self) -> &ProviderMetadata {
|
||||||
|
&self.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||||
|
Ok(self.status)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||||
|
Ok(self.models.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||||
|
if let Some(factory) = &self.generate_error {
|
||||||
|
return Err(factory());
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunks = if let Some(handler) = &self.generate_handler {
|
||||||
|
(handler)(request)
|
||||||
|
} else {
|
||||||
|
vec![GenerateChunk::final_chunk()]
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = stream::iter(chunks.into_iter().map(Ok)).boxed();
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<MockProvider> for Arc<dyn ModelProvider> {
|
||||||
|
fn from(provider: MockProvider) -> Self {
|
||||||
|
Arc::new(provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
1
crates/owlen-providers/tests/common/mod.rs
Normal file
1
crates/owlen-providers/tests/common/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub mod mock_provider;
|
||||||
117
crates/owlen-providers/tests/integration_test.rs
Normal file
117
crates/owlen-providers/tests/integration_test.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
use common::mock_provider::MockProvider;
|
||||||
|
use owlen_core::config::Config;
|
||||||
|
use owlen_core::provider::{
|
||||||
|
GenerateChunk, GenerateRequest, ModelInfo, ProviderManager, ProviderType,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn base_config() -> Config {
|
||||||
|
Config {
|
||||||
|
providers: Default::default(),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_model(name: &str, provider: &str) -> ModelInfo {
|
||||||
|
ModelInfo {
|
||||||
|
name: name.into(),
|
||||||
|
size_bytes: None,
|
||||||
|
capabilities: vec!["chat".into()],
|
||||||
|
description: Some("mock".into()),
|
||||||
|
provider: owlen_core::provider::ProviderMetadata::new(
|
||||||
|
provider,
|
||||||
|
provider,
|
||||||
|
ProviderType::Local,
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
metadata: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn registers_providers_and_lists_ids() {
|
||||||
|
let manager = ProviderManager::default();
|
||||||
|
let provider: Arc<dyn owlen_core::provider::ModelProvider> = MockProvider::new("mock-a").into();
|
||||||
|
|
||||||
|
manager.register_provider(provider).await;
|
||||||
|
let ids = manager.provider_ids().await;
|
||||||
|
|
||||||
|
assert_eq!(ids, vec!["mock-a".to_string()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn aggregates_models_across_providers() {
|
||||||
|
let manager = ProviderManager::default();
|
||||||
|
let provider_a = MockProvider::new("mock-a").with_models(vec![make_model("alpha", "mock-a")]);
|
||||||
|
let provider_b = MockProvider::new("mock-b").with_models(vec![make_model("beta", "mock-b")]);
|
||||||
|
|
||||||
|
manager.register_provider(provider_a.into()).await;
|
||||||
|
manager.register_provider(provider_b.into()).await;
|
||||||
|
|
||||||
|
let models = manager.list_all_models().await.unwrap();
|
||||||
|
assert_eq!(models.len(), 2);
|
||||||
|
assert!(models.iter().any(|m| m.model.name == "alpha"));
|
||||||
|
assert!(models.iter().any(|m| m.model.name == "beta"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn routes_generation_to_specific_provider() {
|
||||||
|
let manager = ProviderManager::default();
|
||||||
|
let provider = MockProvider::new("mock-gen").with_generate_handler(|_req| {
|
||||||
|
vec![
|
||||||
|
GenerateChunk::from_text("hello"),
|
||||||
|
GenerateChunk::final_chunk(),
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
manager.register_provider(provider.into()).await;
|
||||||
|
|
||||||
|
let request = GenerateRequest::new("mock-gen::primary");
|
||||||
|
let mut stream = manager.generate("mock-gen", request).await.unwrap();
|
||||||
|
let mut collected = Vec::new();
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
collected.push(chunk.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(collected.len(), 2);
|
||||||
|
assert_eq!(collected[0].text.as_deref(), Some("hello"));
|
||||||
|
assert!(collected[1].is_final);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn marks_provider_unavailable_on_error() {
|
||||||
|
let manager = ProviderManager::default();
|
||||||
|
let provider = MockProvider::new("flaky")
|
||||||
|
.with_generate_error(|| owlen_core::Error::Network("boom".into()));
|
||||||
|
|
||||||
|
manager.register_provider(provider.into()).await;
|
||||||
|
let request = GenerateRequest::new("flaky::model");
|
||||||
|
let result = manager.generate("flaky", request).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let status = manager.provider_status("flaky").await.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
status,
|
||||||
|
owlen_core::provider::ProviderStatus::Unavailable
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn health_refresh_updates_status_cache() {
|
||||||
|
let manager = ProviderManager::default();
|
||||||
|
let provider =
|
||||||
|
MockProvider::new("healthy").with_status(owlen_core::provider::ProviderStatus::Available);
|
||||||
|
|
||||||
|
manager.register_provider(provider.into()).await;
|
||||||
|
let statuses = manager.refresh_health().await;
|
||||||
|
assert_eq!(
|
||||||
|
statuses.get("healthy"),
|
||||||
|
Some(&owlen_core::provider::ProviderStatus::Available)
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -10,8 +10,7 @@ description = "Terminal User Interface for OWLEN LLM client"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
owlen-core = { path = "../owlen-core" }
|
owlen-core = { path = "../owlen-core" }
|
||||||
owlen-ollama = { path = "../owlen-ollama" }
|
# Removed owlen-ollama dependency - all providers now accessed via MCP architecture (Phase 10)
|
||||||
# Removed circular dependency on `owlen-cli`. The TUI no longer directly depends on the CLI crate.
|
|
||||||
|
|
||||||
# TUI framework
|
# TUI framework
|
||||||
ratatui = { workspace = true }
|
ratatui = { workspace = true }
|
||||||
@@ -19,7 +18,20 @@ crossterm = { workspace = true }
|
|||||||
tui-textarea = { workspace = true }
|
tui-textarea = { workspace = true }
|
||||||
textwrap = { workspace = true }
|
textwrap = { workspace = true }
|
||||||
unicode-width = "0.1"
|
unicode-width = "0.1"
|
||||||
|
unicode-segmentation = "1.11"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
globset = "0.4"
|
||||||
|
ignore = "0.4"
|
||||||
|
pathdiff = "0.2"
|
||||||
|
tree-sitter = "0.20"
|
||||||
|
tree-sitter-rust = "0.20"
|
||||||
|
dirs = { workspace = true }
|
||||||
|
toml = { workspace = true }
|
||||||
|
syntect = "5.3"
|
||||||
|
once_cell = "1.19"
|
||||||
|
owlen-markdown = { path = "../owlen-markdown" }
|
||||||
|
shellexpand = { workspace = true }
|
||||||
|
regex = { workspace = true }
|
||||||
|
|
||||||
# Async runtime
|
# Async runtime
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
@@ -30,6 +42,9 @@ futures-util = { workspace = true }
|
|||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
chrono = { workspace = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio-test = { workspace = true }
|
tokio-test = { workspace = true }
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user