Compare commits
77 Commits
fab63d224b
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| d86888704f | |||
| de6b6e20a5 | |||
| 1e8a5e08ed | |||
| 218ebbf32f | |||
| c49e7f4b22 | |||
| 9588c8c562 | |||
| 1948ac1284 | |||
| 3f92b7d963 | |||
| 5553e61dbf | |||
| 7f987737f9 | |||
| 5182f86133 | |||
| a50099ad74 | |||
| 20ba5523ee | |||
| 0b2b3701dc | |||
| 438b05b8a3 | |||
| e2a31b192f | |||
| b827d3d047 | |||
| 9c0cf274a3 | |||
| 85ae319690 | |||
| 449f133a1f | |||
| 2f6b03ef65 | |||
| d4030dc598 | |||
| 3271697f6b | |||
| cbfef5a5df | |||
| 52efd5f341 | |||
| 200cdbc4bd | |||
| 8525819ab4 | |||
| bcd52d526c | |||
| 7effade1d3 | |||
| dc0fee2ee3 | |||
| ea04a25ed6 | |||
| 282dcdce88 | |||
| b49f58bc16 | |||
| cdc425ae93 | |||
| 3525cb3949 | |||
| 9d85420bf6 | |||
| 641c95131f | |||
| 708c626176 | |||
| 5210e196f2 | |||
| 30c375b6c5 | |||
| baf49b1e69 | |||
| 96e0436d43 | |||
| 498e6e61b6 | |||
| 99064b6c41 | |||
| ee58b0ac32 | |||
| 990f93d467 | |||
| 44a00619b5 | |||
| 6923ee439f | |||
| c997b19b53 | |||
| c9daf68fea | |||
| ba9d083088 | |||
| 825dfc0722 | |||
| 3e4eacd1d3 | |||
| 23253219a3 | |||
| cc2b85a86d | |||
| 58dd6f3efa | |||
| c81d0f1593 | |||
| ae0dd3fc51 | |||
| 80dffa9f41 | |||
| ab0ae4fe04 | |||
| d31e068277 | |||
| 690f5c7056 | |||
| 0da8a3f193 | |||
| 15f81d9728 | |||
| b80db89391 | |||
| f413a63c5a | |||
| 33ad3797a1 | |||
| 55e6b0583d | |||
| ae9c3af096 | |||
| 0bd560b408 | |||
| 083b621b7d | |||
| d2a193e5c1 | |||
| acbfe47a4b | |||
| 60c859b3ab | |||
| 82078afd6d | |||
| 7851af14a9 | |||
| c2f5ccea3b |
34
.github/workflows/macos-check.yml
vendored
Normal file
34
.github/workflows/macos-check.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: macos-check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: cargo check (macOS)
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache Cargo registry
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Cargo check
|
||||
run: cargo check --workspace --all-features
|
||||
@@ -9,6 +9,7 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
args: ['--allow-multiple-documents']
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
|
||||
@@ -1,3 +1,61 @@
|
||||
---
|
||||
kind: pipeline
|
||||
name: pr-checks
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
steps:
|
||||
- name: fmt-clippy-test
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add rustfmt clippy
|
||||
- cargo fmt --all -- --check
|
||||
- cargo clippy --workspace --all-features -- -D warnings
|
||||
- cargo test --workspace --all-features
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: security-audit
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- cron
|
||||
branch:
|
||||
- dev
|
||||
cron: weekly-security
|
||||
|
||||
steps:
|
||||
- name: cargo-audit
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- cargo install cargo-audit --locked
|
||||
- cargo audit
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release-tests
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
|
||||
steps:
|
||||
- name: workspace-tests
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add llvm-tools-preview
|
||||
- cargo install cargo-llvm-cov --locked
|
||||
- cargo llvm-cov --workspace --all-features --summary-only
|
||||
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
@@ -5,6 +63,9 @@ when:
|
||||
variables:
|
||||
- &rust_image 'rust:1.83'
|
||||
|
||||
depends_on:
|
||||
- release-tests
|
||||
|
||||
matrix:
|
||||
include:
|
||||
# Linux
|
||||
@@ -39,14 +100,6 @@ matrix:
|
||||
EXT: ".exe"
|
||||
|
||||
steps:
|
||||
- name: tests
|
||||
image: *rust_image
|
||||
commands:
|
||||
- rustup component add llvm-tools-preview
|
||||
- cargo install cargo-llvm-cov --locked
|
||||
- cargo llvm-cov --workspace --all-features --summary-only
|
||||
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||
|
||||
- name: build
|
||||
image: *rust_image
|
||||
commands:
|
||||
@@ -124,6 +177,11 @@ steps:
|
||||
sha256sum ${ARTIFACT}.tar.gz > ${ARTIFACT}.tar.gz.sha256
|
||||
fi
|
||||
|
||||
- name: release-notes
|
||||
image: *rust_image
|
||||
commands:
|
||||
- scripts/release-notes.sh "${CI_COMMIT_TAG}" release-notes.md
|
||||
|
||||
- name: release
|
||||
image: plugins/gitea-release
|
||||
settings:
|
||||
@@ -136,4 +194,4 @@ steps:
|
||||
- ${ARTIFACT}.zip
|
||||
- ${ARTIFACT}.zip.sha256
|
||||
title: Release ${CI_COMMIT_TAG}
|
||||
note: "Release ${CI_COMMIT_TAG}"
|
||||
note_file: release-notes.md
|
||||
|
||||
16
CHANGELOG.md
16
CHANGELOG.md
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Comprehensive documentation suite including guides for architecture, configuration, testing, and more.
|
||||
- Rustdoc examples for core components like `Provider` and `SessionController`.
|
||||
- Module-level documentation for `owlen-tui`.
|
||||
- Provider integration tests (`crates/owlen-providers/tests`) covering registration, routing, and health status handling for the new `ProviderManager`.
|
||||
- TUI message and generation tests that exercise the non-blocking event loop, background worker, and message dispatch.
|
||||
- Ollama integration can now talk to Ollama Cloud when an API key is configured.
|
||||
- Ollama provider will also read `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables when no key is stored in the config.
|
||||
- `owlen config doctor`, `owlen config path`, and `owlen upgrade` CLI commands to automate migrations and surface manual update steps.
|
||||
@@ -23,6 +25,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Tabbed model selector that separates local and cloud providers, including cloud indicators in the UI.
|
||||
- Footer status line includes provider connectivity/credential summaries (e.g., cloud auth failures, missing API keys).
|
||||
- Secure credential vault integration for Ollama Cloud API keys when `privacy.encrypt_local_data = true`.
|
||||
- Input panel respects a new `ui.input_max_rows` setting so long prompts expand predictably before scrolling kicks in.
|
||||
- Command palette offers fuzzy `:model` filtering and `:provider` completions for fast switching.
|
||||
- Message rendering caches wrapped lines and throttles streaming redraws to keep the TUI responsive on long sessions.
|
||||
- Model picker badges now inspect provider capabilities so vision/audio/thinking models surface the correct icons even when descriptions are sparse.
|
||||
- Chat history honors `ui.scrollback_lines`, trimming older rows to keep the TUI responsive and surfacing a "↓ New messages" badge whenever updates land off-screen.
|
||||
|
||||
### Changed
|
||||
- The main `README.md` has been updated to be more concise and link to the new documentation.
|
||||
@@ -31,11 +38,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Configuration loading performs structural validation and fails fast on missing default providers or invalid MCP definitions.
|
||||
- Ollama provider error handling now distinguishes timeouts, missing models, and authentication failures.
|
||||
- `owlen` warns when the active terminal likely lacks 256-color support.
|
||||
- `config.toml` now carries a schema version (`1.1.0`) and is migrated automatically; deprecated keys such as `agent.max_tool_calls` trigger warnings instead of hard failures.
|
||||
- `config.toml` now carries a schema version (`1.2.0`) and is migrated automatically; deprecated keys such as `agent.max_tool_calls` trigger warnings instead of hard failures.
|
||||
- Model selector navigation (Tab/Shift-Tab) now switches between local and cloud tabs while preserving selection state.
|
||||
- Header displays the active model together with its provider (e.g., `Model (Provider)`), improving clarity when swapping backends.
|
||||
- Documentation refreshed to cover the message handler architecture, the background health worker, multi-provider configuration, and the new provider onboarding checklist.
|
||||
|
||||
---
|
||||
|
||||
## [0.1.11] - 2025-10-18
|
||||
|
||||
### Changed
|
||||
- Bump workspace packages and distribution metadata to version `0.1.11`.
|
||||
|
||||
## [0.1.10] - 2025-10-03
|
||||
|
||||
### Added
|
||||
|
||||
@@ -10,6 +10,10 @@ This project and everyone participating in it is governed by the [Owlen Code of
|
||||
|
||||
## How Can I Contribute?
|
||||
|
||||
### Repository map
|
||||
|
||||
Need a quick orientation before diving in? Start with the curated [repo map](docs/repo-map.md) for a two-level directory overview. If you move folders around, regenerate it with `scripts/gen-repo-map.sh`.
|
||||
|
||||
### Reporting Bugs
|
||||
|
||||
This is one of the most helpful ways you can contribute. Before creating a bug report, please check a few things:
|
||||
|
||||
19
Cargo.toml
19
Cargo.toml
@@ -4,17 +4,20 @@ members = [
|
||||
"crates/owlen-core",
|
||||
"crates/owlen-tui",
|
||||
"crates/owlen-cli",
|
||||
"crates/owlen-mcp-server",
|
||||
"crates/owlen-mcp-llm-server",
|
||||
"crates/owlen-mcp-client",
|
||||
"crates/owlen-mcp-code-server",
|
||||
"crates/owlen-mcp-prompt-server",
|
||||
"crates/owlen-providers",
|
||||
"crates/mcp/server",
|
||||
"crates/mcp/llm-server",
|
||||
"crates/mcp/client",
|
||||
"crates/mcp/code-server",
|
||||
"crates/mcp/prompt-server",
|
||||
"crates/owlen-markdown",
|
||||
"xtask",
|
||||
]
|
||||
exclude = []
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.9"
|
||||
edition = "2021"
|
||||
version = "0.1.11"
|
||||
edition = "2024"
|
||||
authors = ["Owlibou"]
|
||||
license = "AGPL-3.0"
|
||||
repository = "https://somegit.dev/Owlibou/owlen"
|
||||
@@ -43,7 +46,7 @@ serde_json = { version = "1.0" }
|
||||
# Utilities
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
thiserror = "2.0"
|
||||
nix = "0.29"
|
||||
which = "6.0"
|
||||
tempfile = "3.8"
|
||||
|
||||
2
PKGBUILD
2
PKGBUILD
@@ -1,6 +1,6 @@
|
||||
# Maintainer: vikingowl <christian@nachtigall.dev>
|
||||
pkgname=owlen
|
||||
pkgver=0.1.9
|
||||
pkgver=0.1.11
|
||||
pkgrel=1
|
||||
pkgdesc="Terminal User Interface LLM client for Ollama with chat and code assistance features"
|
||||
arch=('x86_64')
|
||||
|
||||
69
README.md
69
README.md
@@ -3,16 +3,17 @@
|
||||
> Terminal-native assistant for running local language models with a comfortable TUI.
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
## What Is OWLEN?
|
||||
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local large
|
||||
language models. It provides a responsive chat workflow that runs against
|
||||
[Ollama](https://ollama.com/) with a focus on developer productivity, vim-style navigation,
|
||||
and seamless session management—all without leaving your terminal.
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local and cloud
|
||||
language models. It provides a responsive chat workflow that now routes through a
|
||||
multi-provider manager—handling local Ollama, Ollama Cloud, and future MCP-backed providers—
|
||||
with a focus on developer productivity, vim-style navigation, and seamless session
|
||||
management—all without leaving your terminal.
|
||||
|
||||
## Alpha Status
|
||||
|
||||
@@ -30,8 +31,11 @@ The OWLEN interface features a clean, multi-panel layout with vim-inspired navig
|
||||
- **Streaming Responses**: Real-time token streaming from Ollama.
|
||||
- **Advanced Text Editing**: Multi-line input, history, and clipboard support.
|
||||
- **Session Management**: Save, load, and manage conversations.
|
||||
- **Code Side Panel**: Switch to code mode (`:mode code`) and open files inline with `:open <path>` for LLM-assisted coding.
|
||||
- **Theming System**: 10 built-in themes and support for custom themes.
|
||||
- **Modular Architecture**: Extensible provider system (Ollama today, additional providers on the roadmap).
|
||||
- **Modular Architecture**: Extensible provider system orchestrated by the new `ProviderManager`, ready for additional MCP-backed providers.
|
||||
- **Dual-Source Model Picker**: Merge local and cloud catalogues with real-time availability badges powered by the background health worker.
|
||||
- **Non-Blocking UI Loop**: Asynchronous generation tasks and provider health checks run off-thread, keeping the TUI responsive even while streaming long replies.
|
||||
- **Guided Setup**: `owlen config doctor` upgrades legacy configs and verifies your environment in seconds.
|
||||
|
||||
## Security & Privacy
|
||||
@@ -53,20 +57,28 @@ Owlen is designed to keep data local by default while still allowing controlled
|
||||
|
||||
### Installation
|
||||
|
||||
#### Linux & macOS
|
||||
The recommended way to install on Linux and macOS is to clone the repository and install using `cargo`.
|
||||
Pick the option that matches your platform and appetite for source builds:
|
||||
|
||||
| Platform | Package / Command | Notes |
|
||||
| --- | --- | --- |
|
||||
| Arch Linux | `yay -S owlen-git` | Builds from the latest `dev` branch via AUR. |
|
||||
| Other Linux | `cargo install --path crates/owlen-cli --locked --force` | Requires Rust 1.75+ and a running Ollama daemon. |
|
||||
| macOS | `cargo install --path crates/owlen-cli --locked --force` | macOS 12+ tested. Install Ollama separately (`brew install ollama`). The binary links against the system OpenSSL – ensure Command Line Tools are installed. |
|
||||
| Windows (experimental) | `cargo install --path crates/owlen-cli --locked --force` | Enable the GNU toolchain (`rustup target add x86_64-pc-windows-gnu`) and install Ollama for Windows preview builds. Some optional tools (e.g., Docker-based code execution) are currently disabled. |
|
||||
|
||||
If you prefer containerised builds, use the provided `Dockerfile` as a base image and copy out `target/release/owlen`.
|
||||
|
||||
Run the helper scripts to sanity-check platform coverage:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Owlibou/owlen.git
|
||||
cd owlen
|
||||
cargo install --path crates/owlen-cli
|
||||
# Windows compatibility smoke test (GNU toolchain)
|
||||
scripts/check-windows.sh
|
||||
|
||||
# Reproduce CI packaging locally (choose a target from .woodpecker.yml)
|
||||
dev/local_build.sh x86_64-unknown-linux-gnu
|
||||
```
|
||||
**Note for macOS**: While this method works, official binary releases for macOS are planned for the future.
|
||||
|
||||
#### Windows
|
||||
The Windows build has not been thoroughly tested yet. Installation is possible via the same `cargo install` method, but it is considered experimental at this time.
|
||||
|
||||
From Unix hosts you can run `scripts/check-windows.sh` to ensure the code base still compiles for Windows (`rustup` will install the required target automatically).
|
||||
> **Tip (macOS):** On the first launch macOS Gatekeeper may quarantine the binary. Clear the attribute (`xattr -d com.apple.quarantine $(which owlen)`) or build from source locally to avoid notarisation prompts.
|
||||
|
||||
### Running OWLEN
|
||||
|
||||
@@ -89,8 +101,16 @@ OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode)
|
||||
|
||||
- **Normal Mode**: Navigate with `h/j/k/l`, `w/b`, `gg/G`.
|
||||
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:w`, `:session save`, `:theme`.
|
||||
- **Quick Exit**: Press `Ctrl+C` twice in Normal mode to quit quickly (first press still cancels active generations).
|
||||
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
|
||||
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
|
||||
|
||||
Model discovery commands worth remembering:
|
||||
|
||||
- `:models --local` or `:models --cloud` jump directly to the corresponding section in the picker.
|
||||
- `:cloud setup [--force-cloud-base-url]` stores your cloud API key without clobbering an existing local base URL (unless you opt in with the flag).
|
||||
When a catalogue is unreachable, Owlen now tags the picker with `Local unavailable` / `Cloud unavailable` so you can recover without guessing.
|
||||
|
||||
## Documentation
|
||||
|
||||
@@ -100,7 +120,10 @@ For more detailed information, please refer to the following documents:
|
||||
- **[CHANGELOG.md](CHANGELOG.md)**: A log of changes for each version.
|
||||
- **[docs/architecture.md](docs/architecture.md)**: An overview of the project's architecture.
|
||||
- **[docs/troubleshooting.md](docs/troubleshooting.md)**: Help with common issues.
|
||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: A guide for adding new providers.
|
||||
- **[docs/repo-map.md](docs/repo-map.md)**: Snapshot of the workspace layout and key crates.
|
||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: Trait-level details for implementing providers.
|
||||
- **[docs/adding-providers.md](docs/adding-providers.md)**: Step-by-step checklist for wiring a provider into the multi-provider architecture and test suite.
|
||||
- **Experimental providers staging area**: [crates/providers/experimental/README.md](crates/providers/experimental/README.md) records the placeholder crates (OpenAI, Anthropic, Gemini) and their current status.
|
||||
- **[docs/platform-support.md](docs/platform-support.md)**: Current OS support matrix and cross-check instructions.
|
||||
|
||||
## Configuration
|
||||
@@ -118,6 +141,16 @@ You can also add custom themes alongside the config directory (e.g., `~/.config/
|
||||
|
||||
See the [themes/README.md](themes/README.md) for more details on theming.
|
||||
|
||||
## Testing
|
||||
|
||||
Owlen uses standard Rust tooling for verification. Run the full test suite with:
|
||||
|
||||
```bash
|
||||
cargo test
|
||||
```
|
||||
|
||||
Unit tests cover the command palette state machine, agent response parsing, and key MCP abstractions. Formatting and lint checks can be run with `cargo fmt --all` and `cargo clippy` respectively.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Upcoming milestones focus on feature parity with modern code assistants while keeping Owlen local-first:
|
||||
|
||||
29
config.toml
Normal file
29
config.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[general]
|
||||
default_provider = "ollama_local"
|
||||
default_model = "llama3.2:latest"
|
||||
|
||||
[privacy]
|
||||
encrypt_local_data = true
|
||||
|
||||
[providers.ollama_local]
|
||||
enabled = true
|
||||
provider_type = "ollama"
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
[providers.ollama_cloud]
|
||||
enabled = false
|
||||
provider_type = "ollama_cloud"
|
||||
base_url = "https://ollama.com"
|
||||
api_key_env = "OLLAMA_CLOUD_API_KEY"
|
||||
|
||||
[providers.openai]
|
||||
enabled = false
|
||||
provider_type = "openai"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
|
||||
[providers.anthropic]
|
||||
enabled = false
|
||||
provider_type = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "owlen-mcp-client"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
description = "Dedicated MCP client library for Owlen, exposing remote MCP server communication"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@@ -5,14 +5,12 @@
|
||||
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
||||
//! core crate internals.
|
||||
|
||||
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
|
||||
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
|
||||
// Re‑export the Provider implementation so the client can also be used as an
|
||||
// LLM provider when the remote MCP server hosts a language‑model tool (e.g.
|
||||
// `generate_text`).
|
||||
// Re‑export the core Provider trait so that the MCP client can also be used as an LLM provider.
|
||||
pub use owlen_core::provider::Provider as McpProvider;
|
||||
pub use owlen_core::Provider as McpProvider;
|
||||
|
||||
// Note: The `RemoteMcpClient` type provides its own `new` constructor in the core
|
||||
// crate. Users can call `RemoteMcpClient::new()` directly. No additional wrapper
|
||||
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "owlen-mcp-code-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
description = "MCP server exposing safe code execution tools for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
@@ -10,11 +10,11 @@ pub mod sandbox;
|
||||
pub mod tools;
|
||||
|
||||
use owlen_core::mcp::protocol::{
|
||||
methods, ErrorCode, InitializeParams, InitializeResult, RequestId, RpcError, RpcErrorResponse,
|
||||
RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, PROTOCOL_VERSION,
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||
};
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
@@ -149,10 +149,10 @@ async fn handle_request(
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
Ok(RpcResponse::new(
|
||||
req.id,
|
||||
serde_json::to_value(result).unwrap(),
|
||||
))
|
||||
let payload = serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = registry.list_tools();
|
||||
@@ -176,10 +176,10 @@ async fn handle_request(
|
||||
metadata: result.metadata,
|
||||
duration_ms: result.duration.as_millis() as u128,
|
||||
};
|
||||
Ok(RpcResponse::new(
|
||||
req.id,
|
||||
serde_json::to_value(resp).unwrap(),
|
||||
))
|
||||
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Docker-based sandboxing for secure code execution
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use bollard::Docker;
|
||||
use bollard::container::{
|
||||
Config, CreateContainerOptions, RemoveContainerOptions, StartContainerOptions,
|
||||
WaitContainerOptions,
|
||||
};
|
||||
use bollard::models::{HostConfig, Mount, MountTypeEnum};
|
||||
use bollard::Docker;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
use crate::sandbox::Sandbox;
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use owlen_core::Result;
|
||||
use serde_json::{json, Value};
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use serde_json::{Value, json};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Tool for compiling projects (Rust, Node.js, Python)
|
||||
@@ -1,10 +1,10 @@
|
||||
[package]
|
||||
name = "owlen-mcp-llm-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@@ -7,18 +7,19 @@
|
||||
clippy::empty_line_after_outer_attr
|
||||
)]
|
||||
|
||||
use owlen_core::config::{ensure_provider_config, Config as OwlenConfig};
|
||||
use owlen_core::Provider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{Config as OwlenConfig, ensure_provider_config};
|
||||
use owlen_core::mcp::protocol::{
|
||||
methods, ErrorCode, InitializeParams, InitializeResult, RequestId, RpcError, RpcErrorResponse,
|
||||
RpcNotification, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, PROTOCOL_VERSION,
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo,
|
||||
methods,
|
||||
};
|
||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use owlen_core::provider::ProviderConfig;
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::types::{ChatParameters, ChatRequest, Message};
|
||||
use owlen_core::Provider;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
@@ -125,7 +126,7 @@ fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama-cloud" => {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
.map_err(|e| {
|
||||
RpcError::internal_error(format!(
|
||||
@@ -152,10 +153,12 @@ fn create_provider() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
}
|
||||
|
||||
fn canonical_provider_name(name: &str) -> String {
|
||||
if name.eq_ignore_ascii_case("ollama-cloud") {
|
||||
"ollama".to_string()
|
||||
} else {
|
||||
name.to_string()
|
||||
let normalized = name.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => "ollama_local".to_string(),
|
||||
"ollama" | "ollama_local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" => "ollama_cloud".to_string(),
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,7 +181,7 @@ async fn handle_generate_text(args: GenerateTextArgs) -> Result<String, RpcError
|
||||
|
||||
// Use streaming API and collect output
|
||||
let mut stream = provider
|
||||
.chat_stream(request)
|
||||
.stream_prompt(request)
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Chat request failed: {}", e)))?;
|
||||
let mut content = String::new();
|
||||
@@ -228,7 +231,9 @@ async fn handle_request(req: &RpcRequest) -> Result<Value, RpcError> {
|
||||
supports_streaming: Some(true),
|
||||
},
|
||||
};
|
||||
Ok(serde_json::to_value(result).unwrap())
|
||||
serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize init result: {}", e))
|
||||
})
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = vec![
|
||||
@@ -245,7 +250,9 @@ async fn handle_request(req: &RpcRequest) -> Result<Value, RpcError> {
|
||||
.list_models()
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to list models: {}", e)))?;
|
||||
Ok(serde_json::to_value(models).unwrap())
|
||||
serde_json::to_value(models).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize model list: {}", e))
|
||||
})
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
// For streaming we will send incremental notifications directly from here.
|
||||
@@ -331,10 +338,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let final_resp = RpcResponse::new(
|
||||
id.clone(),
|
||||
serde_json::to_value(response).unwrap(),
|
||||
);
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize resource response: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
@@ -375,10 +396,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let final_resp = RpcResponse::new(
|
||||
id.clone(),
|
||||
serde_json::to_value(response).unwrap(),
|
||||
);
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize directory listing: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
@@ -454,7 +489,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
parameters,
|
||||
tools: None,
|
||||
};
|
||||
let mut stream = match provider.chat_stream(request).await {
|
||||
let mut stream = match provider.stream_prompt(request).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
@@ -510,8 +545,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let final_resp =
|
||||
RpcResponse::new(id.clone(), serde_json::to_value(response).unwrap());
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize final streaming response: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "owlen-mcp-prompt-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
description = "MCP server that renders prompt templates (YAML) for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
@@ -6,7 +6,7 @@
|
||||
use anyhow::{Context, Result};
|
||||
use handlebars::Handlebars;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -14,8 +14,8 @@ use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use owlen_core::mcp::protocol::{
|
||||
methods, ErrorCode, InitializeParams, InitializeResult, RequestId, RpcError, RpcErrorResponse,
|
||||
RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, PROTOCOL_VERSION,
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||
};
|
||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
@@ -148,7 +148,7 @@ FINAL_ANSWER: Summary of what was done"#
|
||||
template.name, e
|
||||
);
|
||||
} else {
|
||||
let mut templates = futures::executor::block_on(self.templates.write());
|
||||
let mut templates = self.templates.blocking_write();
|
||||
templates.insert(template.name.clone(), template);
|
||||
}
|
||||
}
|
||||
@@ -284,10 +284,10 @@ async fn handle_request(
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
Ok(RpcResponse::new(
|
||||
req.id,
|
||||
serde_json::to_value(result).unwrap(),
|
||||
))
|
||||
let payload = serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = vec![
|
||||
@@ -349,9 +349,17 @@ async fn handle_request(
|
||||
|
||||
let srv = server.lock().await;
|
||||
match srv.get_template(name).await {
|
||||
Some(template) => {
|
||||
json!({"success": true, "template": serde_json::to_value(template).unwrap()})
|
||||
}
|
||||
Some(template) => match serde_json::to_value(template) {
|
||||
Ok(serialized) => {
|
||||
json!({"success": true, "template": serialized})
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(RpcError::internal_error(format!(
|
||||
"Failed to serialize template '{}': {}",
|
||||
name, e
|
||||
)));
|
||||
}
|
||||
},
|
||||
None => json!({"success": false, "error": "Template not found"}),
|
||||
}
|
||||
}
|
||||
@@ -397,10 +405,10 @@ async fn handle_request(
|
||||
duration_ms: 0,
|
||||
};
|
||||
|
||||
Ok(RpcResponse::new(
|
||||
req.id,
|
||||
serde_json::to_value(resp).unwrap(),
|
||||
))
|
||||
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "owlen-mcp-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
@@ -9,4 +9,4 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
path-clean = "1.0"
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
@@ -1,6 +1,6 @@
|
||||
use owlen_core::mcp::protocol::{
|
||||
is_compatible, ErrorCode, InitializeParams, InitializeResult, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, PROTOCOL_VERSION,
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, is_compatible,
|
||||
};
|
||||
use path_clean::PathClean;
|
||||
use serde::Deserialize;
|
||||
@@ -17,6 +17,11 @@ name = "owlen"
|
||||
path = "src/main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-code"
|
||||
path = "src/code_main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-agent"
|
||||
path = "src/agent_main.rs"
|
||||
@@ -24,6 +29,7 @@ required-features = ["chat-client"]
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-providers = { path = "../owlen-providers" }
|
||||
# Optional TUI dependency, enabled by the "chat-client" feature.
|
||||
owlen-tui = { path = "../owlen-tui", optional = true }
|
||||
log = { workspace = true }
|
||||
|
||||
326
crates/owlen-cli/src/bootstrap.rs
Normal file
326
crates/owlen-cli/src/bootstrap.rs
Normal file
@@ -0,0 +1,326 @@
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use crossterm::{
|
||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||
execute,
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures::stream;
|
||||
use owlen_core::{
|
||||
ChatStream, Error, Provider,
|
||||
config::{Config, McpMode},
|
||||
mcp::remote_client::RemoteMcpClient,
|
||||
mode::Mode,
|
||||
provider::ProviderManager,
|
||||
providers::OllamaProvider,
|
||||
session::{ControllerEvent, SessionController},
|
||||
storage::StorageManager,
|
||||
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||
};
|
||||
use owlen_tui::{
|
||||
ChatApp, SessionEvent,
|
||||
app::App as RuntimeApp,
|
||||
config,
|
||||
tui_controller::{TuiController, TuiRequest},
|
||||
ui,
|
||||
};
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::commands::cloud::{load_runtime_credentials, set_env_var};
|
||||
|
||||
pub async fn launch(initial_mode: Mode) -> Result<()> {
|
||||
set_env_var("OWLEN_AUTO_CONSENT", "1");
|
||||
|
||||
let color_support = detect_terminal_color_support();
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
let _ = cfg.refresh_mcp_servers(None);
|
||||
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||
};
|
||||
eprintln!(
|
||||
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||
term_label, BASIC_THEME_NAME, previous_theme
|
||||
);
|
||||
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||
eprintln!(
|
||||
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||
term
|
||||
);
|
||||
}
|
||||
|
||||
cfg.validate()?;
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||
|
||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||
|
||||
let provider = build_provider(&cfg)?;
|
||||
let mut offline_notice: Option<String> = None;
|
||||
let provider = match provider.health_check().await {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.effective_mcp_servers().is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||
};
|
||||
let notice =
|
||||
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||
eprintln!("{notice}");
|
||||
offline_notice = Some(notice.clone());
|
||||
let fallback_model = cfg
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "offline".to_string());
|
||||
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||
}
|
||||
};
|
||||
|
||||
let (controller_event_tx, controller_event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||
let controller = SessionController::new(
|
||||
provider,
|
||||
cfg,
|
||||
storage.clone(),
|
||||
tui_controller,
|
||||
false,
|
||||
Some(controller_event_tx),
|
||||
)
|
||||
.await?;
|
||||
let provider_manager = Arc::new(ProviderManager::default());
|
||||
let mut runtime = RuntimeApp::new(provider_manager);
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||
app.initialize_models().await?;
|
||||
if let Some(notice) = offline_notice.clone() {
|
||||
app.set_status_message(¬ice);
|
||||
app.set_system_status(notice);
|
||||
}
|
||||
|
||||
app.set_mode(initial_mode).await;
|
||||
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
execute!(
|
||||
stdout,
|
||||
EnterAlternateScreen,
|
||||
EnableMouseCapture,
|
||||
EnableBracketedPaste
|
||||
)?;
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let result = run_app(&mut terminal, &mut runtime, &mut app, &mut session_rx).await;
|
||||
|
||||
config::save_config(&app.config())?;
|
||||
|
||||
disable_raw_mode()?;
|
||||
execute!(
|
||||
terminal.backend_mut(),
|
||||
LeaveAlternateScreen,
|
||||
DisableMouseCapture,
|
||||
DisableBracketedPaste
|
||||
)?;
|
||||
terminal.show_cursor()?;
|
||||
|
||||
if let Err(err) = result {
|
||||
println!("{err:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server)
|
||||
} else {
|
||||
RemoteMcpClient::new()
|
||||
};
|
||||
|
||||
match remote_result {
|
||||
Ok(client) => Ok(Arc::new(client) as Arc<dyn Provider>),
|
||||
Err(err) if cfg.mcp.allow_fallback => {
|
||||
log::warn!(
|
||||
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||
err
|
||||
);
|
||||
build_local_provider(cfg)
|
||||
}
|
||||
Err(err) => Err(anyhow!(err)),
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
|
||||
anyhow!("[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"")
|
||||
})?;
|
||||
let client = RemoteMcpClient::new_with_config(mcp_server)?;
|
||||
Ok(Arc::new(client) as Arc<dyn Provider>)
|
||||
}
|
||||
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||
McpMode::Disabled => Err(anyhow!(
|
||||
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
let provider_name = cfg.general.default_provider.clone();
|
||||
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||
anyhow!(format!(
|
||||
"No provider configuration found for '{provider_name}' in [providers]"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(anyhow!(format!(
|
||||
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TerminalColorSupport {
|
||||
Full,
|
||||
Limited { term: String },
|
||||
}
|
||||
|
||||
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term_lower = term.to_lowercase();
|
||||
let color_lower = colorterm.to_lowercase();
|
||||
|
||||
let supports_extended = term_lower.contains("256color")
|
||||
|| color_lower.contains("truecolor")
|
||||
|| color_lower.contains("24bit")
|
||||
|| color_lower.contains("fullcolor");
|
||||
|
||||
if supports_extended {
|
||||
TerminalColorSupport::Full
|
||||
} else {
|
||||
TerminalColorSupport::Limited { term }
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||
match support {
|
||||
TerminalColorSupport::Full => None,
|
||||
TerminalColorSupport::Limited { .. } => {
|
||||
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||
Some(previous)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OfflineProvider {
|
||||
reason: String,
|
||||
placeholder_model: String,
|
||||
}
|
||||
|
||||
impl OfflineProvider {
|
||||
fn new(reason: String, placeholder_model: String) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
placeholder_model,
|
||||
}
|
||||
}
|
||||
|
||||
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||
let mut message = String::new();
|
||||
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||
message.push_str(&self.reason);
|
||||
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||
message.push_str(&format!(
|
||||
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||
requested_model
|
||||
));
|
||||
}
|
||||
message.push_str(
|
||||
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::assistant(message),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OfflineProvider {
|
||||
fn name(&self) -> &str {
|
||||
"offline"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: self.placeholder_model.clone(),
|
||||
provider: "offline".to_string(),
|
||||
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||
Ok(self.friendly_response(&request.model))
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||
let response = self.friendly_response(&request.model);
|
||||
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<(), Error> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"offline provider cannot reach any backing models"
|
||||
)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_app(
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
runtime: &mut RuntimeApp,
|
||||
app: &mut ChatApp,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
let mut render = |terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
state: &mut ChatApp|
|
||||
-> Result<()> {
|
||||
terminal.draw(|f| ui::render_chat(f, state))?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
runtime.run(terminal, app, session_rx, &mut render).await?;
|
||||
Ok(())
|
||||
}
|
||||
16
crates/owlen-cli/src/code_main.rs
Normal file
16
crates/owlen-cli/src/code_main.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Owlen CLI entrypoint optimised for code-first workflows.
|
||||
#![allow(dead_code, unused_imports)]
|
||||
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::Result;
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
bootstrap::launch(Mode::Code).await
|
||||
}
|
||||
@@ -1,17 +1,24 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use clap::Subcommand;
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::config::Config;
|
||||
use owlen_core::LlmProvider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{
|
||||
self as core_config, Config, OLLAMA_CLOUD_API_KEY_ENV, OLLAMA_CLOUD_BASE_URL,
|
||||
OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY,
|
||||
};
|
||||
use owlen_core::credentials::{ApiCredentials, CredentialManager, OLLAMA_CLOUD_CREDENTIAL_ID};
|
||||
use owlen_core::encryption;
|
||||
use owlen_core::provider::{LLMProvider, ProviderConfig};
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::storage::StorageManager;
|
||||
use serde_json::Value;
|
||||
|
||||
const DEFAULT_CLOUD_ENDPOINT: &str = "https://ollama.com";
|
||||
const DEFAULT_CLOUD_ENDPOINT: &str = OLLAMA_CLOUD_BASE_URL;
|
||||
const CLOUD_ENDPOINT_KEY: &str = OLLAMA_CLOUD_ENDPOINT_KEY;
|
||||
const CLOUD_PROVIDER_KEY: &str = "ollama_cloud";
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum CloudCommand {
|
||||
@@ -23,26 +30,29 @@ pub enum CloudCommand {
|
||||
/// Override the cloud endpoint (default: https://ollama.com)
|
||||
#[arg(long)]
|
||||
endpoint: Option<String>,
|
||||
/// Provider name to configure (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to configure (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
/// Overwrite the provider base URL with the cloud endpoint
|
||||
#[arg(long)]
|
||||
force_cloud_base_url: bool,
|
||||
},
|
||||
/// Check connectivity to Ollama Cloud
|
||||
Status {
|
||||
/// Provider name to check (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to check (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// List available cloud-hosted models
|
||||
Models {
|
||||
/// Provider name to query (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to query (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// Remove stored Ollama Cloud credentials
|
||||
Logout {
|
||||
/// Provider name to clear (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to clear (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
@@ -53,19 +63,30 @@ pub async fn run_cloud_command(command: CloudCommand) -> Result<()> {
|
||||
api_key,
|
||||
endpoint,
|
||||
provider,
|
||||
} => setup(provider, api_key, endpoint).await,
|
||||
force_cloud_base_url,
|
||||
} => setup(provider, api_key, endpoint, force_cloud_base_url).await,
|
||||
CloudCommand::Status { provider } => status(provider).await,
|
||||
CloudCommand::Models { provider } => models(provider).await,
|
||||
CloudCommand::Logout { provider } => logout(provider).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup(provider: String, api_key: Option<String>, endpoint: Option<String>) -> Result<()> {
|
||||
async fn setup(
|
||||
provider: String,
|
||||
api_key: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
force_cloud_base_url: bool,
|
||||
) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let endpoint = endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let endpoint =
|
||||
normalize_endpoint(&endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string()));
|
||||
|
||||
ensure_provider_entry(&mut config, &provider, &endpoint);
|
||||
let base_changed = {
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, &endpoint, force_cloud_base_url)
|
||||
};
|
||||
|
||||
let key = match api_key {
|
||||
Some(value) if !value.trim().is_empty() => value,
|
||||
@@ -93,10 +114,6 @@ async fn setup(provider: String, api_key: Option<String>, endpoint: Option<Strin
|
||||
entry.api_key = Some(key.clone());
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.base_url = Some(endpoint.clone());
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
println!("Saved Ollama configuration for provider '{provider}'.");
|
||||
if config.privacy.encrypt_local_data {
|
||||
@@ -104,6 +121,12 @@ async fn setup(provider: String, api_key: Option<String>, endpoint: Option<Strin
|
||||
} else {
|
||||
println!("API key stored in plaintext configuration (encryption disabled).");
|
||||
}
|
||||
if !force_cloud_base_url && !base_changed {
|
||||
println!(
|
||||
"Local base URL preserved; cloud endpoint stored as {}.",
|
||||
CLOUD_ENDPOINT_KEY
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -118,25 +141,32 @@ async fn status(provider: String) -> Result<()> {
|
||||
};
|
||||
|
||||
let api_key = hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
ensure_provider_entry(&mut config, &provider, DEFAULT_CLOUD_ENDPOINT);
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let ollama = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint.clone());
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.health_check().await {
|
||||
Ok(_) => {
|
||||
println!(
|
||||
"✓ Connected to {provider} ({})",
|
||||
provider_cfg
|
||||
.base_url
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_CLOUD_ENDPOINT)
|
||||
);
|
||||
println!("✓ Connected to {provider} ({})", endpoint);
|
||||
if api_key.is_none() && config.privacy.encrypt_local_data {
|
||||
println!(
|
||||
"Warning: No API key stored; connection succeeded via environment variables."
|
||||
@@ -162,13 +192,27 @@ async fn models(provider: String) -> Result<()> {
|
||||
};
|
||||
hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
|
||||
ensure_provider_entry(&mut config, &provider, DEFAULT_CLOUD_ENDPOINT);
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let ollama = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint);
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.list_models().await {
|
||||
@@ -206,8 +250,9 @@ async fn logout(provider: String) -> Result<()> {
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(entry) = provider_entry_mut(&mut config) {
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = None;
|
||||
entry.enabled = false;
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
@@ -215,62 +260,95 @@ async fn logout(provider: String) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_provider_entry(config: &mut Config, provider: &str, endpoint: &str) {
|
||||
if provider == "ollama"
|
||||
&& config.providers.contains_key("ollama-cloud")
|
||||
&& !config.providers.contains_key("ollama")
|
||||
{
|
||||
if let Some(mut legacy) = config.providers.remove("ollama-cloud") {
|
||||
legacy.provider_type = "ollama".to_string();
|
||||
config.providers.insert("ollama".to_string(), legacy);
|
||||
}
|
||||
fn ensure_provider_entry<'a>(config: &'a mut Config, provider: &str) -> &'a mut ProviderConfig {
|
||||
core_config::ensure_provider_config_mut(config, provider)
|
||||
}
|
||||
|
||||
fn configure_cloud_endpoint(entry: &mut ProviderConfig, endpoint: &str, force: bool) -> bool {
|
||||
let normalized = normalize_endpoint(endpoint);
|
||||
let previous_base = entry.base_url.clone();
|
||||
entry.extra.insert(
|
||||
CLOUD_ENDPOINT_KEY.to_string(),
|
||||
Value::String(normalized.clone()),
|
||||
);
|
||||
|
||||
if entry.api_key_env.is_none() {
|
||||
entry.api_key_env = Some(OLLAMA_CLOUD_API_KEY_ENV.to_string());
|
||||
}
|
||||
|
||||
core_config::ensure_provider_config(config, provider);
|
||||
if force
|
||||
|| entry
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
entry.base_url = Some(normalized.clone());
|
||||
}
|
||||
|
||||
if let Some(cfg) = config.providers.get_mut(provider) {
|
||||
if cfg.provider_type != "ollama" {
|
||||
cfg.provider_type = "ollama".to_string();
|
||||
}
|
||||
if cfg.base_url.is_none() {
|
||||
cfg.base_url = Some(endpoint.to_string());
|
||||
}
|
||||
if force {
|
||||
entry.enabled = true;
|
||||
}
|
||||
|
||||
entry.base_url != previous_base
|
||||
}
|
||||
|
||||
fn resolve_cloud_endpoint(cfg: &ProviderConfig) -> Option<String> {
|
||||
if let Some(value) = cfg
|
||||
.extra
|
||||
.get(CLOUD_ENDPOINT_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.map(normalize_endpoint)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
|
||||
cfg.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim_end_matches('/').to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
fn normalize_endpoint(endpoint: &str) -> String {
|
||||
let trimmed = endpoint.trim().trim_end_matches('/');
|
||||
if trimmed.is_empty() {
|
||||
DEFAULT_CLOUD_ENDPOINT.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_name(provider: &str) -> String {
|
||||
let normalized = provider.trim().replace('_', "-").to_ascii_lowercase();
|
||||
let normalized = provider.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => "ollama".to_string(),
|
||||
"ollama-cloud" => "ollama".to_string(),
|
||||
"" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama_cloud" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
value => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_env_var<K, V>(key: K, value: V)
|
||||
where
|
||||
K: AsRef<OsStr>,
|
||||
V: AsRef<OsStr>,
|
||||
{
|
||||
// Safety: the CLI updates process-wide environment variables during startup while no
|
||||
// other threads are mutating the environment.
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_env_if_missing(var: &str, value: &str) {
|
||||
if std::env::var(var)
|
||||
.map(|v| v.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
std::env::set_var(var, value);
|
||||
set_env_var(var, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_entry_mut(config: &mut Config) -> Option<&mut ProviderConfig> {
|
||||
if config.providers.contains_key("ollama") {
|
||||
config.providers.get_mut("ollama")
|
||||
} else {
|
||||
config.providers.get_mut("ollama-cloud")
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_entry(config: &Config) -> Option<&ProviderConfig> {
|
||||
if let Some(entry) = config.providers.get("ollama") {
|
||||
return Some(entry);
|
||||
}
|
||||
config.providers.get("ollama-cloud")
|
||||
}
|
||||
|
||||
fn unlock_credential_manager(
|
||||
config: &Config,
|
||||
storage: Arc<StorageManager>,
|
||||
@@ -302,18 +380,20 @@ fn unlock_vault(path: &Path) -> Result<encryption::VaultHandle> {
|
||||
use std::env;
|
||||
|
||||
if path.exists() {
|
||||
if let Ok(password) = env::var("OWLEN_MASTER_PASSWORD") {
|
||||
if !password.trim().is_empty() {
|
||||
return encryption::unlock_with_password(path.to_path_buf(), &password)
|
||||
.context("Failed to unlock vault with OWLEN_MASTER_PASSWORD");
|
||||
}
|
||||
if let Some(password) = env::var("OWLEN_MASTER_PASSWORD")
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|password| !password.is_empty())
|
||||
{
|
||||
return encryption::unlock_with_password(path.to_path_buf(), &password)
|
||||
.context("Failed to unlock vault with OWLEN_MASTER_PASSWORD");
|
||||
}
|
||||
|
||||
for attempt in 0..3 {
|
||||
let password = encryption::prompt_password("Enter master password: ")?;
|
||||
match encryption::unlock_with_password(path.to_path_buf(), &password) {
|
||||
Ok(handle) => {
|
||||
env::set_var("OWLEN_MASTER_PASSWORD", password);
|
||||
set_env_var("OWLEN_MASTER_PASSWORD", password);
|
||||
return Ok(handle);
|
||||
}
|
||||
Err(err) => {
|
||||
@@ -334,7 +414,7 @@ fn unlock_vault(path: &Path) -> Result<encryption::VaultHandle> {
|
||||
.unwrap_or(true)
|
||||
{
|
||||
let password = encryption::prompt_password("Cache master password for this session: ")?;
|
||||
env::set_var("OWLEN_MASTER_PASSWORD", password);
|
||||
set_env_var("OWLEN_MASTER_PASSWORD", password);
|
||||
}
|
||||
Ok(handle)
|
||||
}
|
||||
@@ -343,34 +423,32 @@ async fn hydrate_api_key(
|
||||
config: &mut Config,
|
||||
manager: Option<&Arc<CredentialManager>>,
|
||||
) -> Result<Option<String>> {
|
||||
if let Some(manager) = manager {
|
||||
if let Some(credentials) = manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await? {
|
||||
let key = credentials.api_key.trim().to_string();
|
||||
if !key.is_empty() {
|
||||
set_env_if_missing("OLLAMA_API_KEY", &key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", &key);
|
||||
}
|
||||
let credentials = match manager {
|
||||
Some(manager) => manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await?,
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(cfg) = provider_entry_mut(config) {
|
||||
if cfg.base_url.is_none() && !credentials.endpoint.trim().is_empty() {
|
||||
cfg.base_url = Some(credentials.endpoint);
|
||||
}
|
||||
}
|
||||
return Ok(Some(key));
|
||||
if let Some(credentials) = credentials {
|
||||
let key = credentials.api_key.trim().to_string();
|
||||
if !key.is_empty() {
|
||||
set_env_if_missing("OLLAMA_API_KEY", &key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", &key);
|
||||
}
|
||||
|
||||
let cfg = core_config::ensure_provider_config_mut(config, CLOUD_PROVIDER_KEY);
|
||||
configure_cloud_endpoint(cfg, &credentials.endpoint, false);
|
||||
return Ok(Some(key));
|
||||
}
|
||||
|
||||
if let Some(cfg) = provider_entry(config) {
|
||||
if let Some(key) = cfg
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
set_env_if_missing("OLLAMA_API_KEY", key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", key);
|
||||
return Ok(Some(key.to_string()));
|
||||
}
|
||||
if let Some(key) = config
|
||||
.provider(CLOUD_PROVIDER_KEY)
|
||||
.and_then(|cfg| cfg.api_key.as_ref())
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
set_env_if_missing("OLLAMA_API_KEY", key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", key);
|
||||
return Ok(Some(key.to_string()));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
@@ -394,8 +472,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn canonicalises_provider_names() {
|
||||
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), "ollama");
|
||||
assert_eq!(canonical_provider_name(" ollama-cloud"), "ollama");
|
||||
assert_eq!(canonical_provider_name(""), "ollama");
|
||||
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(" ollama-cloud"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(""), CLOUD_PROVIDER_KEY);
|
||||
}
|
||||
}
|
||||
4
crates/owlen-cli/src/commands/mod.rs
Normal file
4
crates/owlen-cli/src/commands/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Command implementations for the `owlen` CLI.
|
||||
|
||||
pub mod cloud;
|
||||
pub mod providers;
|
||||
651
crates/owlen-cli/src/commands/providers.rs
Normal file
651
crates/owlen-cli/src/commands/providers.rs
Normal file
@@ -0,0 +1,651 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand};
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{self as core_config, Config};
|
||||
use owlen_core::provider::{
|
||||
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||
};
|
||||
use owlen_core::storage::StorageManager;
|
||||
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
use super::cloud;
|
||||
|
||||
/// CLI subcommands for provider management.
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ProvidersCommand {
|
||||
/// List configured providers and their metadata.
|
||||
List,
|
||||
/// Run health checks against providers.
|
||||
Status {
|
||||
/// Optional provider identifier to check.
|
||||
#[arg(value_name = "PROVIDER")]
|
||||
provider: Option<String>,
|
||||
},
|
||||
/// Enable a provider in the configuration.
|
||||
Enable {
|
||||
/// Provider identifier to enable.
|
||||
provider: String,
|
||||
},
|
||||
/// Disable a provider in the configuration.
|
||||
Disable {
|
||||
/// Provider identifier to disable.
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Arguments for the `owlen models` command.
|
||||
#[derive(Debug, Default, Args)]
|
||||
pub struct ModelsArgs {
|
||||
/// Restrict output to a specific provider.
|
||||
#[arg(long)]
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn run_providers_command(command: ProvidersCommand) -> Result<()> {
|
||||
match command {
|
||||
ProvidersCommand::List => list_providers(),
|
||||
ProvidersCommand::Status { provider } => status_providers(provider.as_deref()).await,
|
||||
ProvidersCommand::Enable { provider } => toggle_provider(&provider, true),
|
||||
ProvidersCommand::Disable { provider } => toggle_provider(&provider, false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_models_command(args: ModelsArgs) -> Result<()> {
|
||||
list_models(args.provider.as_deref()).await
|
||||
}
|
||||
|
||||
fn list_providers() -> Result<()> {
|
||||
let config = tui_config::try_load_config().unwrap_or_default();
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for (id, cfg) in &config.providers {
|
||||
let type_label = describe_provider_type(id, cfg);
|
||||
let auth_label = describe_auth(cfg, requires_auth(id, cfg));
|
||||
let enabled = if cfg.enabled { "yes" } else { "no" };
|
||||
let default = if id == &default_provider { "*" } else { "" };
|
||||
let base = cfg
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().to_string())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
|
||||
rows.push(ProviderListRow {
|
||||
id: id.to_string(),
|
||||
type_label,
|
||||
enabled: enabled.to_string(),
|
||||
default: default.to_string(),
|
||||
auth: auth_label,
|
||||
base_url: base,
|
||||
});
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let enabled_width = rows
|
||||
.iter()
|
||||
.map(|row| row.enabled.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Enabled".len());
|
||||
let default_width = rows
|
||||
.iter()
|
||||
.map(|row| row.default.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Default".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.type_label.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let auth_width = rows
|
||||
.iter()
|
||||
.map(|row| row.auth.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Auth".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} Base URL",
|
||||
"Provider",
|
||||
"Enabled",
|
||||
"Default",
|
||||
"Type",
|
||||
"Auth",
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} {}",
|
||||
row.id,
|
||||
row.enabled,
|
||||
row.default,
|
||||
row.type_label,
|
||||
row.auth,
|
||||
row.base_url,
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status_providers(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let health = manager.refresh_health().await;
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for record in records {
|
||||
let status = health.get(&record.id).copied();
|
||||
rows.push(ProviderStatusRow::from_record(record, status));
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
print_status_rows(&rows);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_models(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let models = manager
|
||||
.list_all_models()
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let statuses = manager.provider_statuses().await;
|
||||
|
||||
print_models(records, models, statuses);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_provider_filter(config: &Config, filter: Option<&str>) -> Result<()> {
|
||||
if let Some(filter) = filter
|
||||
&& !config.providers.contains_key(filter)
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Provider '{}' is not defined in configuration.",
|
||||
filter
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn toggle_provider(provider: &str, enable: bool) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let canonical = canonical_provider_id(provider);
|
||||
if canonical.is_empty() {
|
||||
return Err(anyhow!("Provider name cannot be empty."));
|
||||
}
|
||||
|
||||
let previous_default = config.general.default_provider.clone();
|
||||
let previous_fallback_enabled = config.providers.get("ollama_local").map(|cfg| cfg.enabled);
|
||||
|
||||
let previous_enabled;
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
previous_enabled = entry.enabled;
|
||||
if previous_enabled == enable {
|
||||
println!(
|
||||
"Provider '{}' is already {}.",
|
||||
canonical,
|
||||
if enable { "enabled" } else { "disabled" }
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
entry.enabled = enable;
|
||||
}
|
||||
|
||||
if !enable && config.general.default_provider == canonical {
|
||||
if let Some(candidate) = choose_fallback_provider(&config, &canonical) {
|
||||
config.general.default_provider = candidate.clone();
|
||||
println!(
|
||||
"Default provider set to '{}' because '{}' was disabled.",
|
||||
candidate, canonical
|
||||
);
|
||||
} else {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
entry.enabled = true;
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
println!(
|
||||
"Enabled 'ollama_local' and made it default because no other providers are active."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = config.validate() {
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
entry.enabled = previous_enabled;
|
||||
}
|
||||
config.general.default_provider = previous_default;
|
||||
if let Some(enabled) = previous_fallback_enabled
|
||||
&& let Some(entry) = config.providers.get_mut("ollama_local")
|
||||
{
|
||||
entry.enabled = enabled;
|
||||
}
|
||||
return Err(anyhow!(err));
|
||||
}
|
||||
|
||||
tui_config::save_config(&config).map_err(|err| anyhow!(err))?;
|
||||
|
||||
println!(
|
||||
"{} provider '{}'.",
|
||||
if enable { "Enabled" } else { "Disabled" },
|
||||
canonical
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn choose_fallback_provider(config: &Config, exclude: &str) -> Option<String> {
|
||||
if exclude != "ollama_local"
|
||||
&& let Some(cfg) = config.providers.get("ollama_local")
|
||||
&& cfg.enabled
|
||||
{
|
||||
return Some("ollama_local".to_string());
|
||||
}
|
||||
|
||||
let mut candidates: Vec<String> = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && id.as_str() != exclude)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
candidates.sort();
|
||||
candidates.into_iter().next()
|
||||
}
|
||||
|
||||
async fn register_enabled_providers(
|
||||
manager: &ProviderManager,
|
||||
config: &Config,
|
||||
filter: Option<&str>,
|
||||
) -> Result<Vec<ProviderRecord>> {
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
let mut records = Vec::new();
|
||||
|
||||
for (id, cfg) in &config.providers {
|
||||
if let Some(filter) = filter
|
||||
&& id != filter
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut record = ProviderRecord::from_config(id, cfg, id == &default_provider);
|
||||
if !cfg.enabled {
|
||||
records.push(record);
|
||||
continue;
|
||||
}
|
||||
|
||||
match instantiate_provider(id, cfg) {
|
||||
Ok(provider) => {
|
||||
let metadata = provider.metadata().clone();
|
||||
record.provider_type_label = provider_type_label(metadata.provider_type);
|
||||
record.requires_auth = metadata.requires_auth;
|
||||
record.metadata = Some(metadata);
|
||||
manager.register_provider(provider).await;
|
||||
}
|
||||
Err(err) => {
|
||||
record.registration_error = Some(err.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
records.push(record);
|
||||
}
|
||||
|
||||
records.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn instantiate_provider(id: &str, cfg: &ProviderConfig) -> Result<Arc<dyn ModelProvider>> {
|
||||
let kind = cfg.provider_type.trim().to_ascii_lowercase();
|
||||
if kind == "ollama" || id == "ollama_local" {
|
||||
let provider = OllamaLocalProvider::new(cfg.base_url.clone(), None, None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else if kind == "ollama_cloud" || id == "ollama_cloud" {
|
||||
let provider = OllamaCloudProvider::new(cfg.base_url.clone(), cfg.api_key.clone(), None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Provider '{}' uses unsupported type '{}'.",
|
||||
id,
|
||||
if kind.is_empty() {
|
||||
"unknown"
|
||||
} else {
|
||||
kind.as_str()
|
||||
}
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn describe_provider_type(id: &str, cfg: &ProviderConfig) -> String {
|
||||
if cfg.provider_type.trim().eq_ignore_ascii_case("ollama") || id.ends_with("_local") {
|
||||
"Local".to_string()
|
||||
} else if cfg
|
||||
.provider_type
|
||||
.trim()
|
||||
.eq_ignore_ascii_case("ollama_cloud")
|
||||
|| id.contains("cloud")
|
||||
{
|
||||
"Cloud".to_string()
|
||||
} else {
|
||||
"Custom".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_auth(id: &str, cfg: &ProviderConfig) -> bool {
|
||||
cfg.api_key.is_some()
|
||||
|| cfg.api_key_env.is_some()
|
||||
|| matches!(id, "ollama_cloud" | "openai" | "anthropic")
|
||||
}
|
||||
|
||||
fn describe_auth(cfg: &ProviderConfig, required: bool) -> String {
|
||||
if let Some(env) = cfg
|
||||
.api_key_env
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
format!("env:{env}")
|
||||
} else if cfg
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map(|value| !value.trim().is_empty())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
"config".to_string()
|
||||
} else if required {
|
||||
"required".to_string()
|
||||
} else {
|
||||
"-".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_id(raw: &str) -> String {
|
||||
let trimmed = raw.trim().to_ascii_lowercase();
|
||||
if trimmed.is_empty() {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
match trimmed.as_str() {
|
||||
"ollama" | "ollama-local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" | "ollama-cloud" => "ollama_cloud".to_string(),
|
||||
other => other.replace('-', "_"),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_type_label(provider_type: ProviderType) -> String {
|
||||
match provider_type {
|
||||
ProviderType::Local => "Local".to_string(),
|
||||
ProviderType::Cloud => "Cloud".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_status_strings(status: ProviderStatus) -> (&'static str, &'static str) {
|
||||
match status {
|
||||
ProviderStatus::Available => ("OK", "available"),
|
||||
ProviderStatus::Unavailable => ("ERR", "unavailable"),
|
||||
ProviderStatus::RequiresSetup => ("SETUP", "requires setup"),
|
||||
}
|
||||
}
|
||||
|
||||
fn print_status_rows(rows: &[ProviderStatusRow]) {
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.provider_type.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let status_width = rows
|
||||
.iter()
|
||||
.map(|row| row.indicator.len() + 1 + row.status_label.len())
|
||||
.max()
|
||||
.unwrap_or(6)
|
||||
.max("State".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} Details",
|
||||
"Provider",
|
||||
"Def",
|
||||
"Type",
|
||||
"State",
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
let def = if row.default_provider { "*" } else { "-" };
|
||||
let details = row.detail.as_deref().unwrap_or("-");
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} {}",
|
||||
row.id,
|
||||
def,
|
||||
row.provider_type,
|
||||
format!("{} {}", row.indicator, row.status_label),
|
||||
details,
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_models(
|
||||
records: Vec<ProviderRecord>,
|
||||
models: Vec<AnnotatedModelInfo>,
|
||||
statuses: HashMap<String, ProviderStatus>,
|
||||
) {
|
||||
let mut grouped: HashMap<String, Vec<AnnotatedModelInfo>> = HashMap::new();
|
||||
for info in models {
|
||||
grouped
|
||||
.entry(info.provider_id.clone())
|
||||
.or_default()
|
||||
.push(info);
|
||||
}
|
||||
|
||||
for record in records {
|
||||
let status = statuses.get(&record.id).copied().or_else(|| {
|
||||
if record.metadata.is_some() && record.registration_error.is_none() && record.enabled {
|
||||
Some(ProviderStatus::Unavailable)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let (indicator, label, status_value) = if !record.enabled {
|
||||
("-", "disabled", None)
|
||||
} else if record.registration_error.is_some() {
|
||||
("ERR", "error", None)
|
||||
} else if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
(indicator, label, Some(status))
|
||||
} else {
|
||||
("?", "unknown", None)
|
||||
};
|
||||
|
||||
let title = if record.default_provider {
|
||||
format!("{} (default)", record.id)
|
||||
} else {
|
||||
record.id.clone()
|
||||
};
|
||||
println!(
|
||||
"{} {} [{}] {}",
|
||||
indicator, title, record.provider_type_label, label
|
||||
);
|
||||
|
||||
if let Some(err) = &record.registration_error {
|
||||
println!(" error: {}", err);
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if !record.enabled {
|
||||
println!(" provider disabled");
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(entries) = grouped.get(&record.id) {
|
||||
let mut entries = entries.clone();
|
||||
entries.sort_by(|a, b| a.model.name.cmp(&b.model.name));
|
||||
if entries.is_empty() {
|
||||
println!(" (no models reported)");
|
||||
} else {
|
||||
for entry in entries {
|
||||
let mut line = format!(" - {}", entry.model.name);
|
||||
if let Some(description) = &entry.model.description
|
||||
&& !description.trim().is_empty()
|
||||
{
|
||||
line.push_str(&format!(" — {}", description.trim()));
|
||||
}
|
||||
println!("{}", line);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" (no models reported)");
|
||||
}
|
||||
|
||||
if let Some(ProviderStatus::RequiresSetup) = status_value
|
||||
&& record.requires_auth
|
||||
{
|
||||
println!(" configure provider credentials or API key");
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderListRow {
|
||||
id: String,
|
||||
type_label: String,
|
||||
enabled: String,
|
||||
default: String,
|
||||
auth: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
struct ProviderRecord {
|
||||
id: String,
|
||||
enabled: bool,
|
||||
default_provider: bool,
|
||||
provider_type_label: String,
|
||||
requires_auth: bool,
|
||||
registration_error: Option<String>,
|
||||
metadata: Option<owlen_core::provider::ProviderMetadata>,
|
||||
}
|
||||
|
||||
impl ProviderRecord {
|
||||
fn from_config(id: &str, cfg: &ProviderConfig, default_provider: bool) -> Self {
|
||||
Self {
|
||||
id: id.to_string(),
|
||||
enabled: cfg.enabled,
|
||||
default_provider,
|
||||
provider_type_label: describe_provider_type(id, cfg),
|
||||
requires_auth: requires_auth(id, cfg),
|
||||
registration_error: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderStatusRow {
|
||||
id: String,
|
||||
provider_type: String,
|
||||
default_provider: bool,
|
||||
indicator: String,
|
||||
status_label: String,
|
||||
detail: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderStatusRow {
|
||||
fn from_record(record: ProviderRecord, status: Option<ProviderStatus>) -> Self {
|
||||
if !record.enabled {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "-".to_string(),
|
||||
status_label: "disabled".to_string(),
|
||||
detail: None,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(err) = record.registration_error {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "ERR".to_string(),
|
||||
status_label: "error".to_string(),
|
||||
detail: Some(err),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: indicator.to_string(),
|
||||
status_label: label.to_string(),
|
||||
detail: if matches!(status, ProviderStatus::RequiresSetup) && record.requires_auth {
|
||||
Some("credentials required".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "?".to_string(),
|
||||
status_label: "unknown".to_string(),
|
||||
detail: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,38 +1,22 @@
|
||||
#![allow(clippy::collapsible_if)] // TODO: Remove once Rust 2024 let-chains are available
|
||||
|
||||
//! OWLEN CLI - Chat TUI client
|
||||
|
||||
mod cloud;
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use cloud::{load_runtime_credentials, CloudCommand};
|
||||
use commands::{
|
||||
cloud::{CloudCommand, run_cloud_command},
|
||||
providers::{ModelsArgs, ProvidersCommand, run_models_command, run_providers_command},
|
||||
};
|
||||
use mcp::{McpCommand, run_mcp_command};
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::{
|
||||
config::{Config, McpMode},
|
||||
mcp::remote_client::RemoteMcpClient,
|
||||
mode::Mode,
|
||||
provider::ChatStream,
|
||||
providers::OllamaProvider,
|
||||
session::SessionController,
|
||||
storage::StorageManager,
|
||||
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||
Error, Provider,
|
||||
};
|
||||
use owlen_tui::tui_controller::{TuiController, TuiRequest};
|
||||
use owlen_tui::{config, ui, AppState, ChatApp, Event, EventHandler, SessionEvent};
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crossterm::{
|
||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||
execute,
|
||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
||||
};
|
||||
use futures::stream;
|
||||
use ratatui::{prelude::CrosstermBackend, Terminal};
|
||||
use owlen_core::config::McpMode;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
/// Owlen - Terminal UI for LLM chat
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -54,6 +38,14 @@ enum OwlenCommand {
|
||||
/// Manage Ollama Cloud credentials
|
||||
#[command(subcommand)]
|
||||
Cloud(CloudCommand),
|
||||
/// Manage model providers
|
||||
#[command(subcommand)]
|
||||
Providers(ProvidersCommand),
|
||||
/// List models exposed by configured providers
|
||||
Models(ModelsArgs),
|
||||
/// Manage MCP server registrations
|
||||
#[command(subcommand)]
|
||||
Mcp(McpCommand),
|
||||
/// Show manual steps for updating Owlen to the latest revision
|
||||
Upgrade,
|
||||
}
|
||||
@@ -66,72 +58,17 @@ enum ConfigCommand {
|
||||
Path,
|
||||
}
|
||||
|
||||
fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.mcp_servers.first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server)
|
||||
} else {
|
||||
RemoteMcpClient::new()
|
||||
};
|
||||
|
||||
match remote_result {
|
||||
Ok(client) => {
|
||||
let provider: Arc<dyn Provider> = Arc::new(client);
|
||||
Ok(provider)
|
||||
}
|
||||
Err(err) if cfg.mcp.allow_fallback => {
|
||||
log::warn!(
|
||||
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||
err
|
||||
);
|
||||
build_local_provider(cfg)
|
||||
}
|
||||
Err(err) => Err(anyhow::Error::from(err)),
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.mcp_servers.first().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\""
|
||||
)
|
||||
})?;
|
||||
let client = RemoteMcpClient::new_with_config(mcp_server)?;
|
||||
let provider: Arc<dyn Provider> = Arc::new(client);
|
||||
Ok(provider)
|
||||
}
|
||||
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||
McpMode::Disabled => Err(anyhow::anyhow!(
|
||||
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
let provider_name = cfg.general.default_provider.clone();
|
||||
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||
anyhow::anyhow!(format!(
|
||||
"No provider configuration found for '{provider_name}' in [providers]"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama-cloud" => {
|
||||
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(anyhow::anyhow!(format!(
|
||||
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_command(command: OwlenCommand) -> Result<()> {
|
||||
match command {
|
||||
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
|
||||
OwlenCommand::Cloud(cloud_cmd) => cloud::run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Cloud(cloud_cmd) => run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Providers(provider_cmd) => run_providers_command(provider_cmd).await,
|
||||
OwlenCommand::Models(args) => run_models_command(args).await,
|
||||
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
|
||||
OwlenCommand::Upgrade => {
|
||||
println!("To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force");
|
||||
println!(
|
||||
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
|
||||
);
|
||||
println!(
|
||||
"If you installed from the AUR, use your package manager (e.g., yay -S owlen-git)."
|
||||
);
|
||||
@@ -155,46 +92,85 @@ fn run_config_doctor() -> Result<()> {
|
||||
let config_path = core_config::default_config_path();
|
||||
let existed = config_path.exists();
|
||||
let mut config = config::try_load_config().unwrap_or_default();
|
||||
let _ = config.refresh_mcp_servers(None);
|
||||
let mut changes = Vec::new();
|
||||
|
||||
if !existed {
|
||||
changes.push("created configuration file from defaults".to_string());
|
||||
}
|
||||
|
||||
if !config
|
||||
.providers
|
||||
.contains_key(&config.general.default_provider)
|
||||
{
|
||||
config.general.default_provider = "ollama".to_string();
|
||||
changes.push("default provider missing; reset to 'ollama'".to_string());
|
||||
if config.provider(&config.general.default_provider).is_none() {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push("default provider missing; reset to 'ollama_local'".to_string());
|
||||
}
|
||||
|
||||
if let Some(mut legacy) = config.providers.remove("ollama-cloud") {
|
||||
legacy.provider_type = "ollama".to_string();
|
||||
use std::collections::hash_map::Entry;
|
||||
match config.providers.entry("ollama".to_string()) {
|
||||
Entry::Occupied(mut existing) => {
|
||||
let entry = existing.get_mut();
|
||||
if entry.api_key.is_none() {
|
||||
entry.api_key = legacy.api_key.take();
|
||||
for key in ["ollama_local", "ollama_cloud", "openai", "anthropic"] {
|
||||
if !config.providers.contains_key(key) {
|
||||
core_config::ensure_provider_config_mut(&mut config, key);
|
||||
changes.push(format!("added default configuration for provider '{key}'"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut("ollama_local") {
|
||||
if entry.provider_type.trim().is_empty() || entry.provider_type != "ollama" {
|
||||
entry.provider_type = "ollama".to_string();
|
||||
changes.push("normalised providers.ollama_local.provider_type to 'ollama'".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let mut ensure_default_enabled = true;
|
||||
|
||||
if !config.providers.values().any(|cfg| cfg.enabled) {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push("no providers were enabled; enabled 'ollama_local'".to_string());
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because no providers were enabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
ensure_default_enabled = false;
|
||||
}
|
||||
|
||||
if ensure_default_enabled {
|
||||
let default_id = config.general.default_provider.clone();
|
||||
if let Some(default_cfg) = config.providers.get(&default_id) {
|
||||
if !default_cfg.enabled {
|
||||
if let Some(new_default) = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && *id != &default_id)
|
||||
.map(|(id, _)| id.clone())
|
||||
.min()
|
||||
{
|
||||
config.general.default_provider = new_default.clone();
|
||||
changes.push(format!(
|
||||
"default provider '{default_id}' was disabled; switched default to '{new_default}'"
|
||||
));
|
||||
} else {
|
||||
let entry =
|
||||
core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push(
|
||||
"enabled 'ollama_local' because default provider was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because previous default was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
if entry.base_url.is_none() && legacy.base_url.is_some() {
|
||||
entry.base_url = legacy.base_url.take();
|
||||
}
|
||||
entry.extra.extend(legacy.extra);
|
||||
}
|
||||
Entry::Vacant(slot) => {
|
||||
slot.insert(legacy);
|
||||
}
|
||||
}
|
||||
changes.push(
|
||||
"migrated legacy 'ollama-cloud' provider into unified 'ollama' entry".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
if !config.providers.contains_key("ollama") {
|
||||
core_config::ensure_provider_config(&mut config, "ollama");
|
||||
changes.push("added default ollama provider configuration".to_string());
|
||||
}
|
||||
|
||||
match config.mcp.mode {
|
||||
@@ -203,7 +179,7 @@ fn run_config_doctor() -> Result<()> {
|
||||
config.mcp.warn_on_legacy = true;
|
||||
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
|
||||
}
|
||||
McpMode::RemoteOnly if config.mcp_servers.is_empty() => {
|
||||
McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
|
||||
config.mcp.mode = McpMode::RemotePreferred;
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
@@ -211,7 +187,9 @@ fn run_config_doctor() -> Result<()> {
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
McpMode::RemotePreferred if !config.mcp.allow_fallback && config.mcp_servers.is_empty() => {
|
||||
McpMode::RemotePreferred
|
||||
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
|
||||
{
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
|
||||
@@ -238,116 +216,6 @@ fn run_config_doctor() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TerminalColorSupport {
|
||||
Full,
|
||||
Limited { term: String },
|
||||
}
|
||||
|
||||
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term_lower = term.to_lowercase();
|
||||
let color_lower = colorterm.to_lowercase();
|
||||
|
||||
let supports_extended = term_lower.contains("256color")
|
||||
|| color_lower.contains("truecolor")
|
||||
|| color_lower.contains("24bit")
|
||||
|| color_lower.contains("fullcolor");
|
||||
|
||||
if supports_extended {
|
||||
TerminalColorSupport::Full
|
||||
} else {
|
||||
TerminalColorSupport::Limited { term }
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||
match support {
|
||||
TerminalColorSupport::Full => None,
|
||||
TerminalColorSupport::Limited { .. } => {
|
||||
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||
Some(previous)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OfflineProvider {
|
||||
reason: String,
|
||||
placeholder_model: String,
|
||||
}
|
||||
|
||||
impl OfflineProvider {
|
||||
fn new(reason: String, placeholder_model: String) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
placeholder_model,
|
||||
}
|
||||
}
|
||||
|
||||
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||
let mut message = String::new();
|
||||
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||
message.push_str(&self.reason);
|
||||
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||
message.push_str(&format!(
|
||||
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||
requested_model
|
||||
));
|
||||
}
|
||||
message.push_str(
|
||||
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::assistant(message),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OfflineProvider {
|
||||
fn name(&self) -> &str {
|
||||
"offline"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: self.placeholder_model.clone(),
|
||||
provider: "offline".to_string(),
|
||||
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||
Ok(self.friendly_response(&request.model))
|
||||
}
|
||||
|
||||
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||
let response = self.friendly_response(&request.model);
|
||||
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<(), Error> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"offline provider cannot reach any backing models"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
// Parse command-line arguments
|
||||
@@ -356,151 +224,5 @@ async fn main() -> Result<()> {
|
||||
return run_command(command).await;
|
||||
}
|
||||
let initial_mode = if code { Mode::Code } else { Mode::Chat };
|
||||
|
||||
// Set auto-consent for TUI mode to prevent blocking stdin reads
|
||||
std::env::set_var("OWLEN_AUTO_CONSENT", "1");
|
||||
|
||||
let color_support = detect_terminal_color_support();
|
||||
// Load configuration (or fall back to defaults) for the session controller.
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||
};
|
||||
eprintln!(
|
||||
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||
term_label, BASIC_THEME_NAME, previous_theme
|
||||
);
|
||||
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||
eprintln!(
|
||||
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||
term
|
||||
);
|
||||
}
|
||||
cfg.validate()?;
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||
|
||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||
|
||||
// Create provider according to MCP configuration (supports legacy/local fallback)
|
||||
let provider = build_provider(&cfg)?;
|
||||
let mut offline_notice: Option<String> = None;
|
||||
let provider = match provider.health_check().await {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.mcp_servers.is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||
};
|
||||
let notice =
|
||||
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||
eprintln!("{notice}");
|
||||
offline_notice = Some(notice.clone());
|
||||
let fallback_model = cfg
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "offline".to_string());
|
||||
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||
}
|
||||
};
|
||||
|
||||
let controller =
|
||||
SessionController::new(provider, cfg, storage.clone(), tui_controller, false).await?;
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller).await?;
|
||||
app.initialize_models().await?;
|
||||
if let Some(notice) = offline_notice {
|
||||
app.set_status_message(¬ice);
|
||||
app.set_system_status(notice);
|
||||
}
|
||||
|
||||
// Set the initial mode
|
||||
app.set_mode(initial_mode).await;
|
||||
|
||||
// Event infrastructure
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let (event_tx, event_rx) = mpsc::unbounded_channel();
|
||||
let event_handler = EventHandler::new(event_tx, cancellation_token.clone());
|
||||
let event_handle = tokio::spawn(async move { event_handler.run().await });
|
||||
|
||||
// Terminal setup
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
execute!(
|
||||
stdout,
|
||||
EnterAlternateScreen,
|
||||
EnableMouseCapture,
|
||||
EnableBracketedPaste
|
||||
)?;
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let result = run_app(&mut terminal, &mut app, event_rx, &mut session_rx).await;
|
||||
|
||||
// Shutdown
|
||||
cancellation_token.cancel();
|
||||
event_handle.await?;
|
||||
|
||||
// Persist configuration updates (e.g., selected model)
|
||||
config::save_config(&app.config())?;
|
||||
|
||||
disable_raw_mode()?;
|
||||
execute!(
|
||||
terminal.backend_mut(),
|
||||
LeaveAlternateScreen,
|
||||
DisableMouseCapture,
|
||||
DisableBracketedPaste
|
||||
)?;
|
||||
terminal.show_cursor()?;
|
||||
|
||||
if let Err(err) = result {
|
||||
println!("{err:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_app(
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
app: &mut ChatApp,
|
||||
mut event_rx: mpsc::UnboundedReceiver<Event>,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
// Advance loading animation frame
|
||||
app.advance_loading_animation();
|
||||
|
||||
terminal.draw(|f| ui::render_chat(f, app))?;
|
||||
|
||||
// Process any pending LLM requests AFTER UI has been drawn
|
||||
if let Err(e) = app.process_pending_llm_request().await {
|
||||
eprintln!("Error processing LLM request: {}", e);
|
||||
}
|
||||
|
||||
// Process any pending tool executions AFTER UI has been drawn
|
||||
if let Err(e) = app.process_pending_tool_execution().await {
|
||||
eprintln!("Error processing tool execution: {}", e);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
Some(event) = event_rx.recv() => {
|
||||
if let AppState::Quit = app.handle_event(event).await? {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Some(session_event) = session_rx.recv() => {
|
||||
app.handle_session_event(session_event)?;
|
||||
}
|
||||
// Add a timeout to keep the animation going even when there are no events
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
|
||||
// This will cause the loop to continue and advance the animation
|
||||
}
|
||||
}
|
||||
}
|
||||
bootstrap::launch(initial_mode).await
|
||||
}
|
||||
|
||||
259
crates/owlen-cli/src/mcp.rs
Normal file
259
crates/owlen-cli/src/mcp.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand, ValueEnum};
|
||||
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum McpCommand {
|
||||
/// Add or update an MCP server in the selected scope
|
||||
Add(AddArgs),
|
||||
/// List MCP servers across scopes
|
||||
List(ListArgs),
|
||||
/// Remove an MCP server from a scope
|
||||
Remove(RemoveArgs),
|
||||
}
|
||||
|
||||
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
|
||||
match command {
|
||||
McpCommand::Add(args) => handle_add(args),
|
||||
McpCommand::List(args) => handle_list(args),
|
||||
McpCommand::Remove(args) => handle_remove(args),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||
pub enum ScopeArg {
|
||||
User,
|
||||
#[default]
|
||||
Project,
|
||||
Local,
|
||||
}
|
||||
|
||||
impl From<ScopeArg> for McpConfigScope {
|
||||
fn from(value: ScopeArg) -> Self {
|
||||
match value {
|
||||
ScopeArg::User => McpConfigScope::User,
|
||||
ScopeArg::Project => McpConfigScope::Project,
|
||||
ScopeArg::Local => McpConfigScope::Local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct AddArgs {
|
||||
/// Logical name used to reference the server
|
||||
pub name: String,
|
||||
/// Command or endpoint invoked for the server
|
||||
pub command: String,
|
||||
/// Transport mechanism (stdio, http, websocket)
|
||||
#[arg(long, default_value = "stdio")]
|
||||
pub transport: String,
|
||||
/// Configuration scope to write the server into
|
||||
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
|
||||
pub scope: ScopeArg,
|
||||
/// Environment variables (KEY=VALUE) passed to the server process
|
||||
#[arg(long = "env")]
|
||||
pub env: Vec<String>,
|
||||
/// Additional arguments appended when launching the server
|
||||
#[arg(trailing_var_arg = true, value_name = "ARG")]
|
||||
pub args: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Default)]
|
||||
pub struct ListArgs {
|
||||
/// Restrict output to a specific configuration scope
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
/// Display only the effective servers (after precedence resolution)
|
||||
#[arg(long)]
|
||||
pub effective_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct RemoveArgs {
|
||||
/// Name of the server to remove
|
||||
pub name: String,
|
||||
/// Optional explicit scope to remove from
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
}
|
||||
|
||||
fn handle_add(args: AddArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope: McpConfigScope = args.scope.into();
|
||||
let mut env_map = HashMap::new();
|
||||
for pair in &args.env {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
|
||||
if key.trim().is_empty() {
|
||||
return Err(anyhow!("Environment variable name cannot be empty"));
|
||||
}
|
||||
env_map.insert(key.trim().to_string(), value.to_string());
|
||||
}
|
||||
|
||||
let server = McpServerConfig {
|
||||
name: args.name.clone(),
|
||||
command: args.command.clone(),
|
||||
args: args.args.clone(),
|
||||
transport: args.transport.to_lowercase(),
|
||||
env: env_map,
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
config.add_mcp_server(scope, server.clone(), None)?;
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope ({})",
|
||||
server.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope.",
|
||||
server.name, scope
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_list(args: ListArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
|
||||
let scoped = config.scoped_mcp_servers();
|
||||
if scoped.is_empty() {
|
||||
println!("No MCP servers configured.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let filter_scope = args.scope.map(|scope| scope.into());
|
||||
let effective = config.effective_mcp_servers();
|
||||
let mut active = HashSet::new();
|
||||
for server in effective {
|
||||
active.insert((
|
||||
server.name.clone(),
|
||||
server.command.clone(),
|
||||
server.transport.to_lowercase(),
|
||||
));
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<2} {:<8} {:<20} {:<10} Command",
|
||||
"", "Scope", "Name", "Transport"
|
||||
);
|
||||
for entry in scoped {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let payload = format_command_line(&entry.config.command, &entry.config.args);
|
||||
let key = (
|
||||
entry.config.name.clone(),
|
||||
entry.config.command.clone(),
|
||||
entry.config.transport.to_lowercase(),
|
||||
);
|
||||
let marker = if active.contains(&key) { "*" } else { " " };
|
||||
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!(
|
||||
"{} {:<8} {:<20} {:<10} {}",
|
||||
marker, entry.scope, entry.config.name, entry.config.transport, payload
|
||||
);
|
||||
}
|
||||
|
||||
let scoped_resources = config.scoped_mcp_resources();
|
||||
if !scoped_resources.is_empty() {
|
||||
println!();
|
||||
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
|
||||
let effective_keys: HashSet<(String, String)> = config
|
||||
.effective_mcp_resources()
|
||||
.iter()
|
||||
.map(|res| (res.server.clone(), res.uri.clone()))
|
||||
.collect();
|
||||
|
||||
for entry in scoped_resources {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||
let marker = if effective_keys.contains(&key) {
|
||||
"*"
|
||||
} else {
|
||||
" "
|
||||
};
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
|
||||
let title = entry.config.title.as_deref().unwrap_or("—");
|
||||
|
||||
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_remove(args: RemoveArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope_hint = args.scope.map(|scope| scope.into());
|
||||
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
|
||||
|
||||
match result {
|
||||
Some(scope) => {
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Removed MCP server '{}' from {} scope ({})",
|
||||
args.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
println!("No MCP server named '{}' was found.", args.name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn format_command_line(command: &str, args: &[String]) -> String {
|
||||
if args.is_empty() {
|
||||
command.to_string()
|
||||
} else {
|
||||
format!("{} {}", command, args.join(" "))
|
||||
}
|
||||
}
|
||||
@@ -50,3 +50,4 @@ ollama-rs = { version = "0.3", features = ["stream", "headers"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
httpmock = "0.7"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
//! This module provides the core agent orchestration logic that allows an LLM
|
||||
//! to reason about tasks, execute tools, and observe results in an iterative loop.
|
||||
|
||||
use crate::Provider;
|
||||
use crate::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::provider::Provider;
|
||||
use crate::types::{ChatParameters, ChatRequest, Message};
|
||||
use crate::{Error, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -189,7 +189,7 @@ impl AgentExecutor {
|
||||
fn build_system_prompt(&self, tools: &[McpToolDescriptor]) -> String {
|
||||
let mut prompt = String::from(
|
||||
"You are an AI assistant that uses the ReAct (Reasoning and Acting) pattern to solve tasks.\n\n\
|
||||
You have access to the following tools:\n\n"
|
||||
You have access to the following tools:\n\n",
|
||||
);
|
||||
|
||||
for tool in tools {
|
||||
@@ -230,7 +230,7 @@ impl AgentExecutor {
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self.llm_client.chat(request).await?;
|
||||
let response = self.llm_client.send_prompt(request).await?;
|
||||
Ok(response.message.content)
|
||||
}
|
||||
|
||||
@@ -364,13 +364,13 @@ impl AgentExecutor {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::test_utils::MockProvider;
|
||||
use crate::mcp::test_utils::MockMcpClient;
|
||||
use crate::provider::test_utils::MockProvider;
|
||||
|
||||
#[test]
|
||||
fn test_parse_tool_call() {
|
||||
let executor = AgentExecutor {
|
||||
llm_client: Arc::new(MockProvider),
|
||||
llm_client: Arc::new(MockProvider::default()),
|
||||
tool_client: Arc::new(MockMcpClient),
|
||||
config: AgentConfig::default(),
|
||||
};
|
||||
@@ -399,7 +399,7 @@ ACTION_INPUT: {"query": "Rust programming language"}
|
||||
#[test]
|
||||
fn test_parse_final_answer() {
|
||||
let executor = AgentExecutor {
|
||||
llm_client: Arc::new(MockProvider),
|
||||
llm_client: Arc::new(MockProvider::default()),
|
||||
tool_client: Arc::new(MockMcpClient),
|
||||
config: AgentConfig::default(),
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -58,17 +58,21 @@ impl ConsentManager {
|
||||
/// Load consent records from vault storage
|
||||
pub fn from_vault(vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Self {
|
||||
let guard = vault.lock().expect("Vault mutex poisoned");
|
||||
if let Some(consent_data) = guard.settings().get("consent_records") {
|
||||
if let Ok(permanent_records) =
|
||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||
{
|
||||
return Self {
|
||||
permanent_records,
|
||||
session_records: HashMap::new(),
|
||||
once_records: HashMap::new(),
|
||||
pending_requests: HashMap::new(),
|
||||
};
|
||||
}
|
||||
if let Some(permanent_records) =
|
||||
guard
|
||||
.settings()
|
||||
.get("consent_records")
|
||||
.and_then(|consent_data| {
|
||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||
.ok()
|
||||
})
|
||||
{
|
||||
return Self {
|
||||
permanent_records,
|
||||
session_records: HashMap::new(),
|
||||
once_records: HashMap::new(),
|
||||
pending_requests: HashMap::new(),
|
||||
};
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
@@ -91,17 +95,21 @@ impl ConsentManager {
|
||||
endpoints: Vec<String>,
|
||||
) -> Result<ConsentScope> {
|
||||
// Check if already granted permanently
|
||||
if let Some(existing) = self.permanent_records.get(tool_name) {
|
||||
if existing.scope == ConsentScope::Permanent {
|
||||
return Ok(ConsentScope::Permanent);
|
||||
}
|
||||
if self
|
||||
.permanent_records
|
||||
.get(tool_name)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
||||
{
|
||||
return Ok(ConsentScope::Permanent);
|
||||
}
|
||||
|
||||
// Check if granted for session
|
||||
if let Some(existing) = self.session_records.get(tool_name) {
|
||||
if existing.scope == ConsentScope::Session {
|
||||
return Ok(ConsentScope::Session);
|
||||
}
|
||||
if self
|
||||
.session_records
|
||||
.get(tool_name)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
||||
{
|
||||
return Ok(ConsentScope::Session);
|
||||
}
|
||||
|
||||
// Check if request is already pending (prevent duplicate prompts)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::Result;
|
||||
use crate::storage::StorageManager;
|
||||
use crate::types::{Conversation, Message};
|
||||
use crate::Result;
|
||||
use serde_json::{Number, Value};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -213,6 +213,34 @@ impl ConversationManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cancel_stream(&mut self, message_id: Uuid, notice: impl Into<String>) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||
message.content = notice.into();
|
||||
message.timestamp = std::time::SystemTime::now();
|
||||
message
|
||||
.metadata
|
||||
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||
message.metadata.remove(PLACEHOLDER_FLAG);
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
message.metadata.insert(
|
||||
LAST_CHUNK_TS.to_string(),
|
||||
Value::Number(Number::from(millis)),
|
||||
);
|
||||
}
|
||||
|
||||
self.streaming.remove(&message_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set tool calls on a streaming message
|
||||
pub fn set_tool_calls_on_message(
|
||||
&mut self,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{storage::StorageManager, Error, Result};
|
||||
use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ApiCredentials {
|
||||
@@ -31,6 +31,10 @@ impl CredentialManager {
|
||||
format!("{}_{}", self.namespace, tool_name)
|
||||
}
|
||||
|
||||
fn oauth_storage_key(&self, resource: &str) -> String {
|
||||
self.namespaced_key(&format!("oauth_{resource}"))
|
||||
}
|
||||
|
||||
pub async fn store_credentials(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
@@ -68,4 +72,37 @@ impl CredentialManager {
|
||||
let key = self.namespaced_key(tool_name);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
|
||||
pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let payload = serde_json::to_vec(token).map_err(|err| {
|
||||
Error::Storage(format!(
|
||||
"Failed to serialize OAuth token for secure storage: {err}"
|
||||
))
|
||||
})?;
|
||||
self.storage
|
||||
.store_secure_item(&key, &payload, &self.master_key)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn load_oauth_token(&self, resource: &str) -> Result<Option<OAuthToken>> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let raw = self
|
||||
.storage
|
||||
.load_secure_item(&key, &self.master_key)
|
||||
.await?;
|
||||
if let Some(bytes) = raw {
|
||||
let token = serde_json::from_slice(&bytes).map_err(|err| {
|
||||
Error::Storage(format!("Failed to deserialize stored OAuth token: {err}"))
|
||||
})?;
|
||||
Ok(Some(token))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{Aead, KeyInit},
|
||||
Aes256Gcm, Nonce,
|
||||
aead::{Aead, KeyInit},
|
||||
};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use ring::digest;
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
32
crates/owlen-core/src/facade/llm_client.rs
Normal file
32
crates/owlen-core/src/facade/llm_client.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
Result,
|
||||
llm::ChatStream,
|
||||
mcp::{McpToolCall, McpToolDescriptor, McpToolResponse},
|
||||
types::{ChatRequest, ChatResponse, ModelInfo},
|
||||
};
|
||||
|
||||
/// Object-safe facade for interacting with LLM backends.
|
||||
#[async_trait]
|
||||
pub trait LlmClient: Send + Sync {
|
||||
/// List the models exposed by this client.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Issue a one-shot chat request and wait for the complete response.
|
||||
async fn send_chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||
|
||||
/// Stream chat responses incrementally.
|
||||
async fn stream_chat(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||
|
||||
/// Enumerate tools exposed by the backing provider.
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>>;
|
||||
|
||||
/// Invoke a tool exposed by the provider.
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||
}
|
||||
|
||||
/// Convenience alias for trait-object clients.
|
||||
pub type DynLlmClient = Arc<dyn LlmClient>;
|
||||
1
crates/owlen-core/src/facade/mod.rs
Normal file
1
crates/owlen-core/src/facade/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod llm_client;
|
||||
@@ -1,19 +1,20 @@
|
||||
use crate::types::Message;
|
||||
use crate::ui::RoleLabelDisplay;
|
||||
|
||||
/// Formats messages for display across different clients.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MessageFormatter {
|
||||
wrap_width: usize,
|
||||
show_role_labels: bool,
|
||||
role_label_mode: RoleLabelDisplay,
|
||||
preserve_empty_lines: bool,
|
||||
}
|
||||
|
||||
impl MessageFormatter {
|
||||
/// Create a new formatter
|
||||
pub fn new(wrap_width: usize, show_role_labels: bool) -> Self {
|
||||
pub fn new(wrap_width: usize, role_label_mode: RoleLabelDisplay) -> Self {
|
||||
Self {
|
||||
wrap_width: wrap_width.max(20),
|
||||
show_role_labels,
|
||||
role_label_mode,
|
||||
preserve_empty_lines: false,
|
||||
}
|
||||
}
|
||||
@@ -29,9 +30,19 @@ impl MessageFormatter {
|
||||
self.wrap_width = width.max(20);
|
||||
}
|
||||
|
||||
/// Whether role labels should be shown alongside messages
|
||||
/// The configured role label layout preference.
|
||||
pub fn role_label_mode(&self) -> RoleLabelDisplay {
|
||||
self.role_label_mode
|
||||
}
|
||||
|
||||
/// Whether any role label should be shown alongside messages.
|
||||
pub fn show_role_labels(&self) -> bool {
|
||||
self.show_role_labels
|
||||
!matches!(self.role_label_mode, RoleLabelDisplay::None)
|
||||
}
|
||||
|
||||
/// Update the role label layout preference.
|
||||
pub fn set_role_label_mode(&mut self, mode: RoleLabelDisplay) {
|
||||
self.role_label_mode = mode;
|
||||
}
|
||||
|
||||
pub fn format_message(&self, message: &Message) -> Vec<String> {
|
||||
|
||||
@@ -191,6 +191,12 @@ impl InputBuffer {
|
||||
self.history.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear saved input history entries.
|
||||
pub fn clear_history(&mut self) {
|
||||
self.history.clear();
|
||||
self.history_index = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn prev_char_boundary(buffer: &str, cursor: usize) -> usize {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::collapsible_if)] // TODO: Remove once we can rely on Rust 2024 let-chains
|
||||
|
||||
//! Core traits and types for OWLEN LLM client
|
||||
//!
|
||||
//! This crate provides the foundational abstractions for building
|
||||
@@ -9,16 +11,20 @@ pub mod consent;
|
||||
pub mod conversation;
|
||||
pub mod credentials;
|
||||
pub mod encryption;
|
||||
pub mod facade;
|
||||
pub mod formatting;
|
||||
pub mod input;
|
||||
pub mod llm;
|
||||
pub mod mcp;
|
||||
pub mod mode;
|
||||
pub mod model;
|
||||
pub mod oauth;
|
||||
pub mod provider;
|
||||
pub mod providers;
|
||||
pub mod router;
|
||||
pub mod sandbox;
|
||||
pub mod session;
|
||||
pub mod state;
|
||||
pub mod storage;
|
||||
pub mod theme;
|
||||
pub mod tools;
|
||||
@@ -35,19 +41,24 @@ pub use credentials::*;
|
||||
pub use encryption::*;
|
||||
pub use formatting::*;
|
||||
pub use input::*;
|
||||
pub use oauth::*;
|
||||
// Export MCP types but exclude test_utils to avoid ambiguity
|
||||
pub use facade::llm_client::*;
|
||||
pub use llm::{
|
||||
ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream,
|
||||
};
|
||||
pub use mcp::{
|
||||
client, factory, failover, permission, protocol, remote_client, LocalMcpClient, McpServer,
|
||||
McpToolCall, McpToolDescriptor, McpToolResponse,
|
||||
LocalMcpClient, McpServer, McpToolCall, McpToolDescriptor, McpToolResponse, client, factory,
|
||||
failover, permission, protocol, remote_client,
|
||||
};
|
||||
pub use mode::*;
|
||||
pub use model::*;
|
||||
// Export provider types but exclude test_utils to avoid ambiguity
|
||||
pub use provider::{ChatStream, LLMProvider, Provider, ProviderConfig, ProviderRegistry};
|
||||
pub use provider::*;
|
||||
pub use providers::*;
|
||||
pub use router::*;
|
||||
pub use sandbox::*;
|
||||
pub use session::*;
|
||||
pub use state::*;
|
||||
pub use theme::*;
|
||||
pub use tools::*;
|
||||
pub use validation::*;
|
||||
|
||||
337
crates/owlen-core/src/llm/mod.rs
Normal file
337
crates/owlen-core/src/llm/mod.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! LLM provider abstractions and registry.
|
||||
//!
|
||||
//! This module defines the provider trait hierarchy along with helpers that
|
||||
//! make it easy to register concrete LLM backends and access them through
|
||||
//! dynamic dispatch when wiring the application together.
|
||||
|
||||
use crate::{Error, Result, types::*};
|
||||
use anyhow::anyhow;
|
||||
use futures::{Stream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A boxed stream of chat responses produced by a provider.
|
||||
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
||||
|
||||
/// Trait implemented by every LLM backend Owlen can speak to.
|
||||
///
|
||||
/// Providers expose both one-shot and streaming prompt APIs. Concrete
|
||||
/// implementations typically live in `crate::providers`.
|
||||
pub trait LlmProvider: Send + Sync + 'static + Any + Sized {
|
||||
/// Stream type returned by [`Self::stream_prompt`].
|
||||
type Stream: Stream<Item = Result<ChatResponse>> + Send + 'static;
|
||||
|
||||
type ListModelsFuture<'a>: Future<Output = Result<Vec<ModelInfo>>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type SendPromptFuture<'a>: Future<Output = Result<ChatResponse>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type StreamPromptFuture<'a>: Future<Output = Result<Self::Stream>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type HealthCheckFuture<'a>: Future<Output = Result<()>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
/// Human-readable provider identifier.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Return metadata on all models exposed by this provider.
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_>;
|
||||
|
||||
/// Issue a prompt and wait for the provider to return the full response.
|
||||
fn send_prompt(&self, request: ChatRequest) -> Self::SendPromptFuture<'_>;
|
||||
|
||||
/// Issue a prompt and receive responses incrementally as a stream.
|
||||
fn stream_prompt(&self, request: ChatRequest) -> Self::StreamPromptFuture<'_>;
|
||||
|
||||
/// Perform a lightweight health check.
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_>;
|
||||
|
||||
/// Provider-specific configuration schema (optional).
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
|
||||
/// Access the provider as an `Any` for downcasting.
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that requests a streamed generation and yields the first chunk as a
|
||||
/// regular response. This is handy for providers that only implement the
|
||||
/// streaming API.
|
||||
pub async fn send_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result<ChatResponse>
|
||||
where
|
||||
P: LlmProvider + 'a,
|
||||
{
|
||||
let stream = provider.stream_prompt(request).await?;
|
||||
let mut boxed: ChatStream = Box::pin(stream);
|
||||
match boxed.next().await {
|
||||
Some(Ok(response)) => Ok(response),
|
||||
Some(Err(err)) => Err(err),
|
||||
None => Err(Error::Provider(anyhow!(
|
||||
"Empty chat stream from provider {}",
|
||||
provider.name()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Object-safe wrapper around [`LlmProvider`] for dynamic dispatch scenarios.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||
|
||||
async fn health_check(&self) -> Result<()>;
|
||||
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync);
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T> Provider for T
|
||||
where
|
||||
T: LlmProvider,
|
||||
{
|
||||
fn name(&self) -> &str {
|
||||
LlmProvider::name(self)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
LlmProvider::list_models(self).await
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||
LlmProvider::send_prompt(self, request).await
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||
let stream = LlmProvider::stream_prompt(self, request).await?;
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<()> {
|
||||
LlmProvider::health_check(self).await
|
||||
}
|
||||
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
LlmProvider::config_schema(self)
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
LlmProvider::as_any(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Runtime configuration for a provider instance.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
/// Whether this provider should be activated.
|
||||
#[serde(default = "ProviderConfig::default_enabled")]
|
||||
pub enabled: bool,
|
||||
/// Provider type identifier used to resolve implementations.
|
||||
#[serde(default)]
|
||||
pub provider_type: String,
|
||||
/// Base URL for API calls.
|
||||
#[serde(default)]
|
||||
pub base_url: Option<String>,
|
||||
/// API key or token material.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Environment variable holding the API key.
|
||||
#[serde(default)]
|
||||
pub api_key_env: Option<String>,
|
||||
/// Additional provider-specific configuration.
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl ProviderConfig {
|
||||
const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Merge the current configuration with overrides from `other`.
|
||||
pub fn merge_from(&mut self, mut other: ProviderConfig) {
|
||||
self.enabled = other.enabled;
|
||||
|
||||
if !other.provider_type.is_empty() {
|
||||
self.provider_type = other.provider_type;
|
||||
}
|
||||
|
||||
if let Some(base_url) = other.base_url.take() {
|
||||
self.base_url = Some(base_url);
|
||||
}
|
||||
|
||||
if let Some(api_key) = other.api_key.take() {
|
||||
self.api_key = Some(api_key);
|
||||
}
|
||||
|
||||
if let Some(api_key_env) = other.api_key_env.take() {
|
||||
self.api_key_env = Some(api_key_env);
|
||||
}
|
||||
|
||||
if !other.extra.is_empty() {
|
||||
self.extra.extend(other.extra);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Static registry of providers available to the application.
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Arc<dyn Provider>>,
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<P: LlmProvider + 'static>(&mut self, provider: P) {
|
||||
self.register_arc(Arc::new(provider));
|
||||
}
|
||||
|
||||
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
|
||||
let name = provider.name().to_string();
|
||||
self.providers.insert(name, provider);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn list_providers(&self) -> Vec<String> {
|
||||
self.providers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let mut all_models = Vec::new();
|
||||
|
||||
for provider in self.providers.values() {
|
||||
match provider.list_models().await {
|
||||
Ok(mut models) => all_models.append(&mut models),
|
||||
Err(_) => {
|
||||
// Ignore failing providers and continue.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_models)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Test utilities for constructing mock providers.
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use super::*;
|
||||
use futures::stream;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Simple provider stub that always returns the same response.
|
||||
pub struct MockProvider {
|
||||
name: String,
|
||||
response: ChatResponse,
|
||||
call_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
pub fn new(name: impl Into<String>, response: ChatResponse) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
response,
|
||||
call_count: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockProvider {
|
||||
fn default() -> Self {
|
||||
Self::new(
|
||||
"mock-provider",
|
||||
ChatResponse {
|
||||
message: Message::assistant("mock response".to_string()),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider for MockProvider {
|
||||
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||
|
||||
type ListModelsFuture<'a>
|
||||
= futures::future::Ready<Result<Vec<ModelInfo>>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type SendPromptFuture<'a>
|
||||
= futures::future::Ready<Result<ChatResponse>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type StreamPromptFuture<'a>
|
||||
= futures::future::Ready<Result<Self::Stream>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type HealthCheckFuture<'a>
|
||||
= futures::future::Ready<Result<()>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||
futures::future::ready(Ok(vec![]))
|
||||
}
|
||||
|
||||
fn send_prompt(&self, _request: ChatRequest) -> Self::SendPromptFuture<'_> {
|
||||
self.call_count.fetch_add(1, Ordering::Relaxed);
|
||||
futures::future::ready(Ok(self.response.clone()))
|
||||
}
|
||||
|
||||
fn stream_prompt(&self, _request: ChatRequest) -> Self::StreamPromptFuture<'_> {
|
||||
self.call_count.fetch_add(1, Ordering::Relaxed);
|
||||
let response = self.response.clone();
|
||||
futures::future::ready(Ok(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||
futures::future::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::Result;
|
||||
use crate::mode::Mode;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
use crate::Result;
|
||||
use async_trait::async_trait;
|
||||
pub use client::McpClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -142,6 +142,11 @@ impl McpClient for LocalMcpClient {
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
self.server.call_tool(call).await
|
||||
}
|
||||
|
||||
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
||||
self.server.set_mode(mode).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::Result;
|
||||
use crate::{Result, mode::Mode};
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Trait for a client that can interact with an MCP server
|
||||
@@ -10,6 +10,11 @@ pub trait McpClient: Send + Sync {
|
||||
|
||||
/// Call a tool on the server
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||
|
||||
/// Update the server with the active operating mode.
|
||||
async fn set_mode(&self, _mode: Mode) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export the concrete implementation that supports stdio and HTTP transports.
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
/// Provides a unified interface for creating MCP clients based on configuration.
|
||||
/// Supports switching between local (in-process) and remote (STDIO) execution modes.
|
||||
use super::client::McpClient;
|
||||
use super::{remote_client::RemoteMcpClient, LocalMcpClient};
|
||||
use super::{
|
||||
LocalMcpClient,
|
||||
remote_client::{McpRuntimeSecrets, RemoteMcpClient},
|
||||
};
|
||||
use crate::config::{Config, McpMode};
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
@@ -33,6 +36,14 @@ impl McpClientFactory {
|
||||
|
||||
/// Create an MCP client based on the current configuration.
|
||||
pub fn create(&self) -> Result<Box<dyn McpClient>> {
|
||||
self.create_with_secrets(None)
|
||||
}
|
||||
|
||||
/// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides).
|
||||
pub fn create_with_secrets(
|
||||
&self,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Box<dyn McpClient>> {
|
||||
match self.config.mcp.mode {
|
||||
McpMode::Disabled => Err(Error::Config(
|
||||
"MCP mode is set to 'disabled'; tooling cannot function in this configuration."
|
||||
@@ -48,14 +59,14 @@ impl McpClientFactory {
|
||||
)))
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let server_cfg = self.config.mcp_servers.first().ok_or_else(|| {
|
||||
let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| {
|
||||
Error::Config(
|
||||
"MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
RemoteMcpClient::new_with_config(server_cfg)
|
||||
RemoteMcpClient::new_with_runtime(server_cfg, runtime)
|
||||
.map(|client| Box::new(client) as Box<dyn McpClient>)
|
||||
.map_err(|e| {
|
||||
Error::Config(format!(
|
||||
@@ -65,8 +76,8 @@ impl McpClientFactory {
|
||||
})
|
||||
}
|
||||
McpMode::RemotePreferred => {
|
||||
if let Some(server_cfg) = self.config.mcp_servers.first() {
|
||||
match RemoteMcpClient::new_with_config(server_cfg) {
|
||||
if let Some(server_cfg) = self.config.effective_mcp_servers().first() {
|
||||
match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()) {
|
||||
Ok(client) => {
|
||||
info!(
|
||||
"Connected to remote MCP server '{}' via {} transport.",
|
||||
@@ -109,8 +120,8 @@ impl McpClientFactory {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::McpServerConfig;
|
||||
use crate::Error;
|
||||
use crate::config::McpServerConfig;
|
||||
|
||||
fn build_factory(config: Config) -> McpClientFactory {
|
||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||
@@ -125,7 +136,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_local_client_when_no_servers_configured() {
|
||||
let config = Config::default();
|
||||
let mut config = Config::default();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
|
||||
@@ -139,6 +151,7 @@ mod tests {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::RemoteOnly;
|
||||
config.mcp_servers.clear();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create();
|
||||
@@ -156,7 +169,9 @@ mod tests {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
}];
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create();
|
||||
|
||||
@@ -305,6 +305,7 @@ mod tests {
|
||||
args: vec![],
|
||||
transport: "http".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
if let Ok(client) = RemoteMcpClient::new_with_config(&config) {
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
/// It wraps MCP clients to filter/whitelist tool calls, log invocations, and prompt for consent.
|
||||
use super::client::McpClient;
|
||||
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::config::Config;
|
||||
use crate::{Error, Result};
|
||||
use crate::{config::Config, mode::Mode};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
@@ -145,6 +145,10 @@ impl McpClient for PermissionLayer {
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
||||
self.inner.set_mode(mode).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -152,13 +156,14 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::mcp::LocalMcpClient;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::ui::NoOpUiController;
|
||||
use crate::validation::SchemaValidator;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permission_layer_filters_dangerous_tools() {
|
||||
let config = Arc::new(Config::default());
|
||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||
ui,
|
||||
@@ -182,7 +187,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_consent_callback_is_invoked() {
|
||||
let config = Arc::new(Config::default());
|
||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||
ui,
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
use super::protocol::methods;
|
||||
use super::protocol::{
|
||||
RequestId, RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse, PROTOCOL_VERSION,
|
||||
PROTOCOL_VERSION, RequestId, RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse,
|
||||
};
|
||||
use super::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::consent::{ConsentManager, ConsentScope};
|
||||
use crate::tools::{Tool, WebScrapeTool, WebSearchTool};
|
||||
use crate::types::ModelInfo;
|
||||
use crate::types::{ChatResponse, Message, Role};
|
||||
use crate::{provider::chat_via_stream, Error, LLMProvider, Result};
|
||||
use futures::{future::BoxFuture, stream, StreamExt};
|
||||
use crate::{
|
||||
ChatStream, Error, LlmProvider, Result, facade::llm_client::LlmClient, mode::Mode,
|
||||
send_via_stream,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use futures::{StreamExt, future::BoxFuture, stream};
|
||||
use reqwest::Client as HttpClient;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
|
||||
use tungstenite::protocol::Message as WsMessage;
|
||||
|
||||
/// Client that talks to the external `owlen-mcp-server` over STDIO, HTTP, or WebSocket.
|
||||
@@ -38,6 +43,15 @@ pub struct RemoteMcpClient {
|
||||
ws_endpoint: Option<String>,
|
||||
// Incrementing request identifier.
|
||||
next_id: AtomicU64,
|
||||
// Optional HTTP header (name, value) injected into every request.
|
||||
http_header: Option<(String, String)>,
|
||||
}
|
||||
|
||||
/// Runtime secrets provided when constructing an MCP client.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct McpRuntimeSecrets {
|
||||
pub env_overrides: HashMap<String, String>,
|
||||
pub http_header: Option<(String, String)>,
|
||||
}
|
||||
|
||||
impl RemoteMcpClient {
|
||||
@@ -47,6 +61,14 @@ impl RemoteMcpClient {
|
||||
/// Spawn an external MCP server based on a configuration entry.
|
||||
/// The server must communicate over STDIO (the only supported transport).
|
||||
pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> {
|
||||
Self::new_with_runtime(config, None)
|
||||
}
|
||||
|
||||
pub fn new_with_runtime(
|
||||
config: &crate::config::McpServerConfig,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Self> {
|
||||
let mut runtime = runtime.unwrap_or_default();
|
||||
let transport = config.transport.to_lowercase();
|
||||
match transport.as_str() {
|
||||
"stdio" => {
|
||||
@@ -63,6 +85,9 @@ impl RemoteMcpClient {
|
||||
for (k, v) in config.env.iter() {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
for (k, v) in runtime.env_overrides.drain() {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
|
||||
let mut child = cmd.spawn().map_err(|e| {
|
||||
Error::Io(std::io::Error::new(
|
||||
@@ -91,6 +116,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: None,
|
||||
ws_endpoint: None,
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: None,
|
||||
})
|
||||
}
|
||||
"http" => {
|
||||
@@ -108,6 +134,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: None,
|
||||
ws_endpoint: None,
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: runtime.http_header.take(),
|
||||
})
|
||||
}
|
||||
"websocket" => {
|
||||
@@ -131,6 +158,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: Some(Arc::new(Mutex::new(ws_stream))),
|
||||
ws_endpoint: Some(ws_url),
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: runtime.http_header.take(),
|
||||
})
|
||||
}
|
||||
other => Err(Error::NotImplemented(format!(
|
||||
@@ -170,6 +198,7 @@ impl RemoteMcpClient {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
Self::new_with_config(&config)
|
||||
}
|
||||
@@ -192,8 +221,11 @@ impl RemoteMcpClient {
|
||||
.http_endpoint
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?;
|
||||
let resp = client
|
||||
.post(endpoint)
|
||||
let mut builder = client.post(endpoint);
|
||||
if let Some((ref header_name, ref header_value)) = self.http_header {
|
||||
builder = builder.header(header_name, header_value);
|
||||
}
|
||||
let resp = builder
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
@@ -203,10 +235,10 @@ impl RemoteMcpClient {
|
||||
.await
|
||||
.map_err(|e| Error::Network(e.to_string()))?;
|
||||
// Try to parse as success then error.
|
||||
if let Ok(r) = serde_json::from_str::<RpcResponse>(&text) {
|
||||
if r.id == id {
|
||||
return Ok(r.result);
|
||||
}
|
||||
if let Ok(r) = serde_json::from_str::<RpcResponse>(&text)
|
||||
&& r.id == id
|
||||
{
|
||||
return Ok(r.result);
|
||||
}
|
||||
let err_resp: RpcErrorResponse =
|
||||
serde_json::from_str(&text).map_err(Error::Serialization)?;
|
||||
@@ -249,10 +281,10 @@ impl RemoteMcpClient {
|
||||
};
|
||||
|
||||
// Try to parse as success then error.
|
||||
if let Ok(r) = serde_json::from_str::<RpcResponse>(&response_text) {
|
||||
if r.id == id {
|
||||
return Ok(r.result);
|
||||
}
|
||||
if let Ok(r) = serde_json::from_str::<RpcResponse>(&response_text)
|
||||
&& r.id == id
|
||||
{
|
||||
return Ok(r.result);
|
||||
}
|
||||
let err_resp: RpcErrorResponse =
|
||||
serde_json::from_str(&response_text).map_err(Error::Serialization)?;
|
||||
@@ -416,7 +448,9 @@ impl McpClient for RemoteMcpClient {
|
||||
// Auto‑grant consent for the web_search tool (permanent for this process).
|
||||
let consent_manager = std::sync::Arc::new(std::sync::Mutex::new(ConsentManager::new()));
|
||||
{
|
||||
let mut cm = consent_manager.lock().unwrap();
|
||||
let mut cm = consent_manager
|
||||
.lock()
|
||||
.map_err(|_| Error::Provider(anyhow!("Consent manager mutex poisoned")))?;
|
||||
cm.grant_consent_with_scope(
|
||||
"web_search",
|
||||
Vec::new(),
|
||||
@@ -459,17 +493,22 @@ impl McpClient for RemoteMcpClient {
|
||||
let response: McpToolResponse = serde_json::from_value(result)?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn set_mode(&self, _mode: Mode) -> Result<()> {
|
||||
// Remote servers manage their own mode settings; treat as best-effort no-op.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Provider implementation – forwards chat requests to the generate_text tool.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
impl LLMProvider for RemoteMcpClient {
|
||||
impl LlmProvider for RemoteMcpClient {
|
||||
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||
type ListModelsFuture<'a> = BoxFuture<'a, Result<Vec<ModelInfo>>>;
|
||||
type ChatFuture<'a> = BoxFuture<'a, Result<ChatResponse>>;
|
||||
type ChatStreamFuture<'a> = BoxFuture<'a, Result<Self::Stream>>;
|
||||
type SendPromptFuture<'a> = BoxFuture<'a, Result<ChatResponse>>;
|
||||
type StreamPromptFuture<'a> = BoxFuture<'a, Result<Self::Stream>>;
|
||||
type HealthCheckFuture<'a> = BoxFuture<'a, Result<()>>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
@@ -484,11 +523,11 @@ impl LLMProvider for RemoteMcpClient {
|
||||
})
|
||||
}
|
||||
|
||||
fn chat(&self, request: crate::types::ChatRequest) -> Self::ChatFuture<'_> {
|
||||
Box::pin(chat_via_stream(self, request))
|
||||
fn send_prompt(&self, request: crate::types::ChatRequest) -> Self::SendPromptFuture<'_> {
|
||||
Box::pin(send_via_stream(self, request))
|
||||
}
|
||||
|
||||
fn chat_stream(&self, request: crate::types::ChatRequest) -> Self::ChatStreamFuture<'_> {
|
||||
fn stream_prompt(&self, request: crate::types::ChatRequest) -> Self::StreamPromptFuture<'_> {
|
||||
Box::pin(async move {
|
||||
let args = serde_json::json!({
|
||||
"messages": request.messages,
|
||||
@@ -528,3 +567,27 @@ impl LLMProvider for RemoteMcpClient {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LlmClient for RemoteMcpClient {
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
<Self as LlmProvider>::list_models(self).await
|
||||
}
|
||||
|
||||
async fn send_chat(&self, request: crate::types::ChatRequest) -> Result<ChatResponse> {
|
||||
<Self as LlmProvider>::send_prompt(self, request).await
|
||||
}
|
||||
|
||||
async fn stream_chat(&self, request: crate::types::ChatRequest) -> Result<ChatStream> {
|
||||
let stream = <Self as LlmProvider>::stream_prompt(self, request).await?;
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
<Self as McpClient>::list_tools(self).await
|
||||
}
|
||||
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
<Self as McpClient>::call_tool(self, call).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
use crate::types::ModelInfo;
|
||||
pub mod details;
|
||||
|
||||
pub use details::{DetailedModelInfo, ModelInfoRetrievalError};
|
||||
|
||||
use crate::Result;
|
||||
use crate::types::ModelInfo;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -37,10 +42,8 @@ impl ModelManager {
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Result<Vec<ModelInfo>>>,
|
||||
{
|
||||
if !force_refresh {
|
||||
if let Some(models) = self.cached_if_fresh().await {
|
||||
return Ok(models);
|
||||
}
|
||||
if let (false, Some(models)) = (force_refresh, self.cached_if_fresh().await) {
|
||||
return Ok(models);
|
||||
}
|
||||
|
||||
let models = fetcher().await?;
|
||||
@@ -82,3 +85,125 @@ impl ModelManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ModelDetailsCacheInner {
|
||||
by_key: HashMap<String, DetailedModelInfo>,
|
||||
name_to_key: HashMap<String, String>,
|
||||
fetched_at: HashMap<String, Instant>,
|
||||
}
|
||||
|
||||
/// Cache for rich model details, indexed by digest when available.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelDetailsCache {
|
||||
inner: Arc<RwLock<ModelDetailsCacheInner>>,
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl ModelDetailsCache {
|
||||
/// Create a new details cache with the provided TTL.
|
||||
pub fn new(ttl: Duration) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(ModelDetailsCacheInner::default())),
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to read cached details for the provided model name.
|
||||
pub async fn get(&self, name: &str) -> Option<DetailedModelInfo> {
|
||||
let mut inner = self.inner.write().await;
|
||||
let key = inner.name_to_key.get(name).cloned()?;
|
||||
let stale = inner
|
||||
.fetched_at
|
||||
.get(&key)
|
||||
.is_some_and(|ts| ts.elapsed() >= self.ttl);
|
||||
if stale {
|
||||
inner.by_key.remove(&key);
|
||||
inner.name_to_key.remove(name);
|
||||
inner.fetched_at.remove(&key);
|
||||
return None;
|
||||
}
|
||||
inner.by_key.get(&key).cloned()
|
||||
}
|
||||
|
||||
/// Cache the provided details, overwriting existing entries.
|
||||
pub async fn insert(&self, info: DetailedModelInfo) {
|
||||
let key = info.digest.clone().unwrap_or_else(|| info.name.clone());
|
||||
let mut inner = self.inner.write().await;
|
||||
|
||||
// Remove prior mappings for this model name (possibly different digest).
|
||||
if let Some(previous_key) = inner.name_to_key.get(&info.name).cloned()
|
||||
&& previous_key != key
|
||||
{
|
||||
inner.by_key.remove(&previous_key);
|
||||
inner.fetched_at.remove(&previous_key);
|
||||
}
|
||||
|
||||
inner.fetched_at.insert(key.clone(), Instant::now());
|
||||
inner.name_to_key.insert(info.name.clone(), key.clone());
|
||||
inner.by_key.insert(key, info);
|
||||
}
|
||||
|
||||
/// Remove a specific model from the cache.
|
||||
pub async fn invalidate(&self, name: &str) {
|
||||
let mut inner = self.inner.write().await;
|
||||
if let Some(key) = inner.name_to_key.remove(name) {
|
||||
inner.by_key.remove(&key);
|
||||
inner.fetched_at.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the entire cache.
|
||||
pub async fn invalidate_all(&self) {
|
||||
let mut inner = self.inner.write().await;
|
||||
inner.by_key.clear();
|
||||
inner.name_to_key.clear();
|
||||
inner.fetched_at.clear();
|
||||
}
|
||||
|
||||
/// Return all cached values regardless of freshness.
|
||||
pub async fn cached(&self) -> Vec<DetailedModelInfo> {
|
||||
let inner = self.inner.read().await;
|
||||
inner.by_key.values().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
fn sample_details(name: &str) -> DetailedModelInfo {
|
||||
DetailedModelInfo {
|
||||
name: name.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_details_cache_returns_cached_entry() {
|
||||
let cache = ModelDetailsCache::new(Duration::from_millis(50));
|
||||
let info = sample_details("llama");
|
||||
cache.insert(info.clone()).await;
|
||||
let cached = cache.get("llama").await;
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap().name, "llama");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_details_cache_expires_based_on_ttl() {
|
||||
let cache = ModelDetailsCache::new(Duration::from_millis(10));
|
||||
cache.insert(sample_details("phi")).await;
|
||||
sleep(Duration::from_millis(30)).await;
|
||||
assert!(cache.get("phi").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_details_cache_invalidate_removes_entry() {
|
||||
let cache = ModelDetailsCache::new(Duration::from_secs(1));
|
||||
cache.insert(sample_details("mistral")).await;
|
||||
cache.invalidate("mistral").await;
|
||||
assert!(cache.get("mistral").await.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
105
crates/owlen-core/src/model/details.rs
Normal file
105
crates/owlen-core/src/model/details.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
//! Detailed model metadata for provider inspection features.
|
||||
//!
|
||||
//! These types capture richer information about locally available models
|
||||
//! than the lightweight [`crate::types::ModelInfo`] listing and back the
|
||||
//! higher-level inspection UI exposed in the Owlen TUI.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Rich metadata about an Ollama model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct DetailedModelInfo {
|
||||
/// Canonical model name (including tag).
|
||||
pub name: String,
|
||||
/// Reported architecture or model format.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub architecture: Option<String>,
|
||||
/// Human-readable parameter / quantisation summary.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<String>,
|
||||
/// Context window length, if provided.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context_length: Option<u64>,
|
||||
/// Embedding vector length for embedding-capable models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embedding_length: Option<u64>,
|
||||
/// Quantisation level (e.g., Q4_0, Q5_K_M).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub quantization: Option<String>,
|
||||
/// Primary family identifier (e.g., llama3).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub family: Option<String>,
|
||||
/// Additional family tags reported by Ollama.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub families: Vec<String>,
|
||||
/// Verbose parameter size description (e.g., 70B parameters).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameter_size: Option<String>,
|
||||
/// Default prompt template packaged with the model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub template: Option<String>,
|
||||
/// Default system prompt packaged with the model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
/// License string provided by the model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub license: Option<String>,
|
||||
/// Raw modelfile contents (if available).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modelfile: Option<String>,
|
||||
/// Modification timestamp (ISO-8601) if reported.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modified_at: Option<String>,
|
||||
/// Approximate model size in bytes.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<u64>,
|
||||
/// Digest / checksum used by Ollama (sha256).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub digest: Option<String>,
|
||||
}
|
||||
|
||||
impl DetailedModelInfo {
|
||||
/// Convenience helper that normalises empty strings to `None`.
|
||||
pub fn with_normalised_strings(mut self) -> Self {
|
||||
if self.architecture.as_ref().is_some_and(String::is_empty) {
|
||||
self.architecture = None;
|
||||
}
|
||||
if self.parameters.as_ref().is_some_and(String::is_empty) {
|
||||
self.parameters = None;
|
||||
}
|
||||
if self.quantization.as_ref().is_some_and(String::is_empty) {
|
||||
self.quantization = None;
|
||||
}
|
||||
if self.family.as_ref().is_some_and(String::is_empty) {
|
||||
self.family = None;
|
||||
}
|
||||
if self.parameter_size.as_ref().is_some_and(String::is_empty) {
|
||||
self.parameter_size = None;
|
||||
}
|
||||
if self.template.as_ref().is_some_and(String::is_empty) {
|
||||
self.template = None;
|
||||
}
|
||||
if self.system.as_ref().is_some_and(String::is_empty) {
|
||||
self.system = None;
|
||||
}
|
||||
if self.license.as_ref().is_some_and(String::is_empty) {
|
||||
self.license = None;
|
||||
}
|
||||
if self.modelfile.as_ref().is_some_and(String::is_empty) {
|
||||
self.modelfile = None;
|
||||
}
|
||||
if self.digest.as_ref().is_some_and(String::is_empty) {
|
||||
self.digest = None;
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Error payload returned when model inspection fails for a specific model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfoRetrievalError {
|
||||
/// Model that failed to resolve.
|
||||
pub model_name: String,
|
||||
/// Human-readable description of the failure.
|
||||
pub error_message: String,
|
||||
}
|
||||
507
crates/owlen-core/src/oauth.rs
Normal file
507
crates/owlen-core/src/oauth.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, config::McpOAuthConfig};
|
||||
|
||||
/// Persisted OAuth token set for MCP servers and providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct OAuthToken {
|
||||
/// Bearer access token returned by the authorization server.
|
||||
pub access_token: String,
|
||||
/// Optional refresh token if the provider issues one.
|
||||
#[serde(default)]
|
||||
pub refresh_token: Option<String>,
|
||||
/// Absolute UTC expiration timestamp for the access token.
|
||||
#[serde(default)]
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
/// Optional space-delimited scope string supplied by the provider.
|
||||
#[serde(default)]
|
||||
pub scope: Option<String>,
|
||||
/// Token type reported by the provider (typically `Bearer`).
|
||||
#[serde(default)]
|
||||
pub token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl OAuthToken {
|
||||
/// Returns `true` if the access token has expired at the provided instant.
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if now >= expiry)
|
||||
}
|
||||
|
||||
/// Returns `true` if the token will expire within the supplied duration window.
|
||||
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
|
||||
}
|
||||
}
|
||||
|
||||
/// Active device-authorization session details returned by the authorization server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeviceAuthorization {
|
||||
pub device_code: String,
|
||||
pub user_code: String,
|
||||
pub verification_uri: String,
|
||||
pub verification_uri_complete: Option<String>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub interval: StdDuration,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl DeviceAuthorization {
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
now >= self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of polling the token endpoint during a device-authorization flow.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DevicePollState {
|
||||
Pending { retry_in: StdDuration },
|
||||
Complete(OAuthToken),
|
||||
}
|
||||
|
||||
pub struct OAuthClient {
|
||||
http: Client,
|
||||
config: McpOAuthConfig,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub fn new(config: McpOAuthConfig) -> Result<Self> {
|
||||
let http = Client::builder()
|
||||
.user_agent("OwlenOAuth/1.0")
|
||||
.build()
|
||||
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
|
||||
Ok(Self { http, config })
|
||||
}
|
||||
|
||||
fn scope_value(&self) -> Option<String> {
|
||||
if self.config.scopes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.config.scopes.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
fn token_request_base(&self) -> Vec<(String, String)> {
|
||||
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
|
||||
if let Some(secret) = &self.config.client_secret {
|
||||
params.push(("client_secret".to_string(), secret.clone()));
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
|
||||
let device_url = self
|
||||
.config
|
||||
.device_authorization_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
Error::Config("Device authorization endpoint is not configured.".to_string())
|
||||
})?;
|
||||
|
||||
let mut params = self.token_request_base();
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(device_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("start device authorization", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let payload = response
|
||||
.json::<DeviceAuthorizationResponse>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse device authorization response (status {status}): {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let expires_at =
|
||||
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
|
||||
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
|
||||
|
||||
Ok(DeviceAuthorization {
|
||||
device_code: payload.device_code,
|
||||
user_code: payload.user_code,
|
||||
verification_uri: payload.verification_uri,
|
||||
verification_uri_complete: payload.verification_uri_complete,
|
||||
expires_at,
|
||||
interval,
|
||||
message: payload.message,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
|
||||
params.push(("device_code".to_string(), auth.device_code.clone()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("poll device token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read token response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth token response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
return Ok(DevicePollState::Complete(oauth_token_from_response(
|
||||
payload,
|
||||
)));
|
||||
}
|
||||
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
|
||||
match error.error.as_str() {
|
||||
"authorization_pending" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval,
|
||||
}),
|
||||
"slow_down" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
|
||||
}),
|
||||
"access_denied" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"User declined authorization".to_string()
|
||||
})))
|
||||
}
|
||||
"expired_token" | "expired_device_code" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"Device authorization expired".to_string()
|
||||
})))
|
||||
}
|
||||
other => Err(Error::Auth(
|
||||
error
|
||||
.error_description
|
||||
.unwrap_or_else(|| format!("OAuth error: {other}")),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), "refresh_token".to_string()));
|
||||
params.push(("refresh_token".to_string(), refresh_token.to_string()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("refresh OAuth token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read refresh response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth refresh response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
Ok(oauth_token_from_response(payload))
|
||||
} else {
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
format!("OAuth token refresh failed: {}", error.error)
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceAuthorizationResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default)]
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
#[serde(default)]
|
||||
interval: Option<u64>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
scope: Option<String>,
|
||||
#[serde(default)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthErrorResponse {
|
||||
error: String,
|
||||
#[serde(default)]
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
|
||||
let expires_at = payload
|
||||
.expires_in
|
||||
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
|
||||
.map(|seconds| Utc::now() + Duration::seconds(seconds));
|
||||
|
||||
OAuthToken {
|
||||
access_token: payload.access_token,
|
||||
refresh_token: payload.refresh_token,
|
||||
expires_at,
|
||||
scope: payload.scope,
|
||||
token_type: payload.token_type,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
|
||||
if err.is_timeout() {
|
||||
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
|
||||
} else if err.is_connect() {
|
||||
Error::Network(format!("OAuth {action} connection error: {err}"))
|
||||
} else {
|
||||
Error::Network(format!("OAuth {action} request failed: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn config_for(server: &MockServer) -> McpOAuthConfig {
|
||||
McpOAuthConfig {
|
||||
client_id: "test-client".to_string(),
|
||||
client_secret: None,
|
||||
authorize_url: server.url("/authorize"),
|
||||
token_url: server.url("/token"),
|
||||
device_authorization_url: Some(server.url("/device")),
|
||||
redirect_url: None,
|
||||
scopes: vec!["repo".to_string(), "user".to_string()],
|
||||
token_env: None,
|
||||
header: None,
|
||||
header_prefix: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_device_authorization() -> DeviceAuthorization {
|
||||
DeviceAuthorization {
|
||||
device_code: "device-123".to_string(),
|
||||
user_code: "ABCD-EFGH".to_string(),
|
||||
verification_uri: "https://example.test/activate".to_string(),
|
||||
verification_uri_complete: Some(
|
||||
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
|
||||
),
|
||||
expires_at: Utc::now() + Duration::minutes(10),
|
||||
interval: StdDuration::from_secs(5),
|
||||
message: Some("Open the verification URL and enter the code.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_device_authorization_returns_payload() {
|
||||
let server = MockServer::start_async().await;
|
||||
let device_mock = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-123",
|
||||
"user_code": "ABCD-EFGH",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
|
||||
"expires_in": 600,
|
||||
"interval": 7,
|
||||
"message": "Open the verification URL and enter the code."
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = OAuthClient::new(config_for(&server)).expect("client");
|
||||
let auth = client
|
||||
.start_device_authorization()
|
||||
.await
|
||||
.expect("device authorization payload");
|
||||
|
||||
assert_eq!(auth.user_code, "ABCD-EFGH");
|
||||
assert_eq!(auth.interval, StdDuration::from_secs(7));
|
||||
assert!(auth.expires_at > Utc::now());
|
||||
device_mock.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_reports_pending() {
|
||||
let server = MockServer::start_async().await;
|
||||
let pending = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains(
|
||||
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
|
||||
)
|
||||
.body_contains("device_code=device-123");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "authorization_pending"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(5));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
pending.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_applies_slow_down_backoff() {
|
||||
let server = MockServer::start_async().await;
|
||||
let slow = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "slow_down"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(10));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
slow.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_returns_token_when_authorized() {
|
||||
let server = MockServer::start_async().await;
|
||||
let token = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-abc",
|
||||
"refresh_token": "refresh-xyz",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "repo user"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
let token_info = match result {
|
||||
DevicePollState::Complete(token) => token,
|
||||
other => panic!("expected completion, got {other:?}"),
|
||||
};
|
||||
|
||||
assert_eq!(token_info.access_token, "token-abc");
|
||||
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
|
||||
assert!(token_info.expires_at.is_some());
|
||||
token.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_token_roundtrip() {
|
||||
let server = MockServer::start_async().await;
|
||||
let refresh = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains("grant_type=refresh_token")
|
||||
.body_contains("refresh_token=old-refresh");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-new",
|
||||
"refresh_token": "refresh-new",
|
||||
"expires_in": 1200,
|
||||
"token_type": "Bearer"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let token = client
|
||||
.refresh_token("old-refresh")
|
||||
.await
|
||||
.expect("refresh response");
|
||||
|
||||
assert_eq!(token.access_token, "token-new");
|
||||
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
|
||||
assert!(token.expires_at.is_some());
|
||||
refresh.assert_async().await;
|
||||
}
|
||||
}
|
||||
@@ -1,369 +0,0 @@
|
||||
//! Provider traits and registries.
|
||||
|
||||
use crate::{types::*, Error, Result};
|
||||
use anyhow::anyhow;
|
||||
use futures::{Stream, StreamExt};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A stream of chat responses
|
||||
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
||||
|
||||
/// Trait for LLM providers (Ollama, OpenAI, Anthropic, etc.) with zero-cost static dispatch.
|
||||
pub trait LLMProvider: Send + Sync + 'static {
|
||||
type Stream: Stream<Item = Result<ChatResponse>> + Send + 'static;
|
||||
|
||||
type ListModelsFuture<'a>: Future<Output = Result<Vec<ModelInfo>>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type ChatFuture<'a>: Future<Output = Result<ChatResponse>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type ChatStreamFuture<'a>: Future<Output = Result<Self::Stream>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type HealthCheckFuture<'a>: Future<Output = Result<()>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn name(&self) -> &str;
|
||||
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_>;
|
||||
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_>;
|
||||
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_>;
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_>;
|
||||
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that implements [`LLMProvider::chat`] in terms of [`LLMProvider::chat_stream`].
|
||||
pub async fn chat_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result<ChatResponse>
|
||||
where
|
||||
P: LLMProvider + 'a,
|
||||
{
|
||||
let stream = provider.chat_stream(request).await?;
|
||||
let mut boxed: ChatStream = Box::pin(stream);
|
||||
match boxed.next().await {
|
||||
Some(Ok(response)) => Ok(response),
|
||||
Some(Err(err)) => Err(err),
|
||||
None => Err(Error::Provider(anyhow!(
|
||||
"Empty chat stream from provider {}",
|
||||
provider.name()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Object-safe wrapper trait for runtime-configurable provider usage.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Get the name of this provider.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// List available models from this provider.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Send a chat completion request.
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||
|
||||
/// Send a streaming chat completion request.
|
||||
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||
|
||||
/// Check if the provider is available/healthy.
|
||||
async fn health_check(&self) -> Result<()>;
|
||||
|
||||
/// Get provider-specific configuration schema.
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T> Provider for T
|
||||
where
|
||||
T: LLMProvider,
|
||||
{
|
||||
fn name(&self) -> &str {
|
||||
LLMProvider::name(self)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
LLMProvider::list_models(self).await
|
||||
}
|
||||
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||
LLMProvider::chat(self, request).await
|
||||
}
|
||||
|
||||
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||
let stream = LLMProvider::chat_stream(self, request).await?;
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<()> {
|
||||
LLMProvider::health_check(self).await
|
||||
}
|
||||
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
LLMProvider::config_schema(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a provider
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
/// Provider type identifier
|
||||
pub provider_type: String,
|
||||
/// Base URL for API calls
|
||||
pub base_url: Option<String>,
|
||||
/// API key or token
|
||||
pub api_key: Option<String>,
|
||||
/// Additional provider-specific configuration
|
||||
#[serde(flatten)]
|
||||
pub extra: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A registry of providers
|
||||
pub struct ProviderRegistry {
|
||||
providers: std::collections::HashMap<String, Arc<dyn Provider>>,
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
/// Create a new provider registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider using static dispatch.
|
||||
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
|
||||
self.register_arc(Arc::new(provider));
|
||||
}
|
||||
|
||||
/// Register an already wrapped provider
|
||||
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
|
||||
let name = provider.name().to_string();
|
||||
self.providers.insert(name, provider);
|
||||
}
|
||||
|
||||
/// Get a provider by name
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
|
||||
/// List all registered provider names
|
||||
pub fn list_providers(&self) -> Vec<String> {
|
||||
self.providers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get all models from all providers
|
||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let mut all_models = Vec::new();
|
||||
|
||||
for provider in self.providers.values() {
|
||||
match provider.list_models().await {
|
||||
Ok(mut models) => all_models.append(&mut models),
|
||||
Err(_) => {
|
||||
// Continue with other providers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_models)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use super::*;
|
||||
use crate::types::{ChatRequest, ChatResponse, Message, ModelInfo, Role};
|
||||
use futures::stream;
|
||||
use std::future::{ready, Ready};
|
||||
|
||||
/// Mock provider for testing
|
||||
#[derive(Default)]
|
||||
pub struct MockProvider;
|
||||
|
||||
impl LLMProvider for MockProvider {
|
||||
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||
type ListModelsFuture<'a> = Ready<Result<Vec<ModelInfo>>>;
|
||||
type ChatFuture<'a> = Ready<Result<ChatResponse>>;
|
||||
type ChatStreamFuture<'a> = Ready<Result<Self::Stream>>;
|
||||
type HealthCheckFuture<'a> = Ready<Result<()>>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||
ready(Ok(vec![ModelInfo {
|
||||
id: "mock-model".to_string(),
|
||||
provider: "mock".to_string(),
|
||||
name: "mock-model".to_string(),
|
||||
description: None,
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}]))
|
||||
}
|
||||
|
||||
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
|
||||
ready(Ok(self.build_response(&request)))
|
||||
}
|
||||
|
||||
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
|
||||
let response = self.build_response(&request);
|
||||
ready(Ok(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||
ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
fn build_response(&self, request: &ChatRequest) -> ChatResponse {
|
||||
let content = format!(
|
||||
"Mock response to: {}",
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.map(|m| m.content.clone())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::new(Role::Assistant, content),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::test_utils::MockProvider;
|
||||
use super::*;
|
||||
use crate::types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role};
|
||||
use futures::stream;
|
||||
use std::future::{ready, Ready};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct StreamingProvider;
|
||||
|
||||
impl LLMProvider for StreamingProvider {
|
||||
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||
type ListModelsFuture<'a> = Ready<Result<Vec<ModelInfo>>>;
|
||||
type ChatFuture<'a> = Ready<Result<ChatResponse>>;
|
||||
type ChatStreamFuture<'a> = Ready<Result<Self::Stream>>;
|
||||
type HealthCheckFuture<'a> = Ready<Result<()>>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"streaming"
|
||||
}
|
||||
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||
ready(Ok(vec![ModelInfo {
|
||||
id: "stream-model".to_string(),
|
||||
provider: "streaming".to_string(),
|
||||
name: "stream-model".to_string(),
|
||||
description: None,
|
||||
context_window: None,
|
||||
capabilities: vec!["chat".to_string()],
|
||||
supports_tools: false,
|
||||
}]))
|
||||
}
|
||||
|
||||
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
|
||||
ready(Ok(self.response(&request)))
|
||||
}
|
||||
|
||||
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
|
||||
let response = self.response(&request);
|
||||
ready(Ok(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||
ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingProvider {
|
||||
fn response(&self, request: &ChatRequest) -> ChatResponse {
|
||||
let reply = format!(
|
||||
"echo:{}",
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.map(|m| m.content.clone())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
ChatResponse {
|
||||
message: Message::new(Role::Assistant, reply),
|
||||
usage: None,
|
||||
is_streaming: true,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_chat_reads_from_stream() {
|
||||
let provider = StreamingProvider;
|
||||
let request = ChatRequest {
|
||||
model: "stream-model".to_string(),
|
||||
messages: vec![Message::new(Role::User, "ping".to_string())],
|
||||
parameters: ChatParameters::default(),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = LLMProvider::chat(&provider, request)
|
||||
.await
|
||||
.expect("chat succeeded");
|
||||
assert_eq!(response.message.content, "echo:ping");
|
||||
assert!(response.is_final);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registry_registers_static_provider() {
|
||||
let mut registry = ProviderRegistry::new();
|
||||
registry.register(StreamingProvider);
|
||||
|
||||
let provider = registry.get("streaming").expect("provider registered");
|
||||
let models = provider.list_models().await.expect("models listed");
|
||||
assert_eq!(models[0].id, "stream-model");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registry_accepts_dynamic_provider() {
|
||||
let mut registry = ProviderRegistry::new();
|
||||
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default());
|
||||
registry.register_arc(provider.clone());
|
||||
|
||||
let fetched = registry.get("mock").expect("mock provider present");
|
||||
let request = ChatRequest {
|
||||
model: "mock-model".to_string(),
|
||||
messages: vec![Message::new(Role::User, "hi".to_string())],
|
||||
parameters: ChatParameters::default(),
|
||||
tools: None,
|
||||
};
|
||||
let response = Provider::chat(fetched.as_ref(), request)
|
||||
.await
|
||||
.expect("chat succeeded");
|
||||
assert_eq!(response.message.content, "Mock response to: hi");
|
||||
}
|
||||
}
|
||||
227
crates/owlen-core/src/provider/manager.rs
Normal file
227
crates/owlen-core/src/provider/manager.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use log::{debug, warn};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
|
||||
|
||||
/// Model information annotated with the originating provider metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnnotatedModelInfo {
|
||||
pub provider_id: String,
|
||||
pub provider_status: ProviderStatus,
|
||||
pub model: ModelInfo,
|
||||
}
|
||||
|
||||
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
|
||||
/// health state.
|
||||
pub struct ProviderManager {
|
||||
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
|
||||
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
/// Construct a new manager using the supplied configuration. Providers
|
||||
/// defined in the configuration start with a `RequiresSetup` status so
|
||||
/// that frontends can surface incomplete configuration to users.
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let mut status_cache = HashMap::new();
|
||||
for provider_id in config.providers.keys() {
|
||||
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
|
||||
}
|
||||
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(status_cache),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider instance with the manager.
|
||||
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
|
||||
let provider_id = provider.metadata().id.clone();
|
||||
debug!("registering provider {}", provider_id);
|
||||
|
||||
self.providers
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.clone(), provider);
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id, ProviderStatus::Unavailable);
|
||||
}
|
||||
|
||||
/// Return a stream by routing the request to the designated provider.
|
||||
pub async fn generate(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStream> {
|
||||
let provider = {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
|
||||
|
||||
match provider.generate_stream(request).await {
|
||||
Ok(stream) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// List models across all providers, updating provider status along the way.
|
||||
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let log_id = provider_id.clone();
|
||||
let mut status = ProviderStatus::Unavailable;
|
||||
let mut models = Vec::new();
|
||||
|
||||
match provider.health_check().await {
|
||||
Ok(health) => {
|
||||
status = health;
|
||||
if matches!(status, ProviderStatus::Available) {
|
||||
match provider.list_models().await {
|
||||
Ok(list) => {
|
||||
models = list;
|
||||
}
|
||||
Err(err) => {
|
||||
status = ProviderStatus::Unavailable;
|
||||
warn!("listing models failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
|
||||
(provider_id, status, models)
|
||||
});
|
||||
}
|
||||
|
||||
let mut annotated = Vec::new();
|
||||
let mut status_updates = HashMap::new();
|
||||
|
||||
while let Some((provider_id, status, models)) = tasks.next().await {
|
||||
status_updates.insert(provider_id.clone(), status);
|
||||
for model in models {
|
||||
annotated.push(AnnotatedModelInfo {
|
||||
provider_id: provider_id.clone(),
|
||||
provider_status: status,
|
||||
model,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in status_updates {
|
||||
guard.insert(provider_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(annotated)
|
||||
}
|
||||
|
||||
/// Refresh the health of all registered providers in parallel, returning
|
||||
/// the latest status snapshot.
|
||||
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let status = match provider.health_check().await {
|
||||
Ok(status) => status,
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", provider_id, err);
|
||||
ProviderStatus::Unavailable
|
||||
}
|
||||
};
|
||||
(provider_id, status)
|
||||
});
|
||||
}
|
||||
|
||||
let mut updates = HashMap::new();
|
||||
while let Some((provider_id, status)) = tasks.next().await {
|
||||
updates.insert(provider_id, status);
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in &updates {
|
||||
guard.insert(provider_id.clone(), *status);
|
||||
}
|
||||
}
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
/// Return the provider instance for an identifier.
|
||||
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
|
||||
/// List the registered provider identifiers.
|
||||
pub async fn provider_ids(&self) -> Vec<String> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Retrieve the last known status for a provider.
|
||||
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.get(provider_id).copied()
|
||||
}
|
||||
|
||||
/// Snapshot the currently cached statuses.
|
||||
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
36
crates/owlen-core/src/provider/mod.rs
Normal file
36
crates/owlen-core/src/provider/mod.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
//! Unified provider abstraction layer.
|
||||
//!
|
||||
//! This module defines the async [`ModelProvider`] trait that all model
|
||||
//! backends implement, together with a small suite of shared data structures
|
||||
//! used for model discovery and streaming generation. The [`ProviderManager`]
|
||||
//! orchestrates multiple providers and coordinates their health state.
|
||||
|
||||
mod manager;
|
||||
mod types;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
pub use self::{manager::*, types::*};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
/// Convenience alias for the stream type yielded by [`ModelProvider::generate_stream`].
|
||||
pub type GenerateStream = Pin<Box<dyn Stream<Item = Result<GenerateChunk>> + Send + 'static>>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ModelProvider: Send + Sync {
|
||||
/// Returns descriptive metadata about the provider.
|
||||
fn metadata(&self) -> &ProviderMetadata;
|
||||
|
||||
/// Check the current health state for the provider.
|
||||
async fn health_check(&self) -> Result<ProviderStatus>;
|
||||
|
||||
/// List all models available through the provider.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Acquire a streaming response for a generation request.
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> Result<GenerateStream>;
|
||||
}
|
||||
124
crates/owlen-core/src/provider/types.rs
Normal file
124
crates/owlen-core/src/provider/types.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
//! Shared types used by the unified provider abstraction layer.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Categorises providers so the UI can distinguish between local and hosted
|
||||
/// backends.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderType {
|
||||
Local,
|
||||
Cloud,
|
||||
}
|
||||
|
||||
/// Represents the current availability state for a provider.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderStatus {
|
||||
Available,
|
||||
Unavailable,
|
||||
RequiresSetup,
|
||||
}
|
||||
|
||||
/// Describes core metadata for a provider implementation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ProviderMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub provider_type: ProviderType,
|
||||
pub requires_auth: bool,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl ProviderMetadata {
|
||||
/// Construct a new metadata instance for a provider.
|
||||
pub fn new(
|
||||
id: impl Into<String>,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
requires_auth: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
provider_type,
|
||||
requires_auth,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a model that can be displayed to users.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub size_bytes: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub capabilities: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub provider: ProviderMetadata,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
/// Unified request for streaming text generation across providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GenerateRequest {
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub prompt: Option<String>,
|
||||
#[serde(default)]
|
||||
pub context: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub parameters: HashMap<String, Value>,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
/// Helper for building a request from the minimum required fields.
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
prompt: None,
|
||||
context: Vec::new(),
|
||||
parameters: HashMap::new(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streamed chunk of generation output from a model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GenerateChunk {
|
||||
#[serde(default)]
|
||||
pub text: Option<String>,
|
||||
#[serde(default)]
|
||||
pub is_final: bool,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerateChunk {
|
||||
/// Construct a new chunk with the provided text payload.
|
||||
pub fn from_text(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: Some(text.into()),
|
||||
is_final: false,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark the chunk as the terminal item in a stream.
|
||||
pub fn final_chunk() -> Self {
|
||||
Self {
|
||||
text: None,
|
||||
is_final: true,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
//! Router for managing multiple providers and routing requests
|
||||
|
||||
use crate::{provider::*, types::*, Result};
|
||||
use crate::{Result, llm::*, types::*};
|
||||
use anyhow::anyhow;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A router that can distribute requests across multiple providers
|
||||
@@ -32,7 +33,7 @@ impl Router {
|
||||
}
|
||||
|
||||
/// Register a provider with the router
|
||||
pub fn register_provider<P: LLMProvider + 'static>(&mut self, provider: P) {
|
||||
pub fn register_provider<P: LlmProvider + 'static>(&mut self, provider: P) {
|
||||
self.registry.register(provider);
|
||||
}
|
||||
|
||||
@@ -52,13 +53,13 @@ impl Router {
|
||||
/// Route a request to the appropriate provider
|
||||
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||
let provider = self.find_provider_for_model(&request.model)?;
|
||||
provider.chat(request).await
|
||||
provider.send_prompt(request).await
|
||||
}
|
||||
|
||||
/// Route a streaming request to the appropriate provider
|
||||
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||
let provider = self.find_provider_for_model(&request.model)?;
|
||||
provider.chat_stream(request).await
|
||||
provider.stream_prompt(request).await
|
||||
}
|
||||
|
||||
/// List all available models from all providers
|
||||
@@ -70,18 +71,21 @@ impl Router {
|
||||
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
||||
// Check routing rules first
|
||||
for rule in &self.routing_rules {
|
||||
if self.matches_pattern(&rule.model_pattern, model) {
|
||||
if let Some(provider) = self.registry.get(&rule.provider) {
|
||||
return Ok(provider);
|
||||
}
|
||||
if !self.matches_pattern(&rule.model_pattern, model) {
|
||||
continue;
|
||||
}
|
||||
if let Some(provider) = self.registry.get(&rule.provider) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default provider
|
||||
if let Some(default) = &self.default_provider {
|
||||
if let Some(provider) = self.registry.get(default) {
|
||||
return Ok(provider);
|
||||
}
|
||||
if let Some(provider) = self
|
||||
.default_provider
|
||||
.as_ref()
|
||||
.and_then(|default| self.registry.get(default))
|
||||
{
|
||||
return Ok(provider);
|
||||
}
|
||||
|
||||
// If no default, try to find any provider that has this model
|
||||
@@ -92,7 +96,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
Err(crate::Error::Provider(anyhow::anyhow!(
|
||||
Err(crate::Error::Provider(anyhow!(
|
||||
"No provider found for model: {}",
|
||||
model
|
||||
)))
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Configuration options for sandboxed process execution.
|
||||
@@ -185,16 +185,20 @@ impl SandboxedProcess {
|
||||
if let Ok(output) = output {
|
||||
let version_str = String::from_utf8_lossy(&output.stdout);
|
||||
// Parse version like "bubblewrap 0.11.0" or "0.11.0"
|
||||
if let Some(version_part) = version_str.split_whitespace().last() {
|
||||
if let Some((major, rest)) = version_part.split_once('.') {
|
||||
if let Some((minor, _patch)) = rest.split_once('.') {
|
||||
if let (Ok(maj), Ok(min)) = (major.parse::<u32>(), minor.parse::<u32>()) {
|
||||
// --rlimit-as was added in 0.12.0
|
||||
return maj > 0 || (maj == 0 && min >= 12);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return version_str
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.and_then(|part| {
|
||||
part.split_once('.').and_then(|(major, rest)| {
|
||||
rest.split_once('.').and_then(|(minor, _)| {
|
||||
let maj = major.parse::<u32>().ok()?;
|
||||
let min = minor.parse::<u32>().ok()?;
|
||||
Some((maj, min))
|
||||
})
|
||||
})
|
||||
})
|
||||
.map(|(maj, min)| maj > 0 || (maj == 0 && min >= 12))
|
||||
.unwrap_or(false);
|
||||
}
|
||||
|
||||
// If we can't determine the version, assume it doesn't support it (safer default)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
199
crates/owlen-core/src/state/mod.rs
Normal file
199
crates/owlen-core/src/state/mod.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
//! Shared application state types used across TUI frontends.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// High-level application state reported by the UI loop.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum AppState {
|
||||
Running,
|
||||
Quit,
|
||||
}
|
||||
|
||||
/// Vim-style input modes supported by the TUI.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum InputMode {
|
||||
Normal,
|
||||
Editing,
|
||||
ProviderSelection,
|
||||
ModelSelection,
|
||||
Help,
|
||||
Visual,
|
||||
Command,
|
||||
SessionBrowser,
|
||||
ThemeBrowser,
|
||||
RepoSearch,
|
||||
SymbolSearch,
|
||||
}
|
||||
|
||||
impl fmt::Display for InputMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
InputMode::Normal => "Normal",
|
||||
InputMode::Editing => "Editing",
|
||||
InputMode::ModelSelection => "Model",
|
||||
InputMode::ProviderSelection => "Provider",
|
||||
InputMode::Help => "Help",
|
||||
InputMode::Visual => "Visual",
|
||||
InputMode::Command => "Command",
|
||||
InputMode::SessionBrowser => "Sessions",
|
||||
InputMode::ThemeBrowser => "Themes",
|
||||
InputMode::RepoSearch => "Search",
|
||||
InputMode::SymbolSearch => "Symbols",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents which panel is currently focused in the TUI layout.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum FocusedPanel {
|
||||
Files,
|
||||
Chat,
|
||||
Thinking,
|
||||
Input,
|
||||
Code,
|
||||
}
|
||||
|
||||
/// Auto-scroll state manager for scrollable panels.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoScroll {
|
||||
pub scroll: usize,
|
||||
pub content_len: usize,
|
||||
pub stick_to_bottom: bool,
|
||||
}
|
||||
|
||||
impl Default for AutoScroll {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scroll: 0,
|
||||
content_len: 0,
|
||||
stick_to_bottom: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoScroll {
|
||||
/// Update scroll position based on viewport height.
|
||||
pub fn on_viewport(&mut self, viewport_h: usize) {
|
||||
let max = self.content_len.saturating_sub(viewport_h);
|
||||
if self.stick_to_bottom {
|
||||
self.scroll = max;
|
||||
} else {
|
||||
self.scroll = self.scroll.min(max);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle user scroll input.
|
||||
pub fn on_user_scroll(&mut self, delta: isize, viewport_h: usize) {
|
||||
let max = self.content_len.saturating_sub(viewport_h) as isize;
|
||||
let s = (self.scroll as isize + delta).clamp(0, max) as usize;
|
||||
self.scroll = s;
|
||||
self.stick_to_bottom = s as isize == max;
|
||||
}
|
||||
|
||||
pub fn scroll_half_page_down(&mut self, viewport_h: usize) {
|
||||
let delta = (viewport_h / 2) as isize;
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn scroll_half_page_up(&mut self, viewport_h: usize) {
|
||||
let delta = -((viewport_h / 2) as isize);
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn scroll_full_page_down(&mut self, viewport_h: usize) {
|
||||
let delta = viewport_h as isize;
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn scroll_full_page_up(&mut self, viewport_h: usize) {
|
||||
let delta = -(viewport_h as isize);
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn jump_to_top(&mut self) {
|
||||
self.scroll = 0;
|
||||
self.stick_to_bottom = false;
|
||||
}
|
||||
|
||||
pub fn jump_to_bottom(&mut self, viewport_h: usize) {
|
||||
self.stick_to_bottom = true;
|
||||
self.on_viewport(viewport_h);
|
||||
}
|
||||
}
|
||||
|
||||
/// Visual selection state for text selection.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct VisualSelection {
|
||||
pub start: Option<(usize, usize)>,
|
||||
pub end: Option<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl VisualSelection {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn start_at(&mut self, pos: (usize, usize)) {
|
||||
self.start = Some(pos);
|
||||
self.end = Some(pos);
|
||||
}
|
||||
|
||||
pub fn extend_to(&mut self, pos: (usize, usize)) {
|
||||
self.end = Some(pos);
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.start = None;
|
||||
self.end = None;
|
||||
}
|
||||
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.start.is_some() && self.end.is_some()
|
||||
}
|
||||
|
||||
pub fn get_normalized(&self) -> Option<((usize, usize), (usize, usize))> {
|
||||
if let (Some(s), Some(e)) = (self.start, self.end) {
|
||||
if s.0 < e.0 || (s.0 == e.0 && s.1 <= e.1) {
|
||||
Some((s, e))
|
||||
} else {
|
||||
Some((e, s))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cursor position helper for navigating scrollable content.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct CursorPosition {
|
||||
pub row: usize,
|
||||
pub col: usize,
|
||||
}
|
||||
|
||||
impl CursorPosition {
|
||||
pub fn new(row: usize, col: usize) -> Self {
|
||||
Self { row, col }
|
||||
}
|
||||
|
||||
pub fn move_up(&mut self, amount: usize) {
|
||||
self.row = self.row.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub fn move_down(&mut self, amount: usize, max: usize) {
|
||||
self.row = (self.row + amount).min(max);
|
||||
}
|
||||
|
||||
pub fn move_left(&mut self, amount: usize) {
|
||||
self.col = self.col.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub fn move_right(&mut self, amount: usize, max: usize) {
|
||||
self.col = (self.col + amount).min(max);
|
||||
}
|
||||
|
||||
pub fn as_tuple(&self) -> (usize, usize) {
|
||||
(self.row, self.col)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub type ThemePalette = Theme;
|
||||
|
||||
/// A complete theme definition for OWLEN TUI
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Theme {
|
||||
@@ -34,6 +36,42 @@ pub struct Theme {
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub unfocused_panel_border: Color,
|
||||
|
||||
/// Foreground color for the active pane beacon (`▌`)
|
||||
#[serde(default = "Theme::default_focus_beacon_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub focus_beacon_fg: Color,
|
||||
|
||||
/// Background color for the active pane beacon (`▌`)
|
||||
#[serde(default = "Theme::default_focus_beacon_bg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub focus_beacon_bg: Color,
|
||||
|
||||
/// Foreground color for the inactive pane beacon (`▌`)
|
||||
#[serde(default = "Theme::default_unfocused_beacon_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub unfocused_beacon_fg: Color,
|
||||
|
||||
/// Title color for active pane headers
|
||||
#[serde(default = "Theme::default_pane_header_active")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub pane_header_active: Color,
|
||||
|
||||
/// Title color for inactive pane headers
|
||||
#[serde(default = "Theme::default_pane_header_inactive")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub pane_header_inactive: Color,
|
||||
|
||||
/// Hint text color used within pane headers
|
||||
#[serde(default = "Theme::default_pane_hint_text")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub pane_hint_text: Color,
|
||||
|
||||
/// Color for user message role indicator
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
@@ -114,6 +152,42 @@ pub struct Theme {
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub cursor: Color,
|
||||
|
||||
/// Code block background color
|
||||
#[serde(default = "Theme::default_code_block_background")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub code_block_background: Color,
|
||||
|
||||
/// Code block border color
|
||||
#[serde(default = "Theme::default_code_block_border")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub code_block_border: Color,
|
||||
|
||||
/// Code block text color
|
||||
#[serde(default = "Theme::default_code_block_text")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub code_block_text: Color,
|
||||
|
||||
/// Code block keyword color
|
||||
#[serde(default = "Theme::default_code_block_keyword")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub code_block_keyword: Color,
|
||||
|
||||
/// Code block string literal color
|
||||
#[serde(default = "Theme::default_code_block_string")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub code_block_string: Color,
|
||||
|
||||
/// Code block comment color
|
||||
#[serde(default = "Theme::default_code_block_comment")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub code_block_comment: Color,
|
||||
|
||||
/// Placeholder text color
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
@@ -128,6 +202,84 @@ pub struct Theme {
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub info: Color,
|
||||
|
||||
/// Agent action coloring (ReAct THOUGHT)
|
||||
#[serde(default = "Theme::default_agent_thought")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_thought: Color,
|
||||
|
||||
/// Agent action coloring (ReAct ACTION)
|
||||
#[serde(default = "Theme::default_agent_action")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_action: Color,
|
||||
|
||||
/// Agent action coloring (ReAct ACTION_INPUT)
|
||||
#[serde(default = "Theme::default_agent_action_input")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_action_input: Color,
|
||||
|
||||
/// Agent action coloring (ReAct OBSERVATION)
|
||||
#[serde(default = "Theme::default_agent_observation")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_observation: Color,
|
||||
|
||||
/// Agent action coloring (ReAct FINAL_ANSWER)
|
||||
#[serde(default = "Theme::default_agent_final_answer")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_final_answer: Color,
|
||||
|
||||
/// Status badge foreground when agent is running
|
||||
#[serde(default = "Theme::default_agent_badge_running_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_badge_running_fg: Color,
|
||||
|
||||
/// Status badge background when agent is running
|
||||
#[serde(default = "Theme::default_agent_badge_running_bg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_badge_running_bg: Color,
|
||||
|
||||
/// Status badge foreground when agent mode is idle
|
||||
#[serde(default = "Theme::default_agent_badge_idle_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_badge_idle_fg: Color,
|
||||
|
||||
/// Status badge background when agent mode is idle
|
||||
#[serde(default = "Theme::default_agent_badge_idle_bg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub agent_badge_idle_bg: Color,
|
||||
|
||||
/// Operating mode badge foreground (Chat)
|
||||
#[serde(default = "Theme::default_operating_chat_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub operating_chat_fg: Color,
|
||||
|
||||
/// Operating mode badge background (Chat)
|
||||
#[serde(default = "Theme::default_operating_chat_bg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub operating_chat_bg: Color,
|
||||
|
||||
/// Operating mode badge foreground (Code)
|
||||
#[serde(default = "Theme::default_operating_code_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub operating_code_fg: Color,
|
||||
|
||||
/// Operating mode badge background (Code)
|
||||
#[serde(default = "Theme::default_operating_code_bg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub operating_code_bg: Color,
|
||||
}
|
||||
|
||||
impl Default for Theme {
|
||||
@@ -136,6 +288,108 @@ impl Default for Theme {
|
||||
}
|
||||
}
|
||||
|
||||
impl Theme {
|
||||
const fn default_code_block_background() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_code_block_border() -> Color {
|
||||
Color::Gray
|
||||
}
|
||||
|
||||
const fn default_code_block_text() -> Color {
|
||||
Color::White
|
||||
}
|
||||
|
||||
const fn default_code_block_keyword() -> Color {
|
||||
Color::Yellow
|
||||
}
|
||||
|
||||
const fn default_code_block_string() -> Color {
|
||||
Color::LightGreen
|
||||
}
|
||||
|
||||
const fn default_code_block_comment() -> Color {
|
||||
Color::DarkGray
|
||||
}
|
||||
|
||||
const fn default_agent_thought() -> Color {
|
||||
Color::LightBlue
|
||||
}
|
||||
|
||||
const fn default_agent_action() -> Color {
|
||||
Color::Yellow
|
||||
}
|
||||
|
||||
const fn default_agent_action_input() -> Color {
|
||||
Color::LightCyan
|
||||
}
|
||||
|
||||
const fn default_agent_observation() -> Color {
|
||||
Color::LightGreen
|
||||
}
|
||||
|
||||
const fn default_agent_final_answer() -> Color {
|
||||
Color::Magenta
|
||||
}
|
||||
|
||||
const fn default_agent_badge_running_fg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_agent_badge_running_bg() -> Color {
|
||||
Color::Yellow
|
||||
}
|
||||
|
||||
const fn default_agent_badge_idle_fg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_agent_badge_idle_bg() -> Color {
|
||||
Color::Cyan
|
||||
}
|
||||
|
||||
const fn default_focus_beacon_fg() -> Color {
|
||||
Color::LightMagenta
|
||||
}
|
||||
|
||||
const fn default_focus_beacon_bg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_unfocused_beacon_fg() -> Color {
|
||||
Color::DarkGray
|
||||
}
|
||||
|
||||
const fn default_pane_header_active() -> Color {
|
||||
Color::White
|
||||
}
|
||||
|
||||
const fn default_pane_header_inactive() -> Color {
|
||||
Color::Gray
|
||||
}
|
||||
|
||||
const fn default_pane_hint_text() -> Color {
|
||||
Color::DarkGray
|
||||
}
|
||||
|
||||
const fn default_operating_chat_fg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_operating_chat_bg() -> Color {
|
||||
Color::Blue
|
||||
}
|
||||
|
||||
const fn default_operating_code_fg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_operating_code_bg() -> Color {
|
||||
Color::Magenta
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default themes directory path
|
||||
pub fn default_themes_dir() -> PathBuf {
|
||||
let config_dir = PathBuf::from(shellexpand::tilde(crate::config::DEFAULT_CONFIG_PATH).as_ref())
|
||||
@@ -213,6 +467,10 @@ pub fn built_in_themes() -> HashMap<String, Theme> {
|
||||
"ansi_basic",
|
||||
include_str!("../../../themes/ansi-basic.toml"),
|
||||
),
|
||||
(
|
||||
"grayscale-high-contrast",
|
||||
include_str!("../../../themes/grayscale-high-contrast.toml"),
|
||||
),
|
||||
("gruvbox", include_str!("../../../themes/gruvbox.toml")),
|
||||
("dracula", include_str!("../../../themes/dracula.toml")),
|
||||
("solarized", include_str!("../../../themes/solarized.toml")),
|
||||
@@ -263,6 +521,7 @@ fn get_fallback_theme(name: &str) -> Option<Theme> {
|
||||
"monokai" => Some(monokai()),
|
||||
"material-dark" => Some(material_dark()),
|
||||
"material-light" => Some(material_light()),
|
||||
"grayscale-high-contrast" => Some(grayscale_high_contrast()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -273,27 +532,52 @@ fn default_dark() -> Theme {
|
||||
name: "default_dark".to_string(),
|
||||
text: Color::White,
|
||||
background: Color::Black,
|
||||
focused_panel_border: Color::LightMagenta,
|
||||
unfocused_panel_border: Color::Rgb(95, 20, 135),
|
||||
focused_panel_border: Color::Rgb(216, 160, 255),
|
||||
unfocused_panel_border: Color::Rgb(137, 82, 204),
|
||||
focus_beacon_fg: Color::Rgb(248, 229, 255),
|
||||
focus_beacon_bg: Color::Rgb(38, 10, 58),
|
||||
unfocused_beacon_fg: Color::Rgb(130, 130, 130),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Color::Rgb(210, 210, 210),
|
||||
pane_hint_text: Color::Rgb(210, 210, 210),
|
||||
user_message_role: Color::LightBlue,
|
||||
assistant_message_role: Color::Yellow,
|
||||
tool_output: Color::Gray,
|
||||
thinking_panel_title: Color::LightMagenta,
|
||||
command_bar_background: Color::Black,
|
||||
status_background: Color::Black,
|
||||
mode_normal: Color::LightBlue,
|
||||
mode_editing: Color::LightGreen,
|
||||
mode_model_selection: Color::LightYellow,
|
||||
mode_provider_selection: Color::LightCyan,
|
||||
mode_help: Color::LightMagenta,
|
||||
mode_visual: Color::Magenta,
|
||||
mode_command: Color::Yellow,
|
||||
selection_bg: Color::LightBlue,
|
||||
tool_output: Color::Rgb(200, 200, 200),
|
||||
thinking_panel_title: Color::Rgb(234, 182, 255),
|
||||
command_bar_background: Color::Rgb(10, 10, 10),
|
||||
status_background: Color::Rgb(12, 12, 12),
|
||||
mode_normal: Color::Rgb(117, 200, 255),
|
||||
mode_editing: Color::Rgb(144, 242, 170),
|
||||
mode_model_selection: Color::Rgb(255, 226, 140),
|
||||
mode_provider_selection: Color::Rgb(164, 235, 255),
|
||||
mode_help: Color::Rgb(234, 182, 255),
|
||||
mode_visual: Color::Rgb(255, 170, 255),
|
||||
mode_command: Color::Rgb(255, 220, 120),
|
||||
selection_bg: Color::Rgb(56, 140, 240),
|
||||
selection_fg: Color::Black,
|
||||
cursor: Color::Magenta,
|
||||
placeholder: Color::DarkGray,
|
||||
cursor: Color::Rgb(255, 196, 255),
|
||||
code_block_background: Color::Rgb(25, 25, 25),
|
||||
code_block_border: Color::Rgb(216, 160, 255),
|
||||
code_block_text: Color::White,
|
||||
code_block_keyword: Color::Rgb(255, 220, 120),
|
||||
code_block_string: Color::Rgb(144, 242, 170),
|
||||
code_block_comment: Color::Rgb(170, 170, 170),
|
||||
placeholder: Color::Rgb(180, 180, 180),
|
||||
error: Color::Red,
|
||||
info: Color::LightGreen,
|
||||
info: Color::Rgb(144, 242, 170),
|
||||
agent_thought: Color::Rgb(117, 200, 255),
|
||||
agent_action: Color::Rgb(255, 220, 120),
|
||||
agent_action_input: Color::Rgb(164, 235, 255),
|
||||
agent_observation: Color::Rgb(144, 242, 170),
|
||||
agent_final_answer: Color::Rgb(255, 170, 255),
|
||||
agent_badge_running_fg: Color::Black,
|
||||
agent_badge_running_bg: Color::Yellow,
|
||||
agent_badge_idle_fg: Color::Black,
|
||||
agent_badge_idle_bg: Color::Cyan,
|
||||
operating_chat_fg: Color::Black,
|
||||
operating_chat_bg: Color::Rgb(117, 200, 255),
|
||||
operating_code_fg: Color::Black,
|
||||
operating_code_bg: Color::Rgb(255, 170, 255),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,6 +589,12 @@ fn default_light() -> Theme {
|
||||
background: Color::White,
|
||||
focused_panel_border: Color::Rgb(74, 144, 226),
|
||||
unfocused_panel_border: Color::Rgb(221, 221, 221),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(0, 85, 164),
|
||||
assistant_message_role: Color::Rgb(142, 68, 173),
|
||||
tool_output: Color::Gray,
|
||||
@@ -321,9 +611,28 @@ fn default_light() -> Theme {
|
||||
selection_bg: Color::Rgb(164, 200, 240),
|
||||
selection_fg: Color::Black,
|
||||
cursor: Color::Rgb(217, 95, 2),
|
||||
code_block_background: Color::Rgb(245, 245, 245),
|
||||
code_block_border: Color::Rgb(142, 68, 173),
|
||||
code_block_text: Color::Black,
|
||||
code_block_keyword: Color::Rgb(181, 137, 0),
|
||||
code_block_string: Color::Rgb(46, 139, 87),
|
||||
code_block_comment: Color::Gray,
|
||||
placeholder: Color::Gray,
|
||||
error: Color::Rgb(192, 57, 43),
|
||||
info: Color::Green,
|
||||
agent_thought: Color::Rgb(0, 85, 164),
|
||||
agent_action: Color::Rgb(181, 137, 0),
|
||||
agent_action_input: Color::Rgb(0, 139, 139),
|
||||
agent_observation: Color::Rgb(46, 139, 87),
|
||||
agent_final_answer: Color::Rgb(142, 68, 173),
|
||||
agent_badge_running_fg: Color::White,
|
||||
agent_badge_running_bg: Color::Rgb(241, 196, 15),
|
||||
agent_badge_idle_fg: Color::White,
|
||||
agent_badge_idle_bg: Color::Rgb(0, 150, 136),
|
||||
operating_chat_fg: Color::White,
|
||||
operating_chat_bg: Color::Rgb(0, 85, 164),
|
||||
operating_code_fg: Color::White,
|
||||
operating_code_bg: Color::Rgb(142, 68, 173),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,7 +644,13 @@ fn gruvbox() -> Theme {
|
||||
background: Color::Rgb(40, 40, 40), // #282828
|
||||
focused_panel_border: Color::Rgb(254, 128, 25), // #fe8019 (orange)
|
||||
unfocused_panel_border: Color::Rgb(124, 111, 100), // #7c6f64
|
||||
user_message_role: Color::Rgb(184, 187, 38), // #b8bb26 (green)
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(184, 187, 38), // #b8bb26 (green)
|
||||
assistant_message_role: Color::Rgb(131, 165, 152), // #83a598 (blue)
|
||||
tool_output: Color::Rgb(146, 131, 116),
|
||||
thinking_panel_title: Color::Rgb(211, 134, 155), // #d3869b (purple)
|
||||
@@ -351,9 +666,28 @@ fn gruvbox() -> Theme {
|
||||
selection_bg: Color::Rgb(80, 73, 69),
|
||||
selection_fg: Color::Rgb(235, 219, 178),
|
||||
cursor: Color::Rgb(254, 128, 25),
|
||||
code_block_background: Color::Rgb(60, 56, 54),
|
||||
code_block_border: Color::Rgb(124, 111, 100),
|
||||
code_block_text: Color::Rgb(235, 219, 178),
|
||||
code_block_keyword: Color::Rgb(250, 189, 47),
|
||||
code_block_string: Color::Rgb(142, 192, 124),
|
||||
code_block_comment: Color::Rgb(124, 111, 100),
|
||||
placeholder: Color::Rgb(102, 92, 84),
|
||||
error: Color::Rgb(251, 73, 52), // #fb4934
|
||||
info: Color::Rgb(184, 187, 38),
|
||||
agent_thought: Color::Rgb(131, 165, 152),
|
||||
agent_action: Color::Rgb(250, 189, 47),
|
||||
agent_action_input: Color::Rgb(142, 192, 124),
|
||||
agent_observation: Color::Rgb(184, 187, 38),
|
||||
agent_final_answer: Color::Rgb(211, 134, 155),
|
||||
agent_badge_running_fg: Color::Rgb(40, 40, 40),
|
||||
agent_badge_running_bg: Color::Rgb(250, 189, 47),
|
||||
agent_badge_idle_fg: Color::Rgb(40, 40, 40),
|
||||
agent_badge_idle_bg: Color::Rgb(131, 165, 152),
|
||||
operating_chat_fg: Color::Rgb(40, 40, 40),
|
||||
operating_chat_bg: Color::Rgb(131, 165, 152),
|
||||
operating_code_fg: Color::Rgb(40, 40, 40),
|
||||
operating_code_bg: Color::Rgb(211, 134, 155),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,11 +695,17 @@ fn gruvbox() -> Theme {
|
||||
fn dracula() -> Theme {
|
||||
Theme {
|
||||
name: "dracula".to_string(),
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(40, 42, 54), // #282a36
|
||||
focused_panel_border: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||
unfocused_panel_border: Color::Rgb(68, 71, 90), // #44475a
|
||||
user_message_role: Color::Rgb(139, 233, 253), // #8be9fd (cyan)
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(40, 42, 54), // #282a36
|
||||
focused_panel_border: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||
unfocused_panel_border: Color::Rgb(68, 71, 90), // #44475a
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(139, 233, 253), // #8be9fd (cyan)
|
||||
assistant_message_role: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||
tool_output: Color::Rgb(98, 114, 164),
|
||||
thinking_panel_title: Color::Rgb(189, 147, 249), // #bd93f9 (purple)
|
||||
@@ -381,9 +721,28 @@ fn dracula() -> Theme {
|
||||
selection_bg: Color::Rgb(68, 71, 90),
|
||||
selection_fg: Color::Rgb(248, 248, 242),
|
||||
cursor: Color::Rgb(255, 121, 198),
|
||||
code_block_background: Color::Rgb(68, 71, 90),
|
||||
code_block_border: Color::Rgb(189, 147, 249),
|
||||
code_block_text: Color::Rgb(248, 248, 242),
|
||||
code_block_keyword: Color::Rgb(255, 121, 198),
|
||||
code_block_string: Color::Rgb(80, 250, 123),
|
||||
code_block_comment: Color::Rgb(98, 114, 164),
|
||||
placeholder: Color::Rgb(98, 114, 164),
|
||||
error: Color::Rgb(255, 85, 85), // #ff5555
|
||||
info: Color::Rgb(80, 250, 123),
|
||||
agent_thought: Color::Rgb(139, 233, 253),
|
||||
agent_action: Color::Rgb(241, 250, 140),
|
||||
agent_action_input: Color::Rgb(189, 147, 249),
|
||||
agent_observation: Color::Rgb(80, 250, 123),
|
||||
agent_final_answer: Color::Rgb(255, 121, 198),
|
||||
agent_badge_running_fg: Color::Rgb(40, 42, 54),
|
||||
agent_badge_running_bg: Color::Rgb(241, 250, 140),
|
||||
agent_badge_idle_fg: Color::Rgb(40, 42, 54),
|
||||
agent_badge_idle_bg: Color::Rgb(139, 233, 253),
|
||||
operating_chat_fg: Color::Rgb(40, 42, 54),
|
||||
operating_chat_bg: Color::Rgb(139, 233, 253),
|
||||
operating_code_fg: Color::Rgb(40, 42, 54),
|
||||
operating_code_bg: Color::Rgb(189, 147, 249),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -395,6 +754,12 @@ fn solarized() -> Theme {
|
||||
background: Color::Rgb(0, 43, 54), // #002b36 (base03)
|
||||
focused_panel_border: Color::Rgb(38, 139, 210), // #268bd2 (blue)
|
||||
unfocused_panel_border: Color::Rgb(7, 54, 66), // #073642 (base02)
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(42, 161, 152), // #2aa198 (cyan)
|
||||
assistant_message_role: Color::Rgb(203, 75, 22), // #cb4b16 (orange)
|
||||
tool_output: Color::Rgb(101, 123, 131),
|
||||
@@ -411,9 +776,28 @@ fn solarized() -> Theme {
|
||||
selection_bg: Color::Rgb(7, 54, 66),
|
||||
selection_fg: Color::Rgb(147, 161, 161),
|
||||
cursor: Color::Rgb(211, 54, 130),
|
||||
code_block_background: Color::Rgb(7, 54, 66),
|
||||
code_block_border: Color::Rgb(38, 139, 210),
|
||||
code_block_text: Color::Rgb(147, 161, 161),
|
||||
code_block_keyword: Color::Rgb(181, 137, 0),
|
||||
code_block_string: Color::Rgb(133, 153, 0),
|
||||
code_block_comment: Color::Rgb(88, 110, 117),
|
||||
placeholder: Color::Rgb(88, 110, 117),
|
||||
error: Color::Rgb(220, 50, 47), // #dc322f (red)
|
||||
info: Color::Rgb(133, 153, 0),
|
||||
agent_thought: Color::Rgb(42, 161, 152),
|
||||
agent_action: Color::Rgb(181, 137, 0),
|
||||
agent_action_input: Color::Rgb(38, 139, 210),
|
||||
agent_observation: Color::Rgb(133, 153, 0),
|
||||
agent_final_answer: Color::Rgb(108, 113, 196),
|
||||
agent_badge_running_fg: Color::Rgb(0, 43, 54),
|
||||
agent_badge_running_bg: Color::Rgb(181, 137, 0),
|
||||
agent_badge_idle_fg: Color::Rgb(0, 43, 54),
|
||||
agent_badge_idle_bg: Color::Rgb(42, 161, 152),
|
||||
operating_chat_fg: Color::Rgb(0, 43, 54),
|
||||
operating_chat_bg: Color::Rgb(42, 161, 152),
|
||||
operating_code_fg: Color::Rgb(0, 43, 54),
|
||||
operating_code_bg: Color::Rgb(108, 113, 196),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,6 +809,12 @@ fn midnight_ocean() -> Theme {
|
||||
background: Color::Rgb(13, 17, 23),
|
||||
focused_panel_border: Color::Rgb(88, 166, 255),
|
||||
unfocused_panel_border: Color::Rgb(48, 54, 61),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(121, 192, 255),
|
||||
assistant_message_role: Color::Rgb(137, 221, 255),
|
||||
tool_output: Color::Rgb(84, 110, 122),
|
||||
@@ -441,9 +831,28 @@ fn midnight_ocean() -> Theme {
|
||||
selection_bg: Color::Rgb(56, 139, 253),
|
||||
selection_fg: Color::Rgb(13, 17, 23),
|
||||
cursor: Color::Rgb(246, 140, 245),
|
||||
code_block_background: Color::Rgb(22, 27, 34),
|
||||
code_block_border: Color::Rgb(88, 166, 255),
|
||||
code_block_text: Color::Rgb(192, 202, 245),
|
||||
code_block_keyword: Color::Rgb(255, 212, 59),
|
||||
code_block_string: Color::Rgb(158, 206, 106),
|
||||
code_block_comment: Color::Rgb(110, 118, 129),
|
||||
placeholder: Color::Rgb(110, 118, 129),
|
||||
error: Color::Rgb(248, 81, 73),
|
||||
info: Color::Rgb(158, 206, 106),
|
||||
agent_thought: Color::Rgb(121, 192, 255),
|
||||
agent_action: Color::Rgb(255, 212, 59),
|
||||
agent_action_input: Color::Rgb(137, 221, 255),
|
||||
agent_observation: Color::Rgb(158, 206, 106),
|
||||
agent_final_answer: Color::Rgb(246, 140, 245),
|
||||
agent_badge_running_fg: Color::Rgb(13, 17, 23),
|
||||
agent_badge_running_bg: Color::Rgb(255, 212, 59),
|
||||
agent_badge_idle_fg: Color::Rgb(13, 17, 23),
|
||||
agent_badge_idle_bg: Color::Rgb(137, 221, 255),
|
||||
operating_chat_fg: Color::Rgb(13, 17, 23),
|
||||
operating_chat_bg: Color::Rgb(121, 192, 255),
|
||||
operating_code_fg: Color::Rgb(13, 17, 23),
|
||||
operating_code_bg: Color::Rgb(246, 140, 245),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,11 +860,17 @@ fn midnight_ocean() -> Theme {
|
||||
fn rose_pine() -> Theme {
|
||||
Theme {
|
||||
name: "rose-pine".to_string(),
|
||||
text: Color::Rgb(224, 222, 244), // #e0def4
|
||||
background: Color::Rgb(25, 23, 36), // #191724
|
||||
focused_panel_border: Color::Rgb(235, 111, 146), // #eb6f92 (love)
|
||||
unfocused_panel_border: Color::Rgb(38, 35, 58), // #26233a
|
||||
user_message_role: Color::Rgb(49, 116, 143), // #31748f (foam)
|
||||
text: Color::Rgb(224, 222, 244), // #e0def4
|
||||
background: Color::Rgb(25, 23, 36), // #191724
|
||||
focused_panel_border: Color::Rgb(235, 111, 146), // #eb6f92 (love)
|
||||
unfocused_panel_border: Color::Rgb(38, 35, 58), // #26233a
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(49, 116, 143), // #31748f (foam)
|
||||
assistant_message_role: Color::Rgb(156, 207, 216), // #9ccfd8 (foam light)
|
||||
tool_output: Color::Rgb(110, 106, 134),
|
||||
thinking_panel_title: Color::Rgb(196, 167, 231), // #c4a7e7 (iris)
|
||||
@@ -471,9 +886,28 @@ fn rose_pine() -> Theme {
|
||||
selection_bg: Color::Rgb(64, 61, 82),
|
||||
selection_fg: Color::Rgb(224, 222, 244),
|
||||
cursor: Color::Rgb(235, 111, 146),
|
||||
code_block_background: Color::Rgb(38, 35, 58),
|
||||
code_block_border: Color::Rgb(235, 111, 146),
|
||||
code_block_text: Color::Rgb(224, 222, 244),
|
||||
code_block_keyword: Color::Rgb(246, 193, 119),
|
||||
code_block_string: Color::Rgb(156, 207, 216),
|
||||
code_block_comment: Color::Rgb(110, 106, 134),
|
||||
placeholder: Color::Rgb(110, 106, 134),
|
||||
error: Color::Rgb(235, 111, 146),
|
||||
info: Color::Rgb(156, 207, 216),
|
||||
agent_thought: Color::Rgb(156, 207, 216),
|
||||
agent_action: Color::Rgb(246, 193, 119),
|
||||
agent_action_input: Color::Rgb(196, 167, 231),
|
||||
agent_observation: Color::Rgb(235, 188, 186),
|
||||
agent_final_answer: Color::Rgb(235, 111, 146),
|
||||
agent_badge_running_fg: Color::Rgb(25, 23, 36),
|
||||
agent_badge_running_bg: Color::Rgb(246, 193, 119),
|
||||
agent_badge_idle_fg: Color::Rgb(25, 23, 36),
|
||||
agent_badge_idle_bg: Color::Rgb(156, 207, 216),
|
||||
operating_chat_fg: Color::Rgb(25, 23, 36),
|
||||
operating_chat_bg: Color::Rgb(156, 207, 216),
|
||||
operating_code_fg: Color::Rgb(25, 23, 36),
|
||||
operating_code_bg: Color::Rgb(196, 167, 231),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -481,11 +915,17 @@ fn rose_pine() -> Theme {
|
||||
fn monokai() -> Theme {
|
||||
Theme {
|
||||
name: "monokai".to_string(),
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(39, 40, 34), // #272822
|
||||
focused_panel_border: Color::Rgb(249, 38, 114), // #f92672 (pink)
|
||||
unfocused_panel_border: Color::Rgb(117, 113, 94), // #75715e
|
||||
user_message_role: Color::Rgb(102, 217, 239), // #66d9ef (cyan)
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(39, 40, 34), // #272822
|
||||
focused_panel_border: Color::Rgb(249, 38, 114), // #f92672 (pink)
|
||||
unfocused_panel_border: Color::Rgb(117, 113, 94), // #75715e
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(102, 217, 239), // #66d9ef (cyan)
|
||||
assistant_message_role: Color::Rgb(174, 129, 255), // #ae81ff (purple)
|
||||
tool_output: Color::Rgb(117, 113, 94),
|
||||
thinking_panel_title: Color::Rgb(230, 219, 116), // #e6db74 (yellow)
|
||||
@@ -501,9 +941,28 @@ fn monokai() -> Theme {
|
||||
selection_bg: Color::Rgb(117, 113, 94),
|
||||
selection_fg: Color::Rgb(248, 248, 242),
|
||||
cursor: Color::Rgb(249, 38, 114),
|
||||
code_block_background: Color::Rgb(50, 51, 46),
|
||||
code_block_border: Color::Rgb(249, 38, 114),
|
||||
code_block_text: Color::Rgb(248, 248, 242),
|
||||
code_block_keyword: Color::Rgb(230, 219, 116),
|
||||
code_block_string: Color::Rgb(166, 226, 46),
|
||||
code_block_comment: Color::Rgb(117, 113, 94),
|
||||
placeholder: Color::Rgb(117, 113, 94),
|
||||
error: Color::Rgb(249, 38, 114),
|
||||
info: Color::Rgb(166, 226, 46),
|
||||
agent_thought: Color::Rgb(102, 217, 239),
|
||||
agent_action: Color::Rgb(230, 219, 116),
|
||||
agent_action_input: Color::Rgb(174, 129, 255),
|
||||
agent_observation: Color::Rgb(166, 226, 46),
|
||||
agent_final_answer: Color::Rgb(249, 38, 114),
|
||||
agent_badge_running_fg: Color::Rgb(39, 40, 34),
|
||||
agent_badge_running_bg: Color::Rgb(230, 219, 116),
|
||||
agent_badge_idle_fg: Color::Rgb(39, 40, 34),
|
||||
agent_badge_idle_bg: Color::Rgb(102, 217, 239),
|
||||
operating_chat_fg: Color::Rgb(39, 40, 34),
|
||||
operating_chat_bg: Color::Rgb(102, 217, 239),
|
||||
operating_code_fg: Color::Rgb(39, 40, 34),
|
||||
operating_code_bg: Color::Rgb(174, 129, 255),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -511,11 +970,17 @@ fn monokai() -> Theme {
|
||||
fn material_dark() -> Theme {
|
||||
Theme {
|
||||
name: "material-dark".to_string(),
|
||||
text: Color::Rgb(238, 255, 255), // #eeffff
|
||||
background: Color::Rgb(38, 50, 56), // #263238
|
||||
focused_panel_border: Color::Rgb(128, 203, 196), // #80cbc4 (cyan)
|
||||
unfocused_panel_border: Color::Rgb(84, 110, 122), // #546e7a
|
||||
user_message_role: Color::Rgb(130, 170, 255), // #82aaff (blue)
|
||||
text: Color::Rgb(238, 255, 255), // #eeffff
|
||||
background: Color::Rgb(38, 50, 56), // #263238
|
||||
focused_panel_border: Color::Rgb(128, 203, 196), // #80cbc4 (cyan)
|
||||
unfocused_panel_border: Color::Rgb(84, 110, 122), // #546e7a
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(130, 170, 255), // #82aaff (blue)
|
||||
assistant_message_role: Color::Rgb(199, 146, 234), // #c792ea (purple)
|
||||
tool_output: Color::Rgb(84, 110, 122),
|
||||
thinking_panel_title: Color::Rgb(255, 203, 107), // #ffcb6b (yellow)
|
||||
@@ -531,9 +996,28 @@ fn material_dark() -> Theme {
|
||||
selection_bg: Color::Rgb(84, 110, 122),
|
||||
selection_fg: Color::Rgb(238, 255, 255),
|
||||
cursor: Color::Rgb(255, 204, 0),
|
||||
code_block_background: Color::Rgb(33, 43, 48),
|
||||
code_block_border: Color::Rgb(128, 203, 196),
|
||||
code_block_text: Color::Rgb(238, 255, 255),
|
||||
code_block_keyword: Color::Rgb(255, 203, 107),
|
||||
code_block_string: Color::Rgb(195, 232, 141),
|
||||
code_block_comment: Color::Rgb(84, 110, 122),
|
||||
placeholder: Color::Rgb(84, 110, 122),
|
||||
error: Color::Rgb(240, 113, 120),
|
||||
info: Color::Rgb(195, 232, 141),
|
||||
agent_thought: Color::Rgb(128, 203, 196),
|
||||
agent_action: Color::Rgb(255, 203, 107),
|
||||
agent_action_input: Color::Rgb(199, 146, 234),
|
||||
agent_observation: Color::Rgb(195, 232, 141),
|
||||
agent_final_answer: Color::Rgb(240, 113, 120),
|
||||
agent_badge_running_fg: Color::Rgb(38, 50, 56),
|
||||
agent_badge_running_bg: Color::Rgb(255, 203, 107),
|
||||
agent_badge_idle_fg: Color::Rgb(38, 50, 56),
|
||||
agent_badge_idle_bg: Color::Rgb(128, 203, 196),
|
||||
operating_chat_fg: Color::Rgb(38, 50, 56),
|
||||
operating_chat_bg: Color::Rgb(130, 170, 255),
|
||||
operating_code_fg: Color::Rgb(38, 50, 56),
|
||||
operating_code_bg: Color::Rgb(199, 146, 234),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,6 +1029,12 @@ fn material_light() -> Theme {
|
||||
background: Color::Rgb(236, 239, 241),
|
||||
focused_panel_border: Color::Rgb(0, 150, 136),
|
||||
unfocused_panel_border: Color::Rgb(176, 190, 197),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(68, 138, 255),
|
||||
assistant_message_role: Color::Rgb(124, 77, 255),
|
||||
tool_output: Color::Rgb(144, 164, 174),
|
||||
@@ -561,9 +1051,83 @@ fn material_light() -> Theme {
|
||||
selection_bg: Color::Rgb(176, 190, 197),
|
||||
selection_fg: Color::Rgb(33, 33, 33),
|
||||
cursor: Color::Rgb(194, 24, 91),
|
||||
code_block_background: Color::Rgb(248, 249, 250),
|
||||
code_block_border: Color::Rgb(0, 150, 136),
|
||||
code_block_text: Color::Rgb(33, 33, 33),
|
||||
code_block_keyword: Color::Rgb(245, 124, 0),
|
||||
code_block_string: Color::Rgb(56, 142, 60),
|
||||
code_block_comment: Color::Rgb(144, 164, 174),
|
||||
placeholder: Color::Rgb(144, 164, 174),
|
||||
error: Color::Rgb(211, 47, 47),
|
||||
info: Color::Rgb(56, 142, 60),
|
||||
agent_thought: Color::Rgb(68, 138, 255),
|
||||
agent_action: Color::Rgb(245, 124, 0),
|
||||
agent_action_input: Color::Rgb(124, 77, 255),
|
||||
agent_observation: Color::Rgb(56, 142, 60),
|
||||
agent_final_answer: Color::Rgb(211, 47, 47),
|
||||
agent_badge_running_fg: Color::White,
|
||||
agent_badge_running_bg: Color::Rgb(245, 124, 0),
|
||||
agent_badge_idle_fg: Color::White,
|
||||
agent_badge_idle_bg: Color::Rgb(0, 150, 136),
|
||||
operating_chat_fg: Color::White,
|
||||
operating_chat_bg: Color::Rgb(68, 138, 255),
|
||||
operating_code_fg: Color::White,
|
||||
operating_code_bg: Color::Rgb(124, 77, 255),
|
||||
}
|
||||
}
|
||||
|
||||
/// Grayscale high-contrast theme
|
||||
fn grayscale_high_contrast() -> Theme {
|
||||
Theme {
|
||||
name: "grayscale_high_contrast".to_string(),
|
||||
text: Color::Rgb(247, 247, 247),
|
||||
background: Color::Black,
|
||||
focused_panel_border: Color::White,
|
||||
unfocused_panel_border: Color::Rgb(76, 76, 76),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(240, 240, 240),
|
||||
assistant_message_role: Color::Rgb(214, 214, 214),
|
||||
tool_output: Color::Rgb(189, 189, 189),
|
||||
thinking_panel_title: Color::Rgb(224, 224, 224),
|
||||
command_bar_background: Color::Black,
|
||||
status_background: Color::Rgb(15, 15, 15),
|
||||
mode_normal: Color::White,
|
||||
mode_editing: Color::Rgb(230, 230, 230),
|
||||
mode_model_selection: Color::Rgb(204, 204, 204),
|
||||
mode_provider_selection: Color::Rgb(179, 179, 179),
|
||||
mode_help: Color::Rgb(153, 153, 153),
|
||||
mode_visual: Color::Rgb(242, 242, 242),
|
||||
mode_command: Color::Rgb(208, 208, 208),
|
||||
selection_bg: Color::Rgb(240, 240, 240),
|
||||
selection_fg: Color::Black,
|
||||
cursor: Color::White,
|
||||
code_block_background: Color::Rgb(15, 15, 15),
|
||||
code_block_border: Color::White,
|
||||
code_block_text: Color::Rgb(247, 247, 247),
|
||||
code_block_keyword: Color::Rgb(204, 204, 204),
|
||||
code_block_string: Color::Rgb(214, 214, 214),
|
||||
code_block_comment: Color::Rgb(122, 122, 122),
|
||||
placeholder: Color::Rgb(122, 122, 122),
|
||||
error: Color::White,
|
||||
info: Color::Rgb(200, 200, 200),
|
||||
agent_thought: Color::Rgb(230, 230, 230),
|
||||
agent_action: Color::Rgb(204, 204, 204),
|
||||
agent_action_input: Color::Rgb(176, 176, 176),
|
||||
agent_observation: Color::Rgb(153, 153, 153),
|
||||
agent_final_answer: Color::White,
|
||||
agent_badge_running_fg: Color::Black,
|
||||
agent_badge_running_bg: Color::Rgb(247, 247, 247),
|
||||
agent_badge_idle_fg: Color::Black,
|
||||
agent_badge_idle_bg: Color::Rgb(189, 189, 189),
|
||||
operating_chat_fg: Color::Black,
|
||||
operating_chat_bg: Color::Rgb(242, 242, 242),
|
||||
operating_code_fg: Color::Black,
|
||||
operating_code_bg: Color::Rgb(191, 191, 191),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,16 +1150,16 @@ where
|
||||
}
|
||||
|
||||
fn parse_color(s: &str) -> Result<Color, String> {
|
||||
if let Some(hex) = s.strip_prefix('#') {
|
||||
if hex.len() == 6 {
|
||||
let r = u8::from_str_radix(&hex[0..2], 16)
|
||||
.map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||
let g = u8::from_str_radix(&hex[2..4], 16)
|
||||
.map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||
let b = u8::from_str_radix(&hex[4..6], 16)
|
||||
.map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||
return Ok(Color::Rgb(r, g, b));
|
||||
}
|
||||
if let Some(hex) = s.strip_prefix('#')
|
||||
&& hex.len() == 6
|
||||
{
|
||||
let r =
|
||||
u8::from_str_radix(&hex[0..2], 16).map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||
let g =
|
||||
u8::from_str_radix(&hex[2..4], 16).map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||
let b =
|
||||
u8::from_str_radix(&hex[4..6], 16).map_err(|_| format!("Invalid hex color: {}", s))?;
|
||||
return Ok(Color::Rgb(r, g, b));
|
||||
}
|
||||
|
||||
// Try named colors
|
||||
@@ -660,5 +1224,6 @@ mod tests {
|
||||
assert!(themes.contains_key("default_dark"));
|
||||
assert!(themes.contains_key("gruvbox"));
|
||||
assert!(themes.contains_key("dracula"));
|
||||
assert!(themes.contains_key("grayscale-high-contrast"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ pub mod web_search;
|
||||
pub mod web_search_detailed;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::Result;
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::sandbox::{SandboxConfig, SandboxedProcess};
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::{Tool, ToolResult};
|
||||
use crate::Result;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
/// Tool that fetches the raw HTML content for a list of URLs.
|
||||
///
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::time::Instant;
|
||||
use crate::Result;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::consent::ConsentManager;
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::time::Instant;
|
||||
use crate::Result;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::consent::ConsentManager;
|
||||
@@ -86,7 +86,9 @@ impl Tool for WebSearchDetailedTool {
|
||||
.expect("Consent manager mutex poisoned");
|
||||
|
||||
if !consent.has_consent(self.name()) {
|
||||
return Ok(ToolResult::error("Consent not granted for detailed web search. This should have been handled by the TUI."));
|
||||
return Ok(ToolResult::error(
|
||||
"Consent not granted for detailed web search. This should have been handled by the TUI.",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,170 +3,30 @@
|
||||
//! This module contains reusable UI components that can be shared between
|
||||
//! different TUI applications (chat, code, etc.)
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Application state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AppState {
|
||||
Running,
|
||||
Quit,
|
||||
}
|
||||
pub use crate::state::AppState;
|
||||
|
||||
/// Input modes for TUI applications
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum InputMode {
|
||||
Normal,
|
||||
Editing,
|
||||
ProviderSelection,
|
||||
ModelSelection,
|
||||
Help,
|
||||
Visual,
|
||||
Command,
|
||||
SessionBrowser,
|
||||
ThemeBrowser,
|
||||
}
|
||||
|
||||
impl fmt::Display for InputMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
InputMode::Normal => "Normal",
|
||||
InputMode::Editing => "Editing",
|
||||
InputMode::ModelSelection => "Model",
|
||||
InputMode::ProviderSelection => "Provider",
|
||||
InputMode::Help => "Help",
|
||||
InputMode::Visual => "Visual",
|
||||
InputMode::Command => "Command",
|
||||
InputMode::SessionBrowser => "Sessions",
|
||||
InputMode::ThemeBrowser => "Themes",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
pub use crate::state::InputMode;
|
||||
|
||||
/// Represents which panel is currently focused
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FocusedPanel {
|
||||
Chat,
|
||||
Thinking,
|
||||
Input,
|
||||
}
|
||||
pub use crate::state::FocusedPanel;
|
||||
|
||||
/// Auto-scroll state manager for scrollable panels
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoScroll {
|
||||
pub scroll: usize,
|
||||
pub content_len: usize,
|
||||
pub stick_to_bottom: bool,
|
||||
}
|
||||
|
||||
impl Default for AutoScroll {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scroll: 0,
|
||||
content_len: 0,
|
||||
stick_to_bottom: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoScroll {
|
||||
/// Update scroll position based on viewport height
|
||||
pub fn on_viewport(&mut self, viewport_h: usize) {
|
||||
let max = self.content_len.saturating_sub(viewport_h);
|
||||
if self.stick_to_bottom {
|
||||
self.scroll = max;
|
||||
} else {
|
||||
self.scroll = self.scroll.min(max);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle user scroll input
|
||||
pub fn on_user_scroll(&mut self, delta: isize, viewport_h: usize) {
|
||||
let max = self.content_len.saturating_sub(viewport_h) as isize;
|
||||
let s = (self.scroll as isize + delta).clamp(0, max) as usize;
|
||||
self.scroll = s;
|
||||
self.stick_to_bottom = s as isize == max;
|
||||
}
|
||||
|
||||
/// Scroll down half page
|
||||
pub fn scroll_half_page_down(&mut self, viewport_h: usize) {
|
||||
let delta = (viewport_h / 2) as isize;
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
/// Scroll up half page
|
||||
pub fn scroll_half_page_up(&mut self, viewport_h: usize) {
|
||||
let delta = -((viewport_h / 2) as isize);
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
/// Scroll down full page
|
||||
pub fn scroll_full_page_down(&mut self, viewport_h: usize) {
|
||||
let delta = viewport_h as isize;
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
/// Scroll up full page
|
||||
pub fn scroll_full_page_up(&mut self, viewport_h: usize) {
|
||||
let delta = -(viewport_h as isize);
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
/// Jump to top
|
||||
pub fn jump_to_top(&mut self) {
|
||||
self.scroll = 0;
|
||||
self.stick_to_bottom = false;
|
||||
}
|
||||
|
||||
/// Jump to bottom
|
||||
pub fn jump_to_bottom(&mut self, viewport_h: usize) {
|
||||
self.stick_to_bottom = true;
|
||||
self.on_viewport(viewport_h);
|
||||
}
|
||||
}
|
||||
pub use crate::state::AutoScroll;
|
||||
|
||||
/// Visual selection state for text selection
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct VisualSelection {
|
||||
pub start: Option<(usize, usize)>, // (row, col)
|
||||
pub end: Option<(usize, usize)>, // (row, col)
|
||||
}
|
||||
pub use crate::state::VisualSelection;
|
||||
|
||||
impl VisualSelection {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub fn start_at(&mut self, pos: (usize, usize)) {
|
||||
self.start = Some(pos);
|
||||
self.end = Some(pos);
|
||||
}
|
||||
|
||||
pub fn extend_to(&mut self, pos: (usize, usize)) {
|
||||
self.end = Some(pos);
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.start = None;
|
||||
self.end = None;
|
||||
}
|
||||
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.start.is_some() && self.end.is_some()
|
||||
}
|
||||
|
||||
pub fn get_normalized(&self) -> Option<((usize, usize), (usize, usize))> {
|
||||
if let (Some(s), Some(e)) = (self.start, self.end) {
|
||||
// Normalize selection so start is always before end
|
||||
if s.0 < e.0 || (s.0 == e.0 && s.1 <= e.1) {
|
||||
Some((s, e))
|
||||
} else {
|
||||
Some((e, s))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
/// How role labels should be rendered alongside chat messages.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum RoleLabelDisplay {
|
||||
Inline,
|
||||
Above,
|
||||
None,
|
||||
}
|
||||
|
||||
/// Extract text from a selection range in a list of lines
|
||||
@@ -235,37 +95,7 @@ pub fn extract_text_from_selection(
|
||||
}
|
||||
|
||||
/// Cursor position for navigating scrollable content
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct CursorPosition {
|
||||
pub row: usize,
|
||||
pub col: usize,
|
||||
}
|
||||
|
||||
impl CursorPosition {
|
||||
pub fn new(row: usize, col: usize) -> Self {
|
||||
Self { row, col }
|
||||
}
|
||||
|
||||
pub fn move_up(&mut self, amount: usize) {
|
||||
self.row = self.row.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub fn move_down(&mut self, amount: usize, max: usize) {
|
||||
self.row = (self.row + amount).min(max);
|
||||
}
|
||||
|
||||
pub fn move_left(&mut self, amount: usize) {
|
||||
self.col = self.col.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub fn move_right(&mut self, amount: usize, max: usize) {
|
||||
self.col = (self.col + amount).min(max);
|
||||
}
|
||||
|
||||
pub fn as_tuple(&self) -> (usize, usize) {
|
||||
(self.row, self.col)
|
||||
}
|
||||
}
|
||||
pub use crate::state::CursorPosition;
|
||||
|
||||
/// Word boundary detection for navigation
|
||||
pub fn find_next_word_boundary(line: &str, col: usize) -> Option<usize> {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use jsonschema::{JSONSchema, ValidationError};
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
pub struct SchemaValidator {
|
||||
schemas: HashMap<String, JSONSchema>,
|
||||
|
||||
310
crates/owlen-core/tests/agent_tool_flow.rs
Normal file
310
crates/owlen-core/tests/agent_tool_flow.rs
Normal file
@@ -0,0 +1,310 @@
|
||||
use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use owlen_core::{
|
||||
Config, Error, Mode, Provider,
|
||||
config::McpMode,
|
||||
consent::ConsentScope,
|
||||
mcp::{
|
||||
McpClient, McpToolCall, McpToolDescriptor, McpToolResponse,
|
||||
failover::{FailoverMcpClient, ServerEntry},
|
||||
},
|
||||
session::{ControllerEvent, SessionController, SessionOutcome},
|
||||
storage::StorageManager,
|
||||
types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall},
|
||||
ui::NoOpUiController,
|
||||
};
|
||||
use tempfile::tempdir;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct StreamingToolProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for StreamingToolProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock-streaming-provider"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: "mock-model".into(),
|
||||
name: "Mock Model".into(),
|
||||
description: Some("A mock model that emits tool calls".into()),
|
||||
provider: self.name().into(),
|
||||
context_window: Some(4096),
|
||||
capabilities: vec!["chat".into(), "tools".into()],
|
||||
supports_tools: true,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result<ChatResponse> {
|
||||
let mut message = Message::assistant("tool-call".to_string());
|
||||
message.tool_calls = Some(vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "resources/write".to_string(),
|
||||
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
||||
}]);
|
||||
|
||||
Ok(ChatResponse {
|
||||
message,
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream_prompt(
|
||||
&self,
|
||||
_request: ChatRequest,
|
||||
) -> owlen_core::Result<owlen_core::ChatStream> {
|
||||
let mut first_chunk = Message::assistant(
|
||||
"Thought: need to update README.\nAction: resources/write".to_string(),
|
||||
);
|
||||
first_chunk.tool_calls = Some(vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "resources/write".to_string(),
|
||||
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
||||
}]);
|
||||
|
||||
let chunk = ChatResponse {
|
||||
message: first_chunk,
|
||||
usage: None,
|
||||
is_streaming: true,
|
||||
is_final: false,
|
||||
};
|
||||
|
||||
Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> owlen_core::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "web_search".to_string(),
|
||||
description: "search".to_string(),
|
||||
input_schema: serde_json::json!({"type": "object"}),
|
||||
requires_network: true,
|
||||
requires_filesystem: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
struct TimeoutClient;
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for TimeoutClient {
|
||||
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
||||
Ok(vec![tool_descriptor()])
|
||||
}
|
||||
|
||||
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
||||
Err(Error::Network(
|
||||
"timeout while contacting remote web search endpoint".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CachedResponseClient {
|
||||
response: Arc<McpToolResponse>,
|
||||
}
|
||||
|
||||
impl CachedResponseClient {
|
||||
fn new() -> Self {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("source".to_string(), "cache".to_string());
|
||||
metadata.insert("cached".to_string(), "true".to_string());
|
||||
|
||||
let response = McpToolResponse {
|
||||
name: "web_search".to_string(),
|
||||
success: true,
|
||||
output: serde_json::json!({
|
||||
"query": "rust",
|
||||
"results": [
|
||||
{"title": "Rust Programming Language", "url": "https://www.rust-lang.org"}
|
||||
],
|
||||
"note": "cached result"
|
||||
}),
|
||||
metadata,
|
||||
duration_ms: 0,
|
||||
};
|
||||
|
||||
Self {
|
||||
response: Arc::new(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for CachedResponseClient {
|
||||
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
||||
Ok(vec![tool_descriptor()])
|
||||
}
|
||||
|
||||
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
||||
Ok((*self.response).clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn streaming_file_write_consent_denied_returns_resolution() {
|
||||
let temp_dir = tempdir().expect("temp dir");
|
||||
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
|
||||
.await
|
||||
.expect("storage");
|
||||
|
||||
let mut config = Config::default();
|
||||
config.general.enable_streaming = true;
|
||||
config.privacy.encrypt_local_data = false;
|
||||
config.privacy.require_consent_per_session = true;
|
||||
config.general.default_model = Some("mock-model".into());
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
config
|
||||
.refresh_mcp_servers(None)
|
||||
.expect("refresh MCP servers");
|
||||
|
||||
let provider: Arc<dyn Provider> = Arc::new(StreamingToolProvider);
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||
|
||||
let mut session = SessionController::new(
|
||||
provider,
|
||||
config,
|
||||
Arc::new(storage),
|
||||
ui,
|
||||
true,
|
||||
Some(event_tx),
|
||||
)
|
||||
.await
|
||||
.expect("session controller");
|
||||
|
||||
session
|
||||
.set_operating_mode(Mode::Code)
|
||||
.await
|
||||
.expect("code mode");
|
||||
|
||||
let outcome = session
|
||||
.send_message(
|
||||
"Please write to README".to_string(),
|
||||
ChatParameters {
|
||||
stream: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("send message");
|
||||
|
||||
let (response_id, mut stream) = if let SessionOutcome::Streaming {
|
||||
response_id,
|
||||
stream,
|
||||
} = outcome
|
||||
{
|
||||
(response_id, stream)
|
||||
} else {
|
||||
panic!("expected streaming outcome");
|
||||
};
|
||||
|
||||
session
|
||||
.mark_stream_placeholder(response_id, "▌")
|
||||
.expect("placeholder");
|
||||
|
||||
let chunk = stream
|
||||
.next()
|
||||
.await
|
||||
.expect("stream chunk")
|
||||
.expect("chunk result");
|
||||
session
|
||||
.apply_stream_chunk(response_id, &chunk)
|
||||
.expect("apply chunk");
|
||||
|
||||
let tool_calls = session
|
||||
.check_streaming_tool_calls(response_id)
|
||||
.expect("tool calls");
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].name, "resources/write");
|
||||
|
||||
let event = event_rx.recv().await.expect("controller event");
|
||||
let request_id = match event {
|
||||
ControllerEvent::ToolRequested {
|
||||
request_id,
|
||||
tool_name,
|
||||
data_types,
|
||||
endpoints,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_name, "resources/write");
|
||||
assert!(data_types.iter().any(|t| t.contains("file")));
|
||||
assert!(endpoints.iter().any(|e| e.contains("filesystem")));
|
||||
request_id
|
||||
}
|
||||
};
|
||||
|
||||
let resolution = session
|
||||
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
||||
.expect("resolution");
|
||||
assert_eq!(resolution.scope, ConsentScope::Denied);
|
||||
assert_eq!(resolution.tool_name, "resources/write");
|
||||
assert_eq!(resolution.tool_calls.len(), tool_calls.len());
|
||||
|
||||
let err = session
|
||||
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
||||
.expect_err("second resolution should fail");
|
||||
matches!(err, Error::InvalidInput(_));
|
||||
|
||||
let conversation = session.conversation().clone();
|
||||
let assistant = conversation
|
||||
.messages
|
||||
.iter()
|
||||
.find(|message| message.role == Role::Assistant)
|
||||
.expect("assistant message present");
|
||||
assert!(
|
||||
assistant
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.and_then(|calls| calls.first())
|
||||
.is_some_and(|call| call.name == "resources/write"),
|
||||
"stream chunk should capture the tool call on the assistant message"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn web_tool_timeout_fails_over_to_cached_result() {
|
||||
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
||||
let cached = CachedResponseClient::new();
|
||||
let backup: Arc<dyn McpClient> = Arc::new(cached.clone());
|
||||
|
||||
let client = FailoverMcpClient::with_servers(vec![
|
||||
ServerEntry::new("primary".into(), primary, 1),
|
||||
ServerEntry::new("cache".into(), backup, 2),
|
||||
]);
|
||||
|
||||
let call = McpToolCall {
|
||||
name: "web_search".to_string(),
|
||||
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
||||
};
|
||||
|
||||
let response = client.call_tool(call.clone()).await.expect("fallback");
|
||||
|
||||
assert_eq!(response.name, "web_search");
|
||||
assert_eq!(
|
||||
response.metadata.get("source").map(String::as_str),
|
||||
Some("cache")
|
||||
);
|
||||
assert_eq!(
|
||||
response.output.get("note").and_then(|value| value.as_str()),
|
||||
Some("cached result")
|
||||
);
|
||||
|
||||
let statuses = client.get_server_status().await;
|
||||
assert!(statuses.iter().any(|(name, health)| name == "primary"
|
||||
&& !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
||||
assert!(statuses.iter().any(|(name, health)| name == "cache"
|
||||
&& matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use owlen_core::McpToolCall;
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use tempfile::tempdir;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use owlen_core::McpToolCall;
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
use owlen_core::mcp::failover::{FailoverConfig, FailoverMcpClient, ServerEntry, ServerHealth};
|
||||
use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor};
|
||||
use owlen_core::{Error, Result};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Mock MCP client for testing failover behavior
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
//! Integration test for the MCP prompt rendering server.
|
||||
|
||||
use owlen_core::Result;
|
||||
use owlen_core::config::McpServerConfig;
|
||||
use owlen_core::mcp::client::RemoteMcpClient;
|
||||
use owlen_core::mcp::{McpToolCall, McpToolResponse};
|
||||
use owlen_core::Result;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -44,6 +44,7 @@ async fn test_render_prompt_via_external_server() -> Result<()> {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".into(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
let client = match RemoteMcpClient::new_with_config(&config) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use owlen_core::wrap_cursor::{build_cursor_map, ScreenPos};
|
||||
use owlen_core::wrap_cursor::{ScreenPos, build_cursor_map};
|
||||
|
||||
fn assert_cursor_pos(map: &[ScreenPos], byte_idx: usize, expected: ScreenPos) {
|
||||
assert_eq!(map[byte_idx], expected, "Mismatch at byte {}", byte_idx);
|
||||
|
||||
10
crates/owlen-markdown/Cargo.toml
Normal file
10
crates/owlen-markdown/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "owlen-markdown"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Lightweight markdown to ratatui::Text renderer for OWLEN"
|
||||
|
||||
[dependencies]
|
||||
ratatui = { workspace = true }
|
||||
unicode-width = "0.1"
|
||||
270
crates/owlen-markdown/src/lib.rs
Normal file
270
crates/owlen-markdown/src/lib.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::text::{Line, Span, Text};
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
|
||||
/// Convert a markdown string into a `ratatui::Text`.
|
||||
///
|
||||
/// This lightweight renderer supports common constructs (headings, lists, bold,
|
||||
/// italics, and inline code) and is designed to keep dependencies minimal for
|
||||
/// the OWLEN project.
|
||||
pub fn from_str(input: &str) -> Text<'static> {
|
||||
let mut lines = Vec::new();
|
||||
let mut in_code_block = false;
|
||||
|
||||
for raw_line in input.lines() {
|
||||
let line = raw_line.trim_end_matches('\r');
|
||||
let trimmed = line.trim_start();
|
||||
let indent = &line[..line.len() - trimmed.len()];
|
||||
|
||||
if trimmed.starts_with("```") {
|
||||
in_code_block = !in_code_block;
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_code_block {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
trimmed.to_string(),
|
||||
Style::default()
|
||||
.fg(Color::LightYellow)
|
||||
.add_modifier(Modifier::DIM),
|
||||
));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if trimmed.is_empty() {
|
||||
lines.push(Line::from(Vec::<Span<'static>>::new()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if trimmed.starts_with('#') {
|
||||
let level = trimmed.chars().take_while(|c| *c == '#').count().min(6);
|
||||
let content = trimmed[level..].trim_start();
|
||||
let mut style = Style::default().add_modifier(Modifier::BOLD);
|
||||
style = match level {
|
||||
1 => style.fg(Color::LightCyan),
|
||||
2 => style.fg(Color::Cyan),
|
||||
_ => style.fg(Color::LightBlue),
|
||||
};
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(content.to_string(), style));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = trimmed.strip_prefix("- ") {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
"• ".to_string(),
|
||||
Style::default().fg(Color::LightGreen),
|
||||
));
|
||||
spans.extend(parse_inline(rest));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = trimmed.strip_prefix("* ") {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
"• ".to_string(),
|
||||
Style::default().fg(Color::LightGreen),
|
||||
));
|
||||
spans.extend(parse_inline(rest));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((number, rest)) = parse_ordered_item(trimmed) {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
format!("{number}. "),
|
||||
Style::default().fg(Color::LightGreen),
|
||||
));
|
||||
spans.extend(parse_inline(rest));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.extend(parse_inline(trimmed));
|
||||
lines.push(Line::from(spans));
|
||||
}
|
||||
|
||||
if input.is_empty() {
|
||||
lines.push(Line::from(Vec::<Span<'static>>::new()));
|
||||
}
|
||||
|
||||
Text::from(lines)
|
||||
}
|
||||
|
||||
fn parse_ordered_item(line: &str) -> Option<(u32, &str)> {
|
||||
let mut parts = line.splitn(2, '.');
|
||||
let number = parts.next()?.trim();
|
||||
let rest = parts.next()?;
|
||||
if number.chars().all(|c| c.is_ascii_digit()) {
|
||||
let value = number.parse().ok()?;
|
||||
let rest = rest.trim_start();
|
||||
Some((value, rest))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_inline(text: &str) -> Vec<Span<'static>> {
|
||||
let mut spans = Vec::new();
|
||||
let bytes = text.as_bytes();
|
||||
let mut i = 0;
|
||||
let len = bytes.len();
|
||||
let mut plain_start = 0;
|
||||
|
||||
while i < len {
|
||||
if bytes[i] == b'`' {
|
||||
if let Some(offset) = text[i + 1..].find('`') {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 1..i + 1 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default()
|
||||
.fg(Color::LightYellow)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
i += offset + 2;
|
||||
plain_start = i;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if bytes[i] == b'*' {
|
||||
if i + 1 < len && bytes[i + 1] == b'*' {
|
||||
if let Some(offset) = text[i + 2..].find("**") {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 2..i + 2 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
));
|
||||
i += offset + 4;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
} else if let Some(offset) = text[i + 1..].find('*') {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 1..i + 1 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::ITALIC),
|
||||
));
|
||||
i += offset + 2;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if bytes[i] == b'_' {
|
||||
if i + 1 < len && bytes[i + 1] == b'_' {
|
||||
if let Some(offset) = text[i + 2..].find("__") {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 2..i + 2 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
));
|
||||
i += offset + 4;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
} else if let Some(offset) = text[i + 1..].find('_') {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 1..i + 1 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::ITALIC),
|
||||
));
|
||||
i += offset + 2;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if plain_start < len {
|
||||
spans.push(Span::raw(text[plain_start..].to_string()));
|
||||
}
|
||||
|
||||
if spans.is_empty() {
|
||||
spans.push(Span::raw(String::new()));
|
||||
}
|
||||
|
||||
spans
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn visual_length(spans: &[Span<'_>]) -> usize {
|
||||
spans
|
||||
.iter()
|
||||
.map(|span| UnicodeWidthStr::width(span.content.as_ref()))
|
||||
.sum()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn headings_are_bold() {
|
||||
let text = from_str("# Heading");
|
||||
assert_eq!(text.lines.len(), 1);
|
||||
let line = &text.lines[0];
|
||||
assert!(
|
||||
line.spans
|
||||
.iter()
|
||||
.any(|span| span.style.contains(Modifier::BOLD))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_code_styled() {
|
||||
let text = from_str("Use `code` inline.");
|
||||
let styled = text
|
||||
.lines
|
||||
.iter()
|
||||
.flat_map(|line| &line.spans)
|
||||
.find(|span| span.content.as_ref() == "code")
|
||||
.cloned()
|
||||
.unwrap();
|
||||
assert!(styled.style.contains(Modifier::BOLD));
|
||||
}
|
||||
}
|
||||
20
crates/owlen-providers/Cargo.toml
Normal file
20
crates/owlen-providers/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "owlen-providers"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
description = "Provider implementations for OWLEN"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
reqwest = { package = "reqwest", version = "0.11", features = ["json", "stream"] }
|
||||
3
crates/owlen-providers/src/lib.rs
Normal file
3
crates/owlen-providers/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Provider implementations for OWLEN.
|
||||
|
||||
pub mod ollama;
|
||||
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::{env, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::{
|
||||
Error as CoreError, Result as CoreResult,
|
||||
config::OLLAMA_CLOUD_BASE_URL,
|
||||
provider::{
|
||||
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||
ProviderStatus, ProviderType,
|
||||
},
|
||||
};
|
||||
use serde_json::{Number, Value};
|
||||
|
||||
use super::OllamaClient;
|
||||
|
||||
const API_KEY_ENV: &str = "OLLAMA_CLOUD_API_KEY";
|
||||
|
||||
/// ModelProvider implementation for the hosted Ollama Cloud service.
|
||||
pub struct OllamaCloudProvider {
|
||||
client: OllamaClient,
|
||||
}
|
||||
|
||||
impl OllamaCloudProvider {
|
||||
/// Construct a new cloud provider. An API key must be supplied either
|
||||
/// directly or via the `OLLAMA_CLOUD_API_KEY` environment variable.
|
||||
pub fn new(
|
||||
base_url: Option<String>,
|
||||
api_key: Option<String>,
|
||||
request_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let (api_key, key_source) = resolve_api_key(api_key)?;
|
||||
let base_url = base_url.unwrap_or_else(|| OLLAMA_CLOUD_BASE_URL.to_string());
|
||||
|
||||
let mut metadata =
|
||||
ProviderMetadata::new("ollama_cloud", "Ollama (Cloud)", ProviderType::Cloud, true);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||
metadata.metadata.insert(
|
||||
"api_key_source".into(),
|
||||
Value::String(key_source.to_string()),
|
||||
);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("api_key_env".into(), Value::String(API_KEY_ENV.to_string()));
|
||||
|
||||
if let Some(timeout) = request_timeout {
|
||||
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||
metadata.metadata.insert(
|
||||
"request_timeout_ms".into(),
|
||||
Value::Number(Number::from(timeout_ms)),
|
||||
);
|
||||
}
|
||||
|
||||
let client = OllamaClient::new(&base_url, Some(api_key), metadata, request_timeout)?;
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for OllamaCloudProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
self.client.metadata()
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
match self.client.health_check().await {
|
||||
Ok(status) => Ok(status),
|
||||
Err(CoreError::Auth(_)) => Ok(ProviderStatus::RequiresSetup),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
self.client.list_models().await
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
self.client.generate_stream(request).await
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_api_key(api_key: Option<String>) -> CoreResult<(String, &'static str)> {
|
||||
let key_from_config = api_key
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(str::to_string);
|
||||
|
||||
if let Some(key) = key_from_config {
|
||||
return Ok((key, "config"));
|
||||
}
|
||||
|
||||
let key_from_env = env::var(API_KEY_ENV)
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty());
|
||||
|
||||
if let Some(key) = key_from_env {
|
||||
return Ok((key, "env"));
|
||||
}
|
||||
|
||||
Err(CoreError::Config(
|
||||
"Ollama Cloud API key not configured. Set OLLAMA_CLOUD_API_KEY or configure an API key."
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
80
crates/owlen-providers/src/ollama/local.rs
Normal file
80
crates/owlen-providers/src/ollama/local.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::provider::{
|
||||
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata, ProviderStatus,
|
||||
ProviderType,
|
||||
};
|
||||
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||
use serde_json::{Number, Value};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::OllamaClient;
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
|
||||
const DEFAULT_HEALTH_TIMEOUT_SECS: u64 = 5;
|
||||
|
||||
/// ModelProvider implementation for a local Ollama daemon.
|
||||
pub struct OllamaLocalProvider {
|
||||
client: OllamaClient,
|
||||
health_timeout: Duration,
|
||||
}
|
||||
|
||||
impl OllamaLocalProvider {
|
||||
/// Construct a new local provider using the shared [`OllamaClient`].
|
||||
pub fn new(
|
||||
base_url: Option<String>,
|
||||
request_timeout: Option<Duration>,
|
||||
health_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let base_url = base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
|
||||
let health_timeout =
|
||||
health_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_HEALTH_TIMEOUT_SECS));
|
||||
|
||||
let mut metadata =
|
||||
ProviderMetadata::new("ollama_local", "Ollama (Local)", ProviderType::Local, false);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||
if let Some(timeout) = request_timeout {
|
||||
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||
metadata.metadata.insert(
|
||||
"request_timeout_ms".into(),
|
||||
Value::Number(Number::from(timeout_ms)),
|
||||
);
|
||||
}
|
||||
|
||||
let client = OllamaClient::new(&base_url, None, metadata, request_timeout)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
health_timeout,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for OllamaLocalProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
self.client.metadata()
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
match timeout(self.health_timeout, self.client.health_check()).await {
|
||||
Ok(Ok(status)) => Ok(status),
|
||||
Ok(Err(CoreError::Network(_))) | Ok(Err(CoreError::Timeout(_))) => {
|
||||
Ok(ProviderStatus::Unavailable)
|
||||
}
|
||||
Ok(Err(err)) => Err(err),
|
||||
Err(_) => Ok(ProviderStatus::Unavailable),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
self.client.list_models().await
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
self.client.generate_stream(request).await
|
||||
}
|
||||
}
|
||||
7
crates/owlen-providers/src/ollama/mod.rs
Normal file
7
crates/owlen-providers/src/ollama/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod cloud;
|
||||
pub mod local;
|
||||
pub mod shared;
|
||||
|
||||
pub use cloud::OllamaCloudProvider;
|
||||
pub use local::OllamaLocalProvider;
|
||||
pub use shared::OllamaClient;
|
||||
389
crates/owlen-providers/src/ollama/shared.rs
Normal file
389
crates/owlen-providers/src/ollama/shared.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::StreamExt;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ProviderMetadata, ProviderStatus,
|
||||
};
|
||||
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||
use reqwest::{Client, Method, StatusCode, Url};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Map as JsonMap, Value};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 60;
|
||||
|
||||
/// Shared Ollama HTTP client used by both local and cloud providers.
|
||||
#[derive(Clone)]
|
||||
pub struct OllamaClient {
|
||||
http: Client,
|
||||
base_url: Url,
|
||||
api_key: Option<String>,
|
||||
provider_metadata: ProviderMetadata,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Create a new client with the given base URL and optional API key.
|
||||
pub fn new(
|
||||
base_url: impl AsRef<str>,
|
||||
api_key: Option<String>,
|
||||
provider_metadata: ProviderMetadata,
|
||||
request_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let base_url = Url::parse(base_url.as_ref())
|
||||
.map_err(|err| CoreError::Config(format!("invalid base url: {}", err)))?;
|
||||
|
||||
let timeout = request_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_TIMEOUT_SECS));
|
||||
let http = Client::builder()
|
||||
.timeout(timeout)
|
||||
.build()
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
Ok(Self {
|
||||
http,
|
||||
base_url,
|
||||
api_key,
|
||||
provider_metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Provider metadata associated with this client.
|
||||
pub fn metadata(&self) -> &ProviderMetadata {
|
||||
&self.provider_metadata
|
||||
}
|
||||
|
||||
/// Perform a basic health check to determine provider availability.
|
||||
pub async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
let url = self.endpoint("api/tags")?;
|
||||
|
||||
let response = self
|
||||
.request(Method::GET, url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
match response.status() {
|
||||
status if status.is_success() => Ok(ProviderStatus::Available),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Ok(ProviderStatus::RequiresSetup),
|
||||
_ => Ok(ProviderStatus::Unavailable),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the available models from the Ollama API.
|
||||
pub async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
let url = self.endpoint("api/tags")?;
|
||||
|
||||
let response = self
|
||||
.request(Method::GET, url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
let status = response.status();
|
||||
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(map_http_error("tags", status, &bytes));
|
||||
}
|
||||
|
||||
let payload: TagsResponse =
|
||||
serde_json::from_slice(&bytes).map_err(CoreError::Serialization)?;
|
||||
|
||||
let models = payload
|
||||
.models
|
||||
.into_iter()
|
||||
.map(|model| self.parse_model_info(model))
|
||||
.collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Request a streaming generation session from Ollama.
|
||||
pub async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
let url = self.endpoint("api/generate")?;
|
||||
|
||||
let body = self.build_generate_body(request);
|
||||
|
||||
let response = self
|
||||
.request(Method::POST, url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||
return Err(map_http_error("generate", status, &bytes));
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
let (tx, rx) = mpsc::channel::<CoreResult<GenerateChunk>>(32);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
buffer.extend_from_slice(&bytes);
|
||||
while let Some(pos) = buffer.iter().position(|byte| *byte == b'\n') {
|
||||
let line_bytes: Vec<u8> = buffer.drain(..=pos).collect();
|
||||
let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match parse_stream_line(&line) {
|
||||
Ok(item) => {
|
||||
if tx.send(Ok(item)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(Err(err)).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(Err(map_reqwest_error(err))).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !buffer.is_empty() {
|
||||
let line = String::from_utf8_lossy(&buffer).trim().to_string();
|
||||
if !line.is_empty() {
|
||||
match parse_stream_line(&line) {
|
||||
Ok(item) => {
|
||||
let _ = tx.send(Ok(item)).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(Err(err)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn request(&self, method: Method, url: Url) -> reqwest::RequestBuilder {
|
||||
let mut builder = self.http.request(method, url);
|
||||
if let Some(api_key) = &self.api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
fn endpoint(&self, path: &str) -> CoreResult<Url> {
|
||||
self.base_url
|
||||
.join(path)
|
||||
.map_err(|err| CoreError::Config(format!("invalid endpoint '{}': {}", path, err)))
|
||||
}
|
||||
|
||||
fn build_generate_body(&self, request: GenerateRequest) -> Value {
|
||||
let GenerateRequest {
|
||||
model,
|
||||
prompt,
|
||||
context,
|
||||
parameters,
|
||||
metadata,
|
||||
} = request;
|
||||
|
||||
let mut body = JsonMap::new();
|
||||
body.insert("model".into(), Value::String(model));
|
||||
body.insert("stream".into(), Value::Bool(true));
|
||||
|
||||
if let Some(prompt) = prompt {
|
||||
body.insert("prompt".into(), Value::String(prompt));
|
||||
}
|
||||
|
||||
if !context.is_empty() {
|
||||
let items = context.into_iter().map(Value::String).collect();
|
||||
body.insert("context".into(), Value::Array(items));
|
||||
}
|
||||
|
||||
if !parameters.is_empty() {
|
||||
body.insert("options".into(), Value::Object(to_json_map(parameters)));
|
||||
}
|
||||
|
||||
if !metadata.is_empty() {
|
||||
body.insert("metadata".into(), Value::Object(to_json_map(metadata)));
|
||||
}
|
||||
|
||||
Value::Object(body)
|
||||
}
|
||||
|
||||
fn parse_model_info(&self, model: OllamaModel) -> ModelInfo {
|
||||
let mut metadata = HashMap::new();
|
||||
|
||||
if let Some(digest) = model.digest {
|
||||
metadata.insert("digest".to_string(), Value::String(digest));
|
||||
}
|
||||
if let Some(modified) = model.modified_at {
|
||||
metadata.insert("modified_at".to_string(), Value::String(modified));
|
||||
}
|
||||
if let Some(details) = model.details {
|
||||
let mut details_map = JsonMap::new();
|
||||
if let Some(format) = details.format {
|
||||
details_map.insert("format".into(), Value::String(format));
|
||||
}
|
||||
if let Some(family) = details.family {
|
||||
details_map.insert("family".into(), Value::String(family));
|
||||
}
|
||||
if let Some(parameter_size) = details.parameter_size {
|
||||
details_map.insert("parameter_size".into(), Value::String(parameter_size));
|
||||
}
|
||||
if let Some(quantisation) = details.quantization_level {
|
||||
details_map.insert("quantization_level".into(), Value::String(quantisation));
|
||||
}
|
||||
|
||||
if !details_map.is_empty() {
|
||||
metadata.insert("details".to_string(), Value::Object(details_map));
|
||||
}
|
||||
}
|
||||
|
||||
ModelInfo {
|
||||
name: model.name,
|
||||
size_bytes: model.size,
|
||||
capabilities: Vec::new(),
|
||||
description: None,
|
||||
provider: self.provider_metadata.clone(),
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TagsResponse {
|
||||
#[serde(default)]
|
||||
models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaModel {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
size: Option<u64>,
|
||||
#[serde(default)]
|
||||
digest: Option<String>,
|
||||
#[serde(default)]
|
||||
modified_at: Option<String>,
|
||||
#[serde(default)]
|
||||
details: Option<OllamaModelDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaModelDetails {
|
||||
#[serde(default)]
|
||||
format: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
#[serde(default)]
|
||||
parameter_size: Option<String>,
|
||||
#[serde(default)]
|
||||
quantization_level: Option<String>,
|
||||
}
|
||||
|
||||
fn to_json_map(source: HashMap<String, Value>) -> JsonMap<String, Value> {
|
||||
source.into_iter().collect()
|
||||
}
|
||||
|
||||
fn to_metadata_map(value: &Value) -> HashMap<String, Value> {
|
||||
let mut metadata = HashMap::new();
|
||||
|
||||
if let Value::Object(obj) = value {
|
||||
for (key, item) in obj {
|
||||
if key == "response" || key == "done" {
|
||||
continue;
|
||||
}
|
||||
metadata.insert(key.clone(), item.clone());
|
||||
}
|
||||
}
|
||||
|
||||
metadata
|
||||
}
|
||||
|
||||
fn parse_stream_line(line: &str) -> CoreResult<GenerateChunk> {
|
||||
let value: Value = serde_json::from_str(line).map_err(CoreError::Serialization)?;
|
||||
|
||||
if let Some(error) = value.get("error").and_then(Value::as_str) {
|
||||
return Err(CoreError::Provider(anyhow::anyhow!(
|
||||
"ollama generation error: {}",
|
||||
error
|
||||
)));
|
||||
}
|
||||
|
||||
let mut chunk = GenerateChunk {
|
||||
text: value
|
||||
.get("response")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string),
|
||||
is_final: value.get("done").and_then(Value::as_bool).unwrap_or(false),
|
||||
metadata: to_metadata_map(&value),
|
||||
};
|
||||
|
||||
if chunk.is_final && chunk.text.is_none() && chunk.metadata.is_empty() {
|
||||
chunk
|
||||
.metadata
|
||||
.insert("status".into(), Value::String("done".into()));
|
||||
}
|
||||
|
||||
Ok(chunk)
|
||||
}
|
||||
|
||||
fn map_http_error(endpoint: &str, status: StatusCode, body: &[u8]) -> CoreError {
|
||||
match status {
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => CoreError::Auth(format!(
|
||||
"Ollama {} request unauthorized (status {})",
|
||||
endpoint, status
|
||||
)),
|
||||
StatusCode::TOO_MANY_REQUESTS => CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama {} request rate limited (status {})",
|
||||
endpoint,
|
||||
status
|
||||
)),
|
||||
_ => {
|
||||
let snippet = truncated_body(body);
|
||||
CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama {} request failed: HTTP {} - {}",
|
||||
endpoint,
|
||||
status,
|
||||
snippet
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncated_body(body: &[u8]) -> String {
|
||||
const MAX_CHARS: usize = 512;
|
||||
let text = String::from_utf8_lossy(body);
|
||||
let mut value = String::new();
|
||||
for (idx, ch) in text.chars().enumerate() {
|
||||
if idx >= MAX_CHARS {
|
||||
value.push('…');
|
||||
return value;
|
||||
}
|
||||
value.push(ch);
|
||||
}
|
||||
value
|
||||
}
|
||||
|
||||
fn map_reqwest_error(err: reqwest::Error) -> CoreError {
|
||||
if err.is_timeout() {
|
||||
CoreError::Timeout(err.to_string())
|
||||
} else if err.is_connect() || err.is_request() {
|
||||
CoreError::Network(err.to_string())
|
||||
} else {
|
||||
CoreError::Provider(err.into())
|
||||
}
|
||||
}
|
||||
106
crates/owlen-providers/tests/common/mock_provider.rs
Normal file
106
crates/owlen-providers/tests/common/mock_provider.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use owlen_core::Result as CoreResult;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||
ProviderStatus, ProviderType,
|
||||
};
|
||||
|
||||
pub struct MockProvider {
|
||||
metadata: ProviderMetadata,
|
||||
models: Vec<ModelInfo>,
|
||||
status: ProviderStatus,
|
||||
#[allow(clippy::type_complexity)]
|
||||
generate_handler: Option<Arc<dyn Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync>>,
|
||||
generate_error: Option<Arc<dyn Fn() -> owlen_core::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
pub fn new(id: &str) -> Self {
|
||||
let metadata = ProviderMetadata::new(
|
||||
id,
|
||||
format!("Mock Provider ({})", id),
|
||||
ProviderType::Local,
|
||||
false,
|
||||
);
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
models: vec![ModelInfo {
|
||||
name: format!("{}-primary", id),
|
||||
size_bytes: None,
|
||||
capabilities: vec!["chat".into()],
|
||||
description: Some("Mock model".into()),
|
||||
provider: ProviderMetadata::new(id, "Mock", ProviderType::Local, false),
|
||||
metadata: Default::default(),
|
||||
}],
|
||||
status: ProviderStatus::Available,
|
||||
generate_handler: None,
|
||||
generate_error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_models(mut self, models: Vec<ModelInfo>) -> Self {
|
||||
self.models = models;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_status(mut self, status: ProviderStatus) -> Self {
|
||||
self.status = status;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_generate_handler<F>(mut self, handler: F) -> Self
|
||||
where
|
||||
F: Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync + 'static,
|
||||
{
|
||||
self.generate_handler = Some(Arc::new(handler));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_generate_error<F>(mut self, factory: F) -> Self
|
||||
where
|
||||
F: Fn() -> owlen_core::Error + Send + Sync + 'static,
|
||||
{
|
||||
self.generate_error = Some(Arc::new(factory));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for MockProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
Ok(self.status)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
Ok(self.models.clone())
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
if let Some(factory) = &self.generate_error {
|
||||
return Err(factory());
|
||||
}
|
||||
|
||||
let chunks = if let Some(handler) = &self.generate_handler {
|
||||
(handler)(request)
|
||||
} else {
|
||||
vec![GenerateChunk::final_chunk()]
|
||||
};
|
||||
|
||||
let stream = stream::iter(chunks.into_iter().map(Ok)).boxed();
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MockProvider> for Arc<dyn ModelProvider> {
|
||||
fn from(provider: MockProvider) -> Self {
|
||||
Arc::new(provider)
|
||||
}
|
||||
}
|
||||
1
crates/owlen-providers/tests/common/mod.rs
Normal file
1
crates/owlen-providers/tests/common/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod mock_provider;
|
||||
117
crates/owlen-providers/tests/integration_test.rs
Normal file
117
crates/owlen-providers/tests/integration_test.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
mod common;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
|
||||
use common::mock_provider::MockProvider;
|
||||
use owlen_core::config::Config;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, ModelInfo, ProviderManager, ProviderType,
|
||||
};
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn base_config() -> Config {
|
||||
Config {
|
||||
providers: Default::default(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_model(name: &str, provider: &str) -> ModelInfo {
|
||||
ModelInfo {
|
||||
name: name.into(),
|
||||
size_bytes: None,
|
||||
capabilities: vec!["chat".into()],
|
||||
description: Some("mock".into()),
|
||||
provider: owlen_core::provider::ProviderMetadata::new(
|
||||
provider,
|
||||
provider,
|
||||
ProviderType::Local,
|
||||
false,
|
||||
),
|
||||
metadata: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registers_providers_and_lists_ids() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider: Arc<dyn owlen_core::provider::ModelProvider> = MockProvider::new("mock-a").into();
|
||||
|
||||
manager.register_provider(provider).await;
|
||||
let ids = manager.provider_ids().await;
|
||||
|
||||
assert_eq!(ids, vec!["mock-a".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregates_models_across_providers() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider_a = MockProvider::new("mock-a").with_models(vec![make_model("alpha", "mock-a")]);
|
||||
let provider_b = MockProvider::new("mock-b").with_models(vec![make_model("beta", "mock-b")]);
|
||||
|
||||
manager.register_provider(provider_a.into()).await;
|
||||
manager.register_provider(provider_b.into()).await;
|
||||
|
||||
let models = manager.list_all_models().await.unwrap();
|
||||
assert_eq!(models.len(), 2);
|
||||
assert!(models.iter().any(|m| m.model.name == "alpha"));
|
||||
assert!(models.iter().any(|m| m.model.name == "beta"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn routes_generation_to_specific_provider() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider = MockProvider::new("mock-gen").with_generate_handler(|_req| {
|
||||
vec![
|
||||
GenerateChunk::from_text("hello"),
|
||||
GenerateChunk::final_chunk(),
|
||||
]
|
||||
});
|
||||
|
||||
manager.register_provider(provider.into()).await;
|
||||
|
||||
let request = GenerateRequest::new("mock-gen::primary");
|
||||
let mut stream = manager.generate("mock-gen", request).await.unwrap();
|
||||
let mut collected = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
collected.push(chunk.unwrap());
|
||||
}
|
||||
|
||||
assert_eq!(collected.len(), 2);
|
||||
assert_eq!(collected[0].text.as_deref(), Some("hello"));
|
||||
assert!(collected[1].is_final);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn marks_provider_unavailable_on_error() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider = MockProvider::new("flaky")
|
||||
.with_generate_error(|| owlen_core::Error::Network("boom".into()));
|
||||
|
||||
manager.register_provider(provider.into()).await;
|
||||
let request = GenerateRequest::new("flaky::model");
|
||||
let result = manager.generate("flaky", request).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
let status = manager.provider_status("flaky").await.unwrap();
|
||||
assert!(matches!(
|
||||
status,
|
||||
owlen_core::provider::ProviderStatus::Unavailable
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_refresh_updates_status_cache() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider =
|
||||
MockProvider::new("healthy").with_status(owlen_core::provider::ProviderStatus::Available);
|
||||
|
||||
manager.register_provider(provider.into()).await;
|
||||
let statuses = manager.refresh_health().await;
|
||||
assert_eq!(
|
||||
statuses.get("healthy"),
|
||||
Some(&owlen_core::provider::ProviderStatus::Available)
|
||||
);
|
||||
}
|
||||
@@ -18,7 +18,20 @@ crossterm = { workspace = true }
|
||||
tui-textarea = { workspace = true }
|
||||
textwrap = { workspace = true }
|
||||
unicode-width = "0.1"
|
||||
unicode-segmentation = "1.11"
|
||||
async-trait = "0.1"
|
||||
globset = "0.4"
|
||||
ignore = "0.4"
|
||||
pathdiff = "0.2"
|
||||
tree-sitter = "0.20"
|
||||
tree-sitter-rust = "0.20"
|
||||
dirs = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
syntect = "5.3"
|
||||
once_cell = "1.19"
|
||||
owlen-markdown = { path = "../owlen-markdown" }
|
||||
shellexpand = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
@@ -29,6 +42,9 @@ futures-util = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
chrono = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
|
||||
99
crates/owlen-tui/keymap.toml
Normal file
99
crates/owlen-tui/keymap.toml
Normal file
@@ -0,0 +1,99 @@
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["m"]
|
||||
command = "model.open_all"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+Shift+L"]
|
||||
command = "model.open_local"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+Shift+C"]
|
||||
command = "model.open_cloud"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+Shift+P"]
|
||||
command = "model.open_available"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+P"]
|
||||
command = "palette.open"
|
||||
|
||||
[[binding]]
|
||||
mode = "editing"
|
||||
keys = ["Ctrl+P"]
|
||||
command = "palette.open"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Tab"]
|
||||
command = "focus.next"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Shift+Tab"]
|
||||
command = "focus.prev"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+1"]
|
||||
command = "focus.files"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+2"]
|
||||
command = "focus.chat"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+3"]
|
||||
command = "focus.code"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+4"]
|
||||
command = "focus.thinking"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+5"]
|
||||
command = "focus.input"
|
||||
|
||||
[[binding]]
|
||||
mode = "editing"
|
||||
keys = ["Enter"]
|
||||
command = "composer.submit"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+;"]
|
||||
command = "mode.command"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "editing"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "visual"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "command"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "help"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
77
crates/owlen-tui/src/app/generation.rs
Normal file
77
crates/owlen-tui/src/app/generation.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures_util::StreamExt;
|
||||
use owlen_core::provider::GenerateRequest;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{ActiveGeneration, App, AppMessage};
|
||||
|
||||
impl App {
|
||||
/// Kick off a new generation task on the supplied provider.
|
||||
pub fn start_generation(
|
||||
&mut self,
|
||||
provider_id: impl Into<String>,
|
||||
request: GenerateRequest,
|
||||
) -> Result<Uuid> {
|
||||
let provider_id = provider_id.into();
|
||||
let request_id = Uuid::new_v4();
|
||||
|
||||
// Cancel any existing task so we don't interleave output.
|
||||
if let Some(active) = self.active_generation.take() {
|
||||
active.abort();
|
||||
}
|
||||
|
||||
self.message_tx
|
||||
.send(AppMessage::GenerateStart {
|
||||
request_id,
|
||||
provider_id: provider_id.clone(),
|
||||
request: request.clone(),
|
||||
})
|
||||
.map_err(|err| anyhow!("failed to queue generation start: {err:?}"))?;
|
||||
|
||||
let manager = Arc::clone(&self.provider_manager);
|
||||
let message_tx = self.message_tx.clone();
|
||||
let provider_for_task = provider_id.clone();
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let mut stream = match manager.generate(&provider_for_task, request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
let _ = message_tx.send(AppMessage::GenerateError {
|
||||
request_id: Some(request_id),
|
||||
message: err.to_string(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if message_tx
|
||||
.send(AppMessage::GenerateChunk { request_id, chunk })
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = message_tx.send(AppMessage::GenerateError {
|
||||
request_id: Some(request_id),
|
||||
message: err.to_string(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = message_tx.send(AppMessage::GenerateComplete { request_id });
|
||||
});
|
||||
|
||||
let generation = ActiveGeneration::new(request_id, provider_id, join_handle);
|
||||
self.active_generation = Some(generation);
|
||||
|
||||
Ok(request_id)
|
||||
}
|
||||
}
|
||||
135
crates/owlen-tui/src/app/handler.rs
Normal file
135
crates/owlen-tui/src/app/handler.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use super::{App, messages::AppMessage};
|
||||
use log::warn;
|
||||
use owlen_core::{
|
||||
provider::{GenerateChunk, GenerateRequest, ProviderStatus},
|
||||
state::AppState,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Trait implemented by UI state containers to react to [`AppMessage`] events.
|
||||
pub trait MessageState {
|
||||
/// Called when a generation request is about to start.
|
||||
#[allow(unused_variables)]
|
||||
fn start_generation(
|
||||
&mut self,
|
||||
request_id: Uuid,
|
||||
provider_id: &str,
|
||||
request: &GenerateRequest,
|
||||
) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called for every streamed generation chunk.
|
||||
#[allow(unused_variables)]
|
||||
fn append_chunk(&mut self, request_id: Uuid, chunk: &GenerateChunk) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a generation finishes successfully.
|
||||
#[allow(unused_variables)]
|
||||
fn generation_complete(&mut self, request_id: Uuid) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a generation fails.
|
||||
#[allow(unused_variables)]
|
||||
fn generation_failed(&mut self, request_id: Option<Uuid>, message: &str) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when refreshed model metadata is available.
|
||||
fn update_model_list(&mut self) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a models refresh has been requested.
|
||||
fn refresh_model_list(&mut self) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when provider status updates arrive.
|
||||
#[allow(unused_variables)]
|
||||
fn update_provider_status(&mut self, provider_id: &str, status: ProviderStatus) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a resize event occurs.
|
||||
#[allow(unused_variables)]
|
||||
fn handle_resize(&mut self, width: u16, height: u16) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called on periodic ticks.
|
||||
fn handle_tick(&mut self) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
}
|
||||
|
||||
impl App {
|
||||
/// Dispatch a message to the provided [`MessageState`]. Returns `true` when the
|
||||
/// state indicates the UI should exit.
|
||||
pub fn handle_message<State>(&mut self, state: &mut State, message: AppMessage) -> bool
|
||||
where
|
||||
State: MessageState,
|
||||
{
|
||||
use AppMessage::*;
|
||||
|
||||
let outcome = match message {
|
||||
KeyPress(_) => AppState::Running,
|
||||
Resize { width, height } => state.handle_resize(width, height),
|
||||
Tick => state.handle_tick(),
|
||||
GenerateStart {
|
||||
request_id,
|
||||
provider_id,
|
||||
request,
|
||||
} => state.start_generation(request_id, &provider_id, &request),
|
||||
GenerateChunk { request_id, chunk } => state.append_chunk(request_id, &chunk),
|
||||
GenerateComplete { request_id } => {
|
||||
self.clear_active_generation(request_id);
|
||||
state.generation_complete(request_id)
|
||||
}
|
||||
GenerateError {
|
||||
request_id,
|
||||
message,
|
||||
} => {
|
||||
self.clear_active_generation_optional(request_id);
|
||||
state.generation_failed(request_id, &message)
|
||||
}
|
||||
ModelsRefresh => state.refresh_model_list(),
|
||||
ModelsUpdated => state.update_model_list(),
|
||||
ProviderStatus {
|
||||
provider_id,
|
||||
status,
|
||||
} => state.update_provider_status(&provider_id, status),
|
||||
};
|
||||
|
||||
matches!(outcome, AppState::Quit)
|
||||
}
|
||||
|
||||
fn clear_active_generation(&mut self, request_id: Uuid) {
|
||||
if self
|
||||
.active_generation
|
||||
.as_ref()
|
||||
.map(|active| active.request_id() == request_id)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.active_generation = None;
|
||||
} else {
|
||||
warn!(
|
||||
"received completion for unknown request {}, ignoring",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_active_generation_optional(&mut self, request_id: Option<Uuid>) {
|
||||
match request_id {
|
||||
Some(id) => self.clear_active_generation(id),
|
||||
None => {
|
||||
if self.active_generation.is_some() {
|
||||
self.active_generation = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
41
crates/owlen-tui/src/app/messages.rs
Normal file
41
crates/owlen-tui/src/app/messages.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use crossterm::event::KeyEvent;
|
||||
use owlen_core::provider::{GenerateChunk, GenerateRequest, ProviderStatus};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Messages exchanged between the UI event loop and background workers.
|
||||
#[derive(Debug)]
|
||||
pub enum AppMessage {
|
||||
/// User input event bubbled up from the terminal layer.
|
||||
KeyPress(KeyEvent),
|
||||
/// Terminal resize notification.
|
||||
Resize { width: u16, height: u16 },
|
||||
/// Periodic tick used to drive animations.
|
||||
Tick,
|
||||
/// Initiate a new text generation request.
|
||||
GenerateStart {
|
||||
request_id: Uuid,
|
||||
provider_id: String,
|
||||
request: GenerateRequest,
|
||||
},
|
||||
/// Streamed response chunk from the active generation task.
|
||||
GenerateChunk {
|
||||
request_id: Uuid,
|
||||
chunk: GenerateChunk,
|
||||
},
|
||||
/// Generation finished successfully.
|
||||
GenerateComplete { request_id: Uuid },
|
||||
/// Generation failed or was aborted.
|
||||
GenerateError {
|
||||
request_id: Option<Uuid>,
|
||||
message: String,
|
||||
},
|
||||
/// Trigger a background refresh of available models.
|
||||
ModelsRefresh,
|
||||
/// New model list data is ready.
|
||||
ModelsUpdated,
|
||||
/// Provider health status update.
|
||||
ProviderStatus {
|
||||
provider_id: String,
|
||||
status: ProviderStatus,
|
||||
},
|
||||
}
|
||||
240
crates/owlen-tui/src/app/mod.rs
Normal file
240
crates/owlen-tui/src/app/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
mod generation;
|
||||
mod handler;
|
||||
pub mod mvu;
|
||||
mod worker;
|
||||
|
||||
pub mod messages;
|
||||
pub use worker::background_worker;
|
||||
|
||||
use std::{
|
||||
io,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use crossterm::event::{self, KeyEventKind};
|
||||
use owlen_core::{provider::ProviderManager, state::AppState};
|
||||
use ratatui::{Terminal, backend::CrosstermBackend};
|
||||
use tokio::{
|
||||
sync::mpsc::{self, error::TryRecvError},
|
||||
task::{AbortHandle, JoinHandle, yield_now},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Event, SessionEvent, events};
|
||||
|
||||
pub use handler::MessageState;
|
||||
pub use messages::AppMessage;
|
||||
|
||||
#[async_trait]
|
||||
pub trait UiRuntime: MessageState {
|
||||
async fn handle_ui_event(&mut self, event: Event) -> Result<AppState>;
|
||||
async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()>;
|
||||
async fn process_pending_llm_request(&mut self) -> Result<()>;
|
||||
async fn process_pending_tool_execution(&mut self) -> Result<()>;
|
||||
fn poll_controller_events(&mut self) -> Result<()>;
|
||||
fn advance_loading_animation(&mut self);
|
||||
fn streaming_count(&self) -> usize;
|
||||
}
|
||||
|
||||
/// High-level application state driving the non-blocking TUI.
|
||||
pub struct App {
|
||||
provider_manager: Arc<ProviderManager>,
|
||||
message_tx: mpsc::UnboundedSender<AppMessage>,
|
||||
message_rx: Option<mpsc::UnboundedReceiver<AppMessage>>,
|
||||
active_generation: Option<ActiveGeneration>,
|
||||
}
|
||||
|
||||
impl App {
|
||||
/// Construct a new application instance with an associated message channel.
|
||||
pub fn new(provider_manager: Arc<ProviderManager>) -> Self {
|
||||
let (message_tx, message_rx) = mpsc::unbounded_channel();
|
||||
|
||||
Self {
|
||||
provider_manager,
|
||||
message_tx,
|
||||
message_rx: Some(message_rx),
|
||||
active_generation: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cloneable sender handle for pushing messages into the application loop.
|
||||
pub fn message_sender(&self) -> mpsc::UnboundedSender<AppMessage> {
|
||||
self.message_tx.clone()
|
||||
}
|
||||
|
||||
/// Whether a generation task is currently in flight.
|
||||
pub fn has_active_generation(&self) -> bool {
|
||||
self.active_generation.is_some()
|
||||
}
|
||||
|
||||
/// Abort any in-flight generation task.
|
||||
pub fn abort_active_generation(&mut self) {
|
||||
if let Some(active) = self.active_generation.take() {
|
||||
active.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Launch the background worker responsible for provider health checks.
|
||||
pub fn spawn_background_worker(&self) -> JoinHandle<()> {
|
||||
let manager = Arc::clone(&self.provider_manager);
|
||||
let sender = self.message_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
worker::background_worker(manager, sender).await;
|
||||
})
|
||||
}
|
||||
|
||||
/// Drive the main UI loop, handling terminal events, background messages, and
|
||||
/// provider status updates without blocking rendering.
|
||||
pub async fn run<State, RenderFn>(
|
||||
&mut self,
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
state: &mut State,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
mut render: RenderFn,
|
||||
) -> Result<AppState>
|
||||
where
|
||||
State: UiRuntime,
|
||||
RenderFn: FnMut(&mut Terminal<CrosstermBackend<io::Stdout>>, &mut State) -> Result<()>,
|
||||
{
|
||||
let mut message_rx = self
|
||||
.message_rx
|
||||
.take()
|
||||
.expect("App::run called without an available message receiver");
|
||||
|
||||
let poll_interval = Duration::from_millis(16);
|
||||
let mut last_frame = Instant::now();
|
||||
let frame_interval = Duration::from_millis(16);
|
||||
|
||||
let mut worker_handle = Some(self.spawn_background_worker());
|
||||
|
||||
let exit_state = AppState::Quit;
|
||||
'main: loop {
|
||||
state.advance_loading_animation();
|
||||
|
||||
state.process_pending_llm_request().await?;
|
||||
state.process_pending_tool_execution().await?;
|
||||
state.poll_controller_events()?;
|
||||
|
||||
loop {
|
||||
match session_rx.try_recv() {
|
||||
Ok(session_event) => {
|
||||
state.handle_session_event(session_event).await?;
|
||||
}
|
||||
Err(TryRecvError::Empty) => break,
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
match message_rx.try_recv() {
|
||||
Ok(message) => {
|
||||
if self.handle_message(state, message) {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
if last_frame.elapsed() >= frame_interval {
|
||||
render(terminal, state)?;
|
||||
last_frame = Instant::now();
|
||||
}
|
||||
|
||||
if self.handle_message(state, AppMessage::Tick) {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
break 'main;
|
||||
}
|
||||
|
||||
match event::poll(poll_interval) {
|
||||
Ok(true) => match event::read() {
|
||||
Ok(raw_event) => {
|
||||
if let Some(ui_event) = events::from_crossterm_event(raw_event) {
|
||||
if let Event::Key(key) = &ui_event {
|
||||
if key.kind == KeyEventKind::Press {
|
||||
let _ = self.message_tx.send(AppMessage::KeyPress(*key));
|
||||
}
|
||||
} else if let Event::Resize(width, height) = &ui_event {
|
||||
let _ = self.message_tx.send(AppMessage::Resize {
|
||||
width: *width,
|
||||
height: *height,
|
||||
});
|
||||
}
|
||||
|
||||
if matches!(state.handle_ui_event(ui_event).await?, AppState::Quit) {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
},
|
||||
Ok(false) => {}
|
||||
Err(err) => {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
|
||||
yield_now().await;
|
||||
}
|
||||
|
||||
if let Some(handle) = worker_handle {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
self.message_rx = Some(message_rx);
|
||||
|
||||
Ok(exit_state)
|
||||
}
|
||||
}
|
||||
|
||||
struct ActiveGeneration {
|
||||
request_id: Uuid,
|
||||
#[allow(dead_code)]
|
||||
provider_id: String,
|
||||
abort_handle: AbortHandle,
|
||||
#[allow(dead_code)]
|
||||
join_handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl ActiveGeneration {
|
||||
fn new(request_id: Uuid, provider_id: String, join_handle: JoinHandle<()>) -> Self {
|
||||
let abort_handle = join_handle.abort_handle();
|
||||
Self {
|
||||
request_id,
|
||||
provider_id,
|
||||
abort_handle,
|
||||
join_handle,
|
||||
}
|
||||
}
|
||||
|
||||
fn abort(self) {
|
||||
self.abort_handle.abort();
|
||||
}
|
||||
|
||||
fn request_id(&self) -> Uuid {
|
||||
self.request_id
|
||||
}
|
||||
}
|
||||
165
crates/owlen-tui/src/app/mvu.rs
Normal file
165
crates/owlen-tui/src/app/mvu.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use owlen_core::{consent::ConsentScope, ui::InputMode};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AppModel {
|
||||
pub composer: ComposerModel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComposerModel {
|
||||
pub draft: String,
|
||||
pub pending_submit: bool,
|
||||
pub mode: InputMode,
|
||||
}
|
||||
|
||||
impl Default for ComposerModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
draft: String::new(),
|
||||
pending_submit: false,
|
||||
mode: InputMode::Normal,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AppEvent {
|
||||
Composer(ComposerEvent),
|
||||
ToolPermission {
|
||||
request_id: Uuid,
|
||||
scope: ConsentScope,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ComposerEvent {
|
||||
DraftChanged { content: String },
|
||||
ModeChanged { mode: InputMode },
|
||||
Submit,
|
||||
SubmissionHandled { result: SubmissionOutcome },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SubmissionOutcome {
|
||||
MessageSent,
|
||||
CommandExecuted,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AppEffect {
|
||||
SetStatus(String),
|
||||
RequestSubmit,
|
||||
ResolveToolConsent {
|
||||
request_id: Uuid,
|
||||
scope: ConsentScope,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn update(model: &mut AppModel, event: AppEvent) -> Vec<AppEffect> {
|
||||
match event {
|
||||
AppEvent::Composer(event) => update_composer(&mut model.composer, event),
|
||||
AppEvent::ToolPermission { request_id, scope } => {
|
||||
vec![AppEffect::ResolveToolConsent { request_id, scope }]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_composer(model: &mut ComposerModel, event: ComposerEvent) -> Vec<AppEffect> {
|
||||
match event {
|
||||
ComposerEvent::DraftChanged { content } => {
|
||||
model.draft = content;
|
||||
Vec::new()
|
||||
}
|
||||
ComposerEvent::ModeChanged { mode } => {
|
||||
model.mode = mode;
|
||||
Vec::new()
|
||||
}
|
||||
ComposerEvent::Submit => {
|
||||
if model.draft.trim().is_empty() {
|
||||
return vec![AppEffect::SetStatus(
|
||||
"Cannot send empty message".to_string(),
|
||||
)];
|
||||
}
|
||||
|
||||
model.pending_submit = true;
|
||||
vec![AppEffect::RequestSubmit]
|
||||
}
|
||||
ComposerEvent::SubmissionHandled { result } => {
|
||||
model.pending_submit = false;
|
||||
match result {
|
||||
SubmissionOutcome::MessageSent | SubmissionOutcome::CommandExecuted => {
|
||||
model.draft.clear();
|
||||
if model.mode == InputMode::Editing {
|
||||
model.mode = InputMode::Normal;
|
||||
}
|
||||
}
|
||||
SubmissionOutcome::Failed => {}
|
||||
}
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn submit_with_empty_draft_sets_error() {
|
||||
let mut model = AppModel::default();
|
||||
let effects = update(&mut model, AppEvent::Composer(ComposerEvent::Submit));
|
||||
|
||||
assert!(!model.composer.pending_submit);
|
||||
assert_eq!(effects.len(), 1);
|
||||
match &effects[0] {
|
||||
AppEffect::SetStatus(message) => {
|
||||
assert!(message.contains("Cannot send empty message"));
|
||||
}
|
||||
other => panic!("unexpected effect: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn submit_with_content_requests_processing() {
|
||||
let mut model = AppModel::default();
|
||||
let _ = update(
|
||||
&mut model,
|
||||
AppEvent::Composer(ComposerEvent::DraftChanged {
|
||||
content: "hello world".into(),
|
||||
}),
|
||||
);
|
||||
|
||||
let effects = update(&mut model, AppEvent::Composer(ComposerEvent::Submit));
|
||||
|
||||
assert!(model.composer.pending_submit);
|
||||
assert_eq!(effects.len(), 1);
|
||||
matches!(effects[0], AppEffect::RequestSubmit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn submission_success_clears_draft_and_mode() {
|
||||
let mut model = AppModel::default();
|
||||
let _ = update(
|
||||
&mut model,
|
||||
AppEvent::Composer(ComposerEvent::DraftChanged {
|
||||
content: "hello world".into(),
|
||||
}),
|
||||
);
|
||||
let _ = update(&mut model, AppEvent::Composer(ComposerEvent::Submit));
|
||||
assert!(model.composer.pending_submit);
|
||||
|
||||
let effects = update(
|
||||
&mut model,
|
||||
AppEvent::Composer(ComposerEvent::SubmissionHandled {
|
||||
result: SubmissionOutcome::MessageSent,
|
||||
}),
|
||||
);
|
||||
|
||||
assert!(effects.is_empty());
|
||||
assert!(!model.composer.pending_submit);
|
||||
assert!(model.composer.draft.is_empty());
|
||||
assert_eq!(model.composer.mode, InputMode::Normal);
|
||||
}
|
||||
}
|
||||
52
crates/owlen-tui/src/app/worker.rs
Normal file
52
crates/owlen-tui/src/app/worker.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::{sync::mpsc, time};
|
||||
|
||||
use owlen_core::provider::ProviderManager;
|
||||
|
||||
use super::AppMessage;
|
||||
|
||||
const HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Periodically refresh provider health and emit status updates into the app's
|
||||
/// message channel. Exits automatically once the receiver side of the channel
|
||||
/// is dropped.
|
||||
pub async fn background_worker(
|
||||
provider_manager: Arc<ProviderManager>,
|
||||
message_tx: mpsc::UnboundedSender<AppMessage>,
|
||||
) {
|
||||
let mut interval = time::interval(HEALTH_CHECK_INTERVAL);
|
||||
let mut last_statuses = provider_manager.provider_statuses().await;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if message_tx.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
let statuses = provider_manager.refresh_health().await;
|
||||
|
||||
for (provider_id, status) in statuses {
|
||||
let changed = match last_statuses.get(&provider_id) {
|
||||
Some(previous) => previous != &status,
|
||||
None => true,
|
||||
};
|
||||
|
||||
last_statuses.insert(provider_id.clone(), status);
|
||||
|
||||
if changed
|
||||
&& message_tx
|
||||
.send(AppMessage::ProviderStatus {
|
||||
provider_id,
|
||||
status,
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
// Receiver dropped; terminate worker.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use owlen_core::session::SessionController;
|
||||
use owlen_core::session::{ControllerEvent, SessionController};
|
||||
use owlen_core::ui::{AppState, InputMode};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -16,11 +16,12 @@ pub struct CodeApp {
|
||||
impl CodeApp {
|
||||
pub async fn new(
|
||||
mut controller: SessionController,
|
||||
controller_event_rx: mpsc::UnboundedReceiver<ControllerEvent>,
|
||||
) -> Result<(Self, mpsc::UnboundedReceiver<SessionEvent>)> {
|
||||
controller
|
||||
.conversation_mut()
|
||||
.push_system_message(DEFAULT_SYSTEM_PROMPT.to_string());
|
||||
let (inner, rx) = ChatApp::new(controller).await?;
|
||||
let (inner, rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||
Ok((Self { inner }, rx))
|
||||
}
|
||||
|
||||
@@ -28,8 +29,8 @@ impl CodeApp {
|
||||
self.inner.handle_event(event).await
|
||||
}
|
||||
|
||||
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
self.inner.handle_session_event(event)
|
||||
pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
self.inner.handle_session_event(event).await
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> InputMode {
|
||||
|
||||
349
crates/owlen-tui/src/commands/mod.rs
Normal file
349
crates/owlen-tui/src/commands/mod.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
pub mod registry;
|
||||
pub use registry::{AppCommand, CommandRegistry};
|
||||
|
||||
// Command catalog and lookup utilities for the command palette.
|
||||
|
||||
/// Metadata describing a single command keyword.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CommandSpec {
|
||||
pub keyword: &'static str,
|
||||
pub description: &'static str,
|
||||
}
|
||||
|
||||
const COMMANDS: &[CommandSpec] = &[
|
||||
CommandSpec {
|
||||
keyword: "quit",
|
||||
description: "Exit the application",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "q",
|
||||
description: "Close the active file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "w",
|
||||
description: "Save the active file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "write",
|
||||
description: "Alias for w",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "clear",
|
||||
description: "Clear the conversation",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "c",
|
||||
description: "Alias for clear",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "save",
|
||||
description: "Alias for w",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "wq",
|
||||
description: "Save and close the active file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "x",
|
||||
description: "Alias for wq",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "load",
|
||||
description: "Load a saved conversation",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "o",
|
||||
description: "Alias for load",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "open",
|
||||
description: "Open a file in the code view",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "create",
|
||||
description: "Create a file (creates missing directories)",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "close",
|
||||
description: "Close the active code view",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "mode",
|
||||
description: "Switch operating mode (chat/code)",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "code",
|
||||
description: "Switch to code mode",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "chat",
|
||||
description: "Switch to chat mode",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "tools",
|
||||
description: "List available tools in current mode",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "sessions",
|
||||
description: "List saved sessions",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "session save",
|
||||
description: "Save the current conversation",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "help",
|
||||
description: "Open the help overlay",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "h",
|
||||
description: "Alias for help",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "model",
|
||||
description: "Select a model",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "provider",
|
||||
description: "Switch provider or set its mode",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud setup",
|
||||
description: "Configure Ollama Cloud credentials",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud status",
|
||||
description: "Check Ollama Cloud connectivity",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud models",
|
||||
description: "List models available in Ollama Cloud",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud logout",
|
||||
description: "Remove stored Ollama Cloud credentials",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "model info",
|
||||
description: "Show detailed information for a model",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "model refresh",
|
||||
description: "Refresh cached model information",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "model details",
|
||||
description: "Show details for the active model",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "m",
|
||||
description: "Alias for model",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models info",
|
||||
description: "Prefetch detailed information for all models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models --local",
|
||||
description: "Open model picker focused on local models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models --cloud",
|
||||
description: "Open model picker focused on cloud models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models --available",
|
||||
description: "Open model picker showing available models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "new",
|
||||
description: "Start a new conversation",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "n",
|
||||
description: "Alias for new",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "theme",
|
||||
description: "Switch theme",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "themes",
|
||||
description: "List available themes",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "tutorial",
|
||||
description: "Show keybinding tutorial",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "reload",
|
||||
description: "Reload configuration and themes",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "markdown",
|
||||
description: "Toggle markdown rendering",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "e",
|
||||
description: "Edit a file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "edit",
|
||||
description: "Alias for edit",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "ls",
|
||||
description: "List directory contents",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "privacy-enable",
|
||||
description: "Enable a privacy-sensitive tool",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "privacy-disable",
|
||||
description: "Disable a privacy-sensitive tool",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "privacy-clear",
|
||||
description: "Clear stored secure data",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent",
|
||||
description: "Enable agent mode for autonomous task execution",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "stop-agent",
|
||||
description: "Stop the running agent",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent status",
|
||||
description: "Show current agent status",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent start",
|
||||
description: "Arm the agent for the next request",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent stop",
|
||||
description: "Stop the running agent",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "layout save",
|
||||
description: "Persist the current pane layout",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "layout load",
|
||||
description: "Restore the last saved pane layout",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "files",
|
||||
description: "Toggle the files panel",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "explorer",
|
||||
description: "Alias for files",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "debug log",
|
||||
description: "Toggle the debug log panel",
|
||||
},
|
||||
];
|
||||
|
||||
/// Return the static catalog of commands.
|
||||
pub fn all() -> &'static [CommandSpec] {
|
||||
COMMANDS
|
||||
}
|
||||
|
||||
/// Return the default suggestion list (all command keywords).
|
||||
pub fn default_suggestions() -> Vec<CommandSpec> {
|
||||
COMMANDS.to_vec()
|
||||
}
|
||||
|
||||
/// Generate keyword suggestions for the given input.
|
||||
pub fn suggestions(input: &str) -> Vec<CommandSpec> {
|
||||
let trimmed = input.trim();
|
||||
if trimmed.is_empty() {
|
||||
return default_suggestions();
|
||||
}
|
||||
|
||||
let mut matches: Vec<(usize, usize, CommandSpec)> = COMMANDS
|
||||
.iter()
|
||||
.filter_map(|spec| {
|
||||
match_score(spec.keyword, trimmed).map(|score| (score.0, score.1, *spec))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
return default_suggestions();
|
||||
}
|
||||
|
||||
matches.sort_by(|a, b| {
|
||||
a.0.cmp(&b.0)
|
||||
.then(a.1.cmp(&b.1))
|
||||
.then(a.2.keyword.cmp(b.2.keyword))
|
||||
});
|
||||
|
||||
matches.into_iter().map(|(_, _, spec)| spec).collect()
|
||||
}
|
||||
|
||||
pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> {
|
||||
let query = query.trim();
|
||||
if query.is_empty() {
|
||||
return Some((usize::MAX, candidate.len()));
|
||||
}
|
||||
|
||||
let candidate_normalized = candidate.trim().to_lowercase();
|
||||
if candidate_normalized.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let query_normalized = query.to_lowercase();
|
||||
|
||||
if candidate_normalized == query_normalized {
|
||||
Some((0, candidate.len()))
|
||||
} else if candidate_normalized.starts_with(&query_normalized) {
|
||||
Some((1, 0))
|
||||
} else if let Some(pos) = candidate_normalized.find(&query_normalized) {
|
||||
Some((2, pos))
|
||||
} else if is_subsequence(&candidate_normalized, &query_normalized) {
|
||||
Some((3, candidate.len()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn suggestions_prioritize_agent_start() {
|
||||
let results = suggestions("agent st");
|
||||
assert!(!results.is_empty());
|
||||
assert_eq!(results[0].keyword, "agent start");
|
||||
assert!(results.iter().any(|spec| spec.keyword == "agent stop"));
|
||||
}
|
||||
}
|
||||
|
||||
fn is_subsequence(text: &str, pattern: &str) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut pattern_chars = pattern.chars();
|
||||
let mut current = match pattern_chars.next() {
|
||||
Some(ch) => ch,
|
||||
None => return true,
|
||||
};
|
||||
|
||||
for ch in text.chars() {
|
||||
if ch == current {
|
||||
match pattern_chars.next() {
|
||||
Some(next_ch) => current = next_ch,
|
||||
None => return true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
107
crates/owlen-tui/src/commands/registry.rs
Normal file
107
crates/owlen-tui/src/commands/registry.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use owlen_core::ui::FocusedPanel;
|
||||
|
||||
use crate::widgets::model_picker::FilterMode;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum AppCommand {
|
||||
OpenModelPicker(Option<FilterMode>),
|
||||
OpenCommandPalette,
|
||||
CycleFocusForward,
|
||||
CycleFocusBackward,
|
||||
FocusPanel(FocusedPanel),
|
||||
ComposerSubmit,
|
||||
EnterCommandMode,
|
||||
ToggleDebugLog,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CommandRegistry {
|
||||
commands: HashMap<String, AppCommand>,
|
||||
}
|
||||
|
||||
impl CommandRegistry {
|
||||
pub fn new() -> Self {
|
||||
let mut commands = HashMap::new();
|
||||
|
||||
commands.insert(
|
||||
"model.open_all".to_string(),
|
||||
AppCommand::OpenModelPicker(None),
|
||||
);
|
||||
commands.insert(
|
||||
"model.open_local".to_string(),
|
||||
AppCommand::OpenModelPicker(Some(FilterMode::LocalOnly)),
|
||||
);
|
||||
commands.insert(
|
||||
"model.open_cloud".to_string(),
|
||||
AppCommand::OpenModelPicker(Some(FilterMode::CloudOnly)),
|
||||
);
|
||||
commands.insert(
|
||||
"model.open_available".to_string(),
|
||||
AppCommand::OpenModelPicker(Some(FilterMode::Available)),
|
||||
);
|
||||
commands.insert("palette.open".to_string(), AppCommand::OpenCommandPalette);
|
||||
commands.insert("focus.next".to_string(), AppCommand::CycleFocusForward);
|
||||
commands.insert("focus.prev".to_string(), AppCommand::CycleFocusBackward);
|
||||
commands.insert(
|
||||
"focus.files".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Files),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.chat".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Chat),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.thinking".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Thinking),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.input".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Input),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.code".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Code),
|
||||
);
|
||||
commands.insert("composer.submit".to_string(), AppCommand::ComposerSubmit);
|
||||
commands.insert("mode.command".to_string(), AppCommand::EnterCommandMode);
|
||||
commands.insert("debug.toggle".to_string(), AppCommand::ToggleDebugLog);
|
||||
|
||||
Self { commands }
|
||||
}
|
||||
|
||||
pub fn resolve(&self, command: &str) -> Option<AppCommand> {
|
||||
self.commands.get(command).copied()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CommandRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolve_known_command() {
|
||||
let registry = CommandRegistry::new();
|
||||
assert_eq!(
|
||||
registry.resolve("focus.next"),
|
||||
Some(AppCommand::CycleFocusForward)
|
||||
);
|
||||
assert_eq!(
|
||||
registry.resolve("model.open_cloud"),
|
||||
Some(AppCommand::OpenModelPicker(Some(FilterMode::CloudOnly)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_unknown_command() {
|
||||
let registry = CommandRegistry::new();
|
||||
assert_eq!(registry.resolve("does.not.exist"), None);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
pub use owlen_core::config::{
|
||||
default_config_path, ensure_ollama_config, ensure_provider_config, session_timeout, Config,
|
||||
GeneralSettings, InputSettings, StorageSettings, UiSettings, DEFAULT_CONFIG_PATH,
|
||||
Config, DEFAULT_CONFIG_PATH, GeneralSettings, IconMode, InputSettings, StorageSettings,
|
||||
UiSettings, default_config_path, ensure_ollama_config, ensure_provider_config, session_timeout,
|
||||
};
|
||||
|
||||
/// Attempt to load configuration from default location
|
||||
|
||||
@@ -17,6 +17,22 @@ pub enum Event {
|
||||
Tick,
|
||||
}
|
||||
|
||||
/// Convert a raw crossterm event into an application event.
|
||||
pub fn from_crossterm_event(raw: crossterm::event::Event) -> Option<Event> {
|
||||
match raw {
|
||||
crossterm::event::Event::Key(key) => {
|
||||
if key.kind == KeyEventKind::Press {
|
||||
Some(Event::Key(key))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
crossterm::event::Event::Resize(width, height) => Some(Event::Resize(width, height)),
|
||||
crossterm::event::Event::Paste(text) => Some(Event::Paste(text)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Event handler that captures terminal events and sends them to the application
|
||||
pub struct EventHandler {
|
||||
sender: mpsc::UnboundedSender<Event>,
|
||||
@@ -52,20 +68,8 @@ impl EventHandler {
|
||||
if event::poll(timeout).unwrap_or(false) {
|
||||
match event::read() {
|
||||
Ok(event) => {
|
||||
match event {
|
||||
crossterm::event::Event::Key(key) => {
|
||||
// Only handle KeyEventKind::Press to avoid duplicate events
|
||||
if key.kind == KeyEventKind::Press {
|
||||
let _ = self.sender.send(Event::Key(key));
|
||||
}
|
||||
}
|
||||
crossterm::event::Event::Resize(width, height) => {
|
||||
let _ = self.sender.send(Event::Resize(width, height));
|
||||
}
|
||||
crossterm::event::Event::Paste(text) => {
|
||||
let _ = self.sender.send(Event::Paste(text));
|
||||
}
|
||||
_ => {}
|
||||
if let Some(converted) = from_crossterm_event(event) {
|
||||
let _ = self.sender.send(converted);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user